# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "torchaudio>=2.0.0", # "transformers>=4.36.0", # "datasets>=2.14.0", # "click>=8.0.0", # "tqdm>=4.60.0", # "wandb>=0.15.0", # "python-dotenv>=1.0.0", # "jiwer>=3.0.0", # "huggingface_hub>=0.20.0", # ] # /// """ Training script for ASR-1 Vietnamese Speech Recognition. Fine-tunes OpenAI Whisper on Vietnamese speech datasets. Usage: uv run src/train.py uv run src/train.py --base-model openai/whisper-large-v3 uv run src/train.py --dataset vivos uv run src/train.py --wandb --wandb-project asr-1 """ import sys from pathlib import Path from dataclasses import dataclass from typing import Any, Dict, List, Union from dotenv import load_dotenv load_dotenv() import torch import click from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, ) from datasets import Audio import evaluate sys.path.insert(0, str(Path(__file__).parent.parent)) from src.data import load_common_voice, load_vivos, prepare_dataset @dataclass class DataCollatorSpeechSeq2SeqWithPadding: """Data collator for Whisper speech-to-text training.""" processor: Any def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # Split inputs and labels input_features = [{"input_features": f["input_features"]} for f in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") # Pad labels label_features = [{"input_ids": f["labels"]} for f in features] labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # Replace padding with -100 for loss computation labels = labels_batch["input_ids"].masked_fill( labels_batch.attention_mask.ne(1), -100 ) # Remove BOS token if present (Whisper adds it during generation) if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch def compute_metrics(pred, processor, wer_metric, cer_metric): """Compute WER and CER metrics.""" pred_ids = pred.predictions label_ids = pred.label_ids # Replace -100 with pad token label_ids[label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str) cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str) return {"wer": wer, "cer": cer} @click.command() @click.option('--base-model', default='openai/whisper-large-v3', help='Base Whisper model') @click.option('--dataset', type=click.Choice(['common_voice', 'vivos', 'both']), default='common_voice', help='Training dataset') @click.option('--output', '-o', default='models/asr-1', help='Output directory') @click.option('--epochs', default=3, type=int, help='Number of training epochs') @click.option('--batch-size', default=8, type=int, help='Per-device batch size') @click.option('--grad-accum', default=2, type=int, help='Gradient accumulation steps') @click.option('--lr', default=1e-5, type=float, help='Learning rate') @click.option('--warmup-steps', default=500, type=int, help='Warmup steps') @click.option('--max-steps', default=-1, type=int, help='Max training steps (-1 for epoch-based)') @click.option('--fp16/--no-fp16', default=True, help='Use mixed precision') @click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') @click.option('--wandb-project', default='asr-1', help='W&B project name') @click.option('--push-to-hub', is_flag=True, help='Push model to HuggingFace Hub') @click.option('--hub-model-id', default='undertheseanlp/asr-1', help='HuggingFace Hub model ID') @click.option('--eval-steps', default=500, type=int, help='Evaluate every N steps') @click.option('--save-steps', default=500, type=int, help='Save checkpoint every N steps') @click.option('--cache-dir', default=None, help='Dataset cache directory') def train(base_model, dataset, output, epochs, batch_size, grad_accum, lr, warmup_steps, max_steps, fp16, use_wandb, wandb_project, push_to_hub, hub_model_id, eval_steps, save_steps, cache_dir): """Train ASR-1 Vietnamese Speech Recognition model.""" device = "cuda" if torch.cuda.is_available() else "cpu" click.echo(f"Using device: {device}") click.echo("=" * 60) click.echo("ASR-1: Vietnamese Automatic Speech Recognition") click.echo("=" * 60) # Load processor and model click.echo(f"\nLoading base model: {base_model}") processor = WhisperProcessor.from_pretrained(base_model) model = WhisperForConditionalGeneration.from_pretrained(base_model) # Force Vietnamese language and transcription task model.generation_config.language = "vi" model.generation_config.task = "transcribe" model.generation_config.forced_decoder_ids = None n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) click.echo(f" Parameters: {n_params:,}") # Load datasets click.echo(f"\nLoading dataset: {dataset}") if dataset == "common_voice": train_ds = load_common_voice("train", cache_dir=cache_dir) eval_ds = load_common_voice("validation", cache_dir=cache_dir) elif dataset == "vivos": train_ds = load_vivos("train", cache_dir=cache_dir) eval_ds = load_vivos("test", cache_dir=cache_dir) else: # both from datasets import concatenate_datasets cv_train = load_common_voice("train", cache_dir=cache_dir) vivos_train = load_vivos("train", cache_dir=cache_dir) train_ds = concatenate_datasets([cv_train, vivos_train]) eval_ds = load_common_voice("validation", cache_dir=cache_dir) click.echo(f" Train: {len(train_ds)} samples") click.echo(f" Eval: {len(eval_ds)} samples") # Prepare datasets click.echo("\nPreparing datasets...") train_ds = train_ds.map( lambda batch: prepare_dataset(batch, processor), remove_columns=train_ds.column_names, ) eval_ds = eval_ds.map( lambda batch: prepare_dataset(batch, processor), remove_columns=eval_ds.column_names, ) # Data collator data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) # Metrics wer_metric = evaluate.load("wer") cer_metric = evaluate.load("cer") # Training arguments training_args = Seq2SeqTrainingArguments( output_dir=output, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=grad_accum, learning_rate=lr, warmup_steps=warmup_steps, max_steps=max_steps, num_train_epochs=epochs, fp16=fp16 and torch.cuda.is_available(), eval_strategy="steps", eval_steps=eval_steps, save_strategy="steps", save_steps=save_steps, logging_steps=25, load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, predict_with_generate=True, generation_max_length=225, report_to="wandb" if use_wandb else "none", push_to_hub=push_to_hub, hub_model_id=hub_model_id if push_to_hub else None, save_total_limit=3, dataloader_num_workers=4, remove_unused_columns=False, ) # Trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, data_collator=data_collator, processing_class=processor.feature_extractor, compute_metrics=lambda pred: compute_metrics(pred, processor, wer_metric, cer_metric), ) # Train click.echo(f"\nTraining for {epochs} epochs...") trainer.train() # Save best model click.echo(f"\nSaving model to {output}") trainer.save_model(output) processor.save_pretrained(output) # Final evaluation click.echo("\nFinal evaluation...") metrics = trainer.evaluate() click.echo(f" WER: {metrics['eval_wer']:.2f}%") click.echo(f" CER: {metrics['eval_cer']:.2f}%") click.echo(f"\nModel saved to: {output}") if push_to_hub: click.echo(f"Pushing to HuggingFace Hub: {hub_model_id}") trainer.push_to_hub() if __name__ == '__main__': train()