| |
| |
|
|
| import os |
| os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
| import re |
| import sys |
| import glob |
| import json |
| import numpy as np |
| import pandas as pd |
| import torch |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional, Union |
| from argparse import ArgumentParser |
|
|
| import evaluate |
| from evaluate import load |
| from datasets import Dataset, Audio, load_dataset, ClassLabel |
| from transformers import Wav2Vec2CTCTokenizer |
| from transformers import Wav2Vec2FeatureExtractor |
| from transformers import Wav2Vec2Processor |
| from transformers import Wav2Vec2ForCTC |
| from transformers import TrainingArguments |
| from transformers import Trainer |
| from safetensors.torch import save_file as safe_save_file |
| from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE |
|
|
| |
| |
|
|
| parser = ArgumentParser() |
| parser.add_argument('--input_dir', default='./', type=str, help='Directory with a training dataset') |
| parser.add_argument('--output_dir', default='models-1', type=str, help='Directory to save model checkpoints') |
| parser.add_argument('--lang', default='ady', type=str, help='Language') |
| parser.add_argument('--model_name', default='facebook/mms-1b-l1107', type=str, help='Pretrained model') |
| parser.add_argument('--attn_implementation', default='flash_attention_2', type=str, help='Attention implementation') |
| parser.add_argument('--n_epochs', default=30, type=int, help='Number of epochs to train') |
| parser.add_argument('--batch_size', default=2, type=int, help='Batch size') |
| parser.add_argument('--accum', default=1, type=int, help='Number of steps for gradient accumulation') |
| parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') |
| parser.add_argument('--num_workers', default=os.cpu_count(), type=int, help='Number of workers') |
| parser.add_argument('--reduce_p', default=2, type=int, help='Patience for learning rate reduction') |
| parser.add_argument('--reduce_f', default=0.5, type=float, help='Factor for learning rate reduction') |
| parser.add_argument('--reduce_mode', default='min', type=str, help='Mode (min/max) for learning rate reduction') |
| parser.add_argument('--max_length', default=None, type=int, help='Audio max length in frames (duration in seconds by 16_000)') |
| parser.add_argument('--truncation', default=0, type=int, choices=[0, 1], help='Truncation') |
| args = parser.parse_args() |
| for a in [a for a in vars(args) if '__' not in a]: print('%-25s %s' % (a, vars(args)[a])) |
|
|
| |
| |
| |
|
|
| bracketed = re.compile(r"\[[^\]]+\]") |
| unintell_paren = re.compile(r"\(\?+\)") |
| repl_punc = re.compile('[,?¿¡!";:]+') |
| multispace = re.compile(" +") |
|
|
| def clean(t): |
| """ |
| Official cleaning function |
| """ |
| t = re.sub(bracketed, " ", t) |
| t = re.sub(unintell_paren, " ", t) |
| t = t.replace(" ... ", " ") |
| t = t.replace("#x27;", "'") |
| t = re.sub(repl_punc, " ", t) |
| t = t.replace("...", "!ELLIPSIS!").replace(".", " ").replace("!ELLIPSIS!", "...") |
| t = re.sub(multispace, " ", t) |
| return t |
|
|
|
|
|
|
| @dataclass |
| class DataCollatorCTCWithPadding: |
| """ |
| Data collator that will dynamically pad the inputs received. |
| https://github.com/huggingface/transformers/blob/7e61d56a45c19284cfda0cee8995fb552f6b1f4e/ |
| examples/pytorch/speech-recognition/run_speech_recognition_ctc.py#L219 |
| Args: |
| processor (:class:`~transformers.Wav2Vec2Processor`) |
| The processor used for processing the data. |
| padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): |
| Select a strategy to pad the returned sequences (according to the model's padding side and padding index) |
| among: |
| * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single |
| sequence if provided). |
| * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the |
| maximum acceptable input length for the model if that argument is not provided. |
| * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of |
| different lengths). |
| """ |
|
|
| processor: Wav2Vec2Processor |
| padding: Union[bool, str] = True |
|
|
| def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
| |
| |
| input_features = [{"input_values": feature["input_values"]} for feature in features] |
| label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
| batch = self.processor.pad( |
| input_features, |
| padding=self.padding, |
| return_tensors="pt", |
| ) |
|
|
| labels_batch = self.processor.pad( |
| labels=label_features, |
| padding=self.padding, |
| return_tensors="pt", |
| ) |
|
|
| |
| labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) |
|
|
| batch["labels"] = labels |
|
|
| return batch |
|
|
| |
| |
| |
|
|
| output_dir = os.path.join(args.output_dir, args.lang) |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| train_df = pd.read_csv(os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/train.tsv' % args.lang), sep='\t') |
| dev_df = pd.read_csv(os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/dev.tsv' % args.lang), sep='\t') |
| test_df = pd.read_csv(os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/test.tsv' % args.lang), sep='\t') |
|
|
| |
| train_df['split'] = 'train' |
| dev_df['split'] = 'dev' |
| test_df['split'] = 'test' |
| corpus_df = pd.concat([train_df, dev_df, test_df]) |
| corpus_size = len(corpus_df) |
|
|
| |
| reported_file = os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/reported.tsv' % args.lang) |
| if os.path.exists(reported_file) and os.path.getsize(reported_file): |
| reported_df = pd.read_csv(reported_file, sep='\t') |
| |
| selector_reported = corpus_df['sentence_id'].isin(reported_df['sentence_id'].values) |
| corpus_df = corpus_df[~selector_reported] |
| else: |
| selector_reported = np.array([0]) |
|
|
| |
| invalidated_file = os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/invalidated.tsv' % args.lang) |
| if os.path.exists(invalidated_file) and os.path.getsize(invalidated_file): |
| invalidated_df = pd.read_csv(invalidated_file, sep='\t') |
| |
| selector_invalidated = corpus_df['sentence_id'].isin(invalidated_df['sentence_id'].values) |
| corpus_df = corpus_df[~selector_invalidated] |
| else: |
| selector_invalidated = np.array([0]) |
|
|
| |
| clip_durations_df = pd.read_csv(os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/clip_durations.tsv' % args.lang), sep='\t') |
| zero_duration_df = clip_durations_df[clip_durations_df['duration[ms]'] == 0].copy() |
| selector_zero_duration = corpus_df['path'].isin(zero_duration_df['clip'].values) |
| |
|
|
|
|
| |
| selector_dup_file = corpus_df.duplicated(subset=['path'], keep=False) |
| corpus_df = corpus_df[~selector_dup_file] |
|
|
| |
| selector_no_trans = corpus_df['sentence'].isnull() |
| corpus_df = corpus_df[~selector_no_trans] |
| selector_zero_len_trans = corpus_df['sentence'].map(len) == 0 |
| corpus_df = corpus_df[~selector_zero_len_trans] |
|
|
| |
| selector_non_voted = corpus_df['up_votes'] == 0 |
| |
|
|
| |
| corpus_df['sentence'] = corpus_df['sentence'].map(clean) |
|
|
| |
| selector_dup_trans = corpus_df.duplicated(subset=['sentence'], keep=False) |
| corpus_df = corpus_df[~selector_dup_trans] |
|
|
| |
| corpus_df['file'] = corpus_df['path'].map(lambda x: os.path.join(args.input_dir, 'cv-corpus-23.0-2025-09-05/%s/clips' % args.lang, x)) |
|
|
| |
| corpus_df['transcription'] = corpus_df['sentence'] |
|
|
| |
| train_df = corpus_df[(corpus_df['split'] == 'train') | (corpus_df['split'] == 'dev')].copy() |
| dev_df = corpus_df[corpus_df['split'] == 'test'].copy() |
|
|
| |
| n_intersected_speakers = len(set(train_df['client_id']) & set(dev_df['client_id'])) |
|
|
| print('CORPUS raw size:', corpus_size) |
| print('FINAL. Train: %d Dev: %d' % (len(train_df), len(dev_df))) |
|
|
| |
| |
| |
|
|
| tr_df = train_df[['file', 'sentence']].copy() |
| tr_df['path'] = tr_df['file'] |
| tr_df.columns = ['audio', 'sentence', 'path'] |
|
|
| assert os.path.exists(tr_df.iloc[0]['path']), 'Cannot find .mp3 file' |
|
|
| common_voice_train = Dataset.from_pandas(tr_df, preserve_index=False) |
| common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16000)) |
| print(common_voice_train) |
|
|
| |
|
|
| te_df = dev_df[['file', 'sentence']].copy() |
| te_df['path'] = te_df['file'] |
| te_df.columns = ['audio', 'sentence', 'path'] |
|
|
| assert os.path.exists(te_df.iloc[0]['path']), 'Cannot find .mp3 file' |
|
|
| common_voice_test = Dataset.from_pandas(te_df, preserve_index=False) |
| common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16000)) |
| print(common_voice_test) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| chars = [] |
| for line in corpus_df['transcription']: |
| chars += list(line) |
|
|
| vocab_list = list(set(chars)) |
| vocab_list.remove(" ") |
| if "|" in vocab_list: |
| vocab_list.remove("|") |
| vocab_list = sorted(vocab_list) |
| vocab_list = ["<pad>", "<s>", "</s>", "<unk>", "|"] + vocab_list |
| vocab_dict = {v: k for k, v in enumerate(vocab_list)} |
|
|
|
|
| nested_vocab_dict = {args.lang: vocab_dict} |
|
|
| with open(os.path.join(output_dir, 'vocab.json'), 'wt', encoding='utf-8') as vocab_file: |
| json.dump(nested_vocab_dict, vocab_file) |
|
|
| |
| |
| |
|
|
| tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(output_dir, |
| pad_token="<pad>", |
| bos_token="<s>", |
| eos_token="</s>", |
| unk_token="<unk>", |
| word_delimiter_token="|", |
| do_lower_case=False, |
| target_lang=args.lang) |
|
|
| feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, |
| sampling_rate=16000, |
| padding_value=0.0, |
| do_normalize=True, |
| return_attention_mask=True) |
|
|
| processor = Wav2Vec2Processor(feature_extractor=feature_extractor, |
| tokenizer=tokenizer) |
|
|
| print('Tokenizer spec:') |
| print('pad_token_id:', tokenizer.pad_token_id) |
| print('bos_token_id:', tokenizer.bos_token_id) |
| print('eos_token_id:', tokenizer.eos_token_id) |
| print('unk_token_id:', tokenizer.unk_token_id) |
| print('word_delimiter_token_id:', tokenizer.word_delimiter_token_id) |
| print('vocab_size:', len(tokenizer.vocab[args.lang])) |
|
|
| |
| |
| |
|
|
| def prepare_dataset(batch): |
| audio = batch["audio"] |
| |
| batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"], max_length=args.max_length, truncation=bool(args.truncation)).input_values[0] |
| batch["input_length"] = len(batch["input_values"]) |
| batch["labels"] = processor(text=batch["sentence"]).input_ids |
| return batch |
|
|
| common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, num_proc=args.num_workers) |
| common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, num_proc=args.num_workers) |
| print(common_voice_train) |
| print(common_voice_test) |
|
|
| |
| |
| |
|
|
| data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) |
|
|
| |
| |
| |
|
|
| wer_metric = load("wer") |
|
|
| def compute_metrics(pred): |
| pred_logits = pred.predictions |
| pred_ids = np.argmax(pred_logits, axis=-1) |
|
|
| pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
| pred_str = processor.batch_decode(pred_ids) |
| |
| label_str = processor.batch_decode(pred.label_ids, group_tokens=False) |
|
|
| wer = wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
| return {"wer": wer} |
|
|
| |
| |
| |
|
|
| model = Wav2Vec2ForCTC.from_pretrained( |
| args.model_name, |
| ctc_loss_reduction="mean", |
| pad_token_id=processor.tokenizer.pad_token_id, |
| vocab_size=len(processor.tokenizer), |
| ignore_mismatched_sizes=True, |
| attn_implementation=args.attn_implementation, |
| |
| ) |
|
|
| |
| |
| |
|
|
| model.init_adapter_layers() |
| model.freeze_base_model() |
| adapter_weights = model._get_adapters() |
| for param in adapter_weights.values(): |
| param.requires_grad = True |
|
|
| |
| |
| |
|
|
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| seed=42, |
| data_seed=None, |
| group_by_length=False, |
|
|
| dataloader_num_workers=args.num_workers, |
| dataloader_pin_memory=True, |
| dataloader_prefetch_factor=2, |
| torch_empty_cache_steps=None, |
|
|
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.accum, |
| auto_find_batch_size=False, |
|
|
| per_device_eval_batch_size=args.batch_size, |
| eval_accumulation_steps=None, |
|
|
| optim='adamw_torch_fused', |
| weight_decay=0.05, |
| learning_rate=args.lr, |
| lr_scheduler_type='reduce_lr_on_plateau', |
| lr_scheduler_kwargs={'mode': args.reduce_mode, 'factor': args.reduce_f, 'patience': args.reduce_p}, |
| warmup_steps=20, |
| num_train_epochs=args.n_epochs, |
| gradient_checkpointing=True, |
| fp16=True, |
| bf16=False, |
| fp16_opt_level='O1', |
| half_precision_backend='auto', |
| bf16_full_eval=False, |
| fp16_full_eval=False, |
|
|
| eval_strategy='epoch', |
| logging_strategy='epoch', |
|
|
| save_strategy='best', |
| save_total_limit=1, |
| save_only_model=False, |
| load_best_model_at_end=True, |
| metric_for_best_model='wer', |
| greater_is_better=False, |
| push_to_hub=False, |
| ) |
|
|
| |
| |
| |
|
|
| trainer = Trainer( |
| model=model, |
| data_collator=data_collator, |
| args=training_args, |
| compute_metrics=compute_metrics, |
| train_dataset=common_voice_train, |
| eval_dataset=common_voice_test, |
| processing_class=processor, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| train_output = trainer.train() |
|
|
| print('-'*20) |
| print('best_metric:', trainer.state.best_metric) |
| print('-'*20) |
|
|
| |
| |
| |
|
|
| adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(args.lang) |
| adapter_file = os.path.join(training_args.output_dir, adapter_file) |
|
|
| safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"}) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|