Wav2Vec-XLS-R-300M fine-tuned for Arabic

Wav2Vec-XLS-R-300M Fine-tuned for Arabic using Common-Voice 11. When using the model, make sure the audio files are sampled at 16 kHz.

Evaluation

The model can be used directly (without a language model) as follows:

import torch
import torchaudio
from datasets import load_dataset, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import string, re

punctuation = '''`รทร—ุ›<>_()*&^%][ู€ุŒ/:"ุŸ.,'{}~ยฆ+|!โ€โ€ฆโ€œโ€“ู€''' + string.punctuation
arabic_diacritics = re.compile("""
                             ู‘    | # Shadda
                             ูŽ    | # Fatha
                             ู‹    | # Tanwin Fath
                             ู    | # Damma
                             ูŒ    | # Tanwin Damm
                             ู    | # Kasra
                             ู    | # Tanwin Kasr
                             ู’    | # Sukun
                             ู€     # Tatwil/Kashid
                         """, re.VERBOSE)


def process_text(text):
    translator = str.maketrans('', '', punctuation)
    text = text.translate(translator)
    text = re.sub("[0123456789]", '', text)

    # remove Tashkeel
    text = re.sub(arabic_diacritics, '', text)

    # remove elongation
    text = re.sub("[ุฅุฃุขุง]", "ุง", text)
    text = re.sub("ู‰", "ูŠ", text)
    text = re.sub("ุค", "ุก", text)
    text = re.sub("ุฆ", "ุก", text)
    text = re.sub("ุฉ", "ู‡", text)
    text = re.sub("ฺฏ", "ูƒ", text)

    text = ' '.join(word for word in text.split())


feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
do_normalize=True, return_attention_mask=True)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar", unk_token="[UNK]", pad_token="[PAD]",
word_delimiter_token="|")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained("aitor-alvarez/wav2vec2-xls-r-300m-ar").to("cuda")


def prepare_dataset(batch):
  audio = batch["audio"]
  
  batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
  batch["input_length"] = len(batch["input_values"])
  
  with processor.as_target_processor():
  batch["labels"] = processor(batch["sentence"]).input_ids
  return batch


def remove_ar_special_characters(batch):
    batch["sentence"] = process_text(batch["sentence"]).lower()
    return batch

speech_test = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test")
speech_test = speech_test.cast_column("audio", Audio(sampling_rate=16_000))
speech_test = speech_test.map(remove_ar_special_characters)
speech_test = speech_test.map(prepare_dataset, remove_columns=speech_test.column_names)


def get_predictions(batch):
  with torch.no_grad():
    input_dict = processor(batch["input_values"], return_tensors="pt", padding=True)
    logits = model(input_dict.input_values.to(device), attention_mask=input_dict.attention_mask.to(device)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_txt"] = processor.batch_decode(pred_ids)[0]
    batch["txt"] = processor.decode(batch["labels"])
    return batch

results = speech_test.map(get_predictions)
print("Test WER: {:.2f}".format(wer_metric.compute(predictions=results["pred_txt"], references=results["txt"])))

WER Test: 20.71 %

Downloads last month
9
Safetensors
Model size
0.3B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using aitor-alvarez/wav2vec2-xls-r-300m-ar 1