Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
asr-1 / src /train.py
rain1024's picture
Initial commit: ASR-1 Vietnamese speech recognition model
5763d9e
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "torchaudio>=2.0.0",
# "transformers>=4.36.0",
# "datasets>=2.14.0",
# "click>=8.0.0",
# "tqdm>=4.60.0",
# "wandb>=0.15.0",
# "python-dotenv>=1.0.0",
# "jiwer>=3.0.0",
# "huggingface_hub>=0.20.0",
# ]
# ///
"""
Training script for ASR-1 Vietnamese Speech Recognition.
Fine-tunes OpenAI Whisper on Vietnamese speech datasets.
Usage:
uv run src/train.py
uv run src/train.py --base-model openai/whisper-large-v3
uv run src/train.py --dataset vivos
uv run src/train.py --wandb --wandb-project asr-1
"""
import sys
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from dotenv import load_dotenv
load_dotenv()
import torch
import click
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
from datasets import Audio
import evaluate
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data import load_common_voice, load_vivos, prepare_dataset
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""Data collator for Whisper speech-to-text training."""
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# Split inputs and labels
input_features = [{"input_features": f["input_features"]} for f in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# Pad labels
label_features = [{"input_ids": f["labels"]} for f in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# Replace padding with -100 for loss computation
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
# Remove BOS token if present (Whisper adds it during generation)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def compute_metrics(pred, processor, wer_metric, cer_metric):
"""Compute WER and CER metrics."""
pred_ids = pred.predictions
label_ids = pred.label_ids
# Replace -100 with pad token
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer": cer}
@click.command()
@click.option('--base-model', default='openai/whisper-large-v3', help='Base Whisper model')
@click.option('--dataset', type=click.Choice(['common_voice', 'vivos', 'both']), default='common_voice',
help='Training dataset')
@click.option('--output', '-o', default='models/asr-1', help='Output directory')
@click.option('--epochs', default=3, type=int, help='Number of training epochs')
@click.option('--batch-size', default=8, type=int, help='Per-device batch size')
@click.option('--grad-accum', default=2, type=int, help='Gradient accumulation steps')
@click.option('--lr', default=1e-5, type=float, help='Learning rate')
@click.option('--warmup-steps', default=500, type=int, help='Warmup steps')
@click.option('--max-steps', default=-1, type=int, help='Max training steps (-1 for epoch-based)')
@click.option('--fp16/--no-fp16', default=True, help='Use mixed precision')
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
@click.option('--wandb-project', default='asr-1', help='W&B project name')
@click.option('--push-to-hub', is_flag=True, help='Push model to HuggingFace Hub')
@click.option('--hub-model-id', default='undertheseanlp/asr-1', help='HuggingFace Hub model ID')
@click.option('--eval-steps', default=500, type=int, help='Evaluate every N steps')
@click.option('--save-steps', default=500, type=int, help='Save checkpoint every N steps')
@click.option('--cache-dir', default=None, help='Dataset cache directory')
def train(base_model, dataset, output, epochs, batch_size, grad_accum, lr,
warmup_steps, max_steps, fp16, use_wandb, wandb_project, push_to_hub,
hub_model_id, eval_steps, save_steps, cache_dir):
"""Train ASR-1 Vietnamese Speech Recognition model."""
device = "cuda" if torch.cuda.is_available() else "cpu"
click.echo(f"Using device: {device}")
click.echo("=" * 60)
click.echo("ASR-1: Vietnamese Automatic Speech Recognition")
click.echo("=" * 60)
# Load processor and model
click.echo(f"\nLoading base model: {base_model}")
processor = WhisperProcessor.from_pretrained(base_model)
model = WhisperForConditionalGeneration.from_pretrained(base_model)
# Force Vietnamese language and transcription task
model.generation_config.language = "vi"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
click.echo(f" Parameters: {n_params:,}")
# Load datasets
click.echo(f"\nLoading dataset: {dataset}")
if dataset == "common_voice":
train_ds = load_common_voice("train", cache_dir=cache_dir)
eval_ds = load_common_voice("validation", cache_dir=cache_dir)
elif dataset == "vivos":
train_ds = load_vivos("train", cache_dir=cache_dir)
eval_ds = load_vivos("test", cache_dir=cache_dir)
else: # both
from datasets import concatenate_datasets
cv_train = load_common_voice("train", cache_dir=cache_dir)
vivos_train = load_vivos("train", cache_dir=cache_dir)
train_ds = concatenate_datasets([cv_train, vivos_train])
eval_ds = load_common_voice("validation", cache_dir=cache_dir)
click.echo(f" Train: {len(train_ds)} samples")
click.echo(f" Eval: {len(eval_ds)} samples")
# Prepare datasets
click.echo("\nPreparing datasets...")
train_ds = train_ds.map(
lambda batch: prepare_dataset(batch, processor),
remove_columns=train_ds.column_names,
)
eval_ds = eval_ds.map(
lambda batch: prepare_dataset(batch, processor),
remove_columns=eval_ds.column_names,
)
# Data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Metrics
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
# Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir=output,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=grad_accum,
learning_rate=lr,
warmup_steps=warmup_steps,
max_steps=max_steps,
num_train_epochs=epochs,
fp16=fp16 and torch.cuda.is_available(),
eval_strategy="steps",
eval_steps=eval_steps,
save_strategy="steps",
save_steps=save_steps,
logging_steps=25,
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
predict_with_generate=True,
generation_max_length=225,
report_to="wandb" if use_wandb else "none",
push_to_hub=push_to_hub,
hub_model_id=hub_model_id if push_to_hub else None,
save_total_limit=3,
dataloader_num_workers=4,
remove_unused_columns=False,
)
# Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=data_collator,
processing_class=processor.feature_extractor,
compute_metrics=lambda pred: compute_metrics(pred, processor, wer_metric, cer_metric),
)
# Train
click.echo(f"\nTraining for {epochs} epochs...")
trainer.train()
# Save best model
click.echo(f"\nSaving model to {output}")
trainer.save_model(output)
processor.save_pretrained(output)
# Final evaluation
click.echo("\nFinal evaluation...")
metrics = trainer.evaluate()
click.echo(f" WER: {metrics['eval_wer']:.2f}%")
click.echo(f" CER: {metrics['eval_cer']:.2f}%")
click.echo(f"\nModel saved to: {output}")
if push_to_hub:
click.echo(f"Pushing to HuggingFace Hub: {hub_model_id}")
trainer.push_to_hub()
if __name__ == '__main__':
train()