|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Training script for ASR-1 Vietnamese Speech Recognition. |
|
|
|
|
|
Fine-tunes OpenAI Whisper on Vietnamese speech datasets. |
|
|
|
|
|
Usage: |
|
|
uv run src/train.py |
|
|
uv run src/train.py --base-model openai/whisper-large-v3 |
|
|
uv run src/train.py --dataset vivos |
|
|
uv run src/train.py --wandb --wandb-project asr-1 |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
import torch |
|
|
import click |
|
|
from transformers import ( |
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
Seq2SeqTrainingArguments, |
|
|
Seq2SeqTrainer, |
|
|
) |
|
|
from datasets import Audio |
|
|
import evaluate |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from src.data import load_common_voice, load_vivos, prepare_dataset |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataCollatorSpeechSeq2SeqWithPadding: |
|
|
"""Data collator for Whisper speech-to-text training.""" |
|
|
|
|
|
processor: Any |
|
|
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
input_features = [{"input_features": f["input_features"]} for f in features] |
|
|
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
|
|
|
|
|
|
|
|
label_features = [{"input_ids": f["labels"]} for f in features] |
|
|
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
|
|
|
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill( |
|
|
labels_batch.attention_mask.ne(1), -100 |
|
|
) |
|
|
|
|
|
|
|
|
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
|
|
labels = labels[:, 1:] |
|
|
|
|
|
batch["labels"] = labels |
|
|
return batch |
|
|
|
|
|
|
|
|
def compute_metrics(pred, processor, wer_metric, cer_metric): |
|
|
"""Compute WER and CER metrics.""" |
|
|
pred_ids = pred.predictions |
|
|
label_ids = pred.label_ids |
|
|
|
|
|
|
|
|
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
|
|
|
|
|
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
|
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
|
|
|
|
wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
return {"wer": wer, "cer": cer} |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option('--base-model', default='openai/whisper-large-v3', help='Base Whisper model') |
|
|
@click.option('--dataset', type=click.Choice(['common_voice', 'vivos', 'both']), default='common_voice', |
|
|
help='Training dataset') |
|
|
@click.option('--output', '-o', default='models/asr-1', help='Output directory') |
|
|
@click.option('--epochs', default=3, type=int, help='Number of training epochs') |
|
|
@click.option('--batch-size', default=8, type=int, help='Per-device batch size') |
|
|
@click.option('--grad-accum', default=2, type=int, help='Gradient accumulation steps') |
|
|
@click.option('--lr', default=1e-5, type=float, help='Learning rate') |
|
|
@click.option('--warmup-steps', default=500, type=int, help='Warmup steps') |
|
|
@click.option('--max-steps', default=-1, type=int, help='Max training steps (-1 for epoch-based)') |
|
|
@click.option('--fp16/--no-fp16', default=True, help='Use mixed precision') |
|
|
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging') |
|
|
@click.option('--wandb-project', default='asr-1', help='W&B project name') |
|
|
@click.option('--push-to-hub', is_flag=True, help='Push model to HuggingFace Hub') |
|
|
@click.option('--hub-model-id', default='undertheseanlp/asr-1', help='HuggingFace Hub model ID') |
|
|
@click.option('--eval-steps', default=500, type=int, help='Evaluate every N steps') |
|
|
@click.option('--save-steps', default=500, type=int, help='Save checkpoint every N steps') |
|
|
@click.option('--cache-dir', default=None, help='Dataset cache directory') |
|
|
def train(base_model, dataset, output, epochs, batch_size, grad_accum, lr, |
|
|
warmup_steps, max_steps, fp16, use_wandb, wandb_project, push_to_hub, |
|
|
hub_model_id, eval_steps, save_steps, cache_dir): |
|
|
"""Train ASR-1 Vietnamese Speech Recognition model.""" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
click.echo(f"Using device: {device}") |
|
|
|
|
|
click.echo("=" * 60) |
|
|
click.echo("ASR-1: Vietnamese Automatic Speech Recognition") |
|
|
click.echo("=" * 60) |
|
|
|
|
|
|
|
|
click.echo(f"\nLoading base model: {base_model}") |
|
|
processor = WhisperProcessor.from_pretrained(base_model) |
|
|
model = WhisperForConditionalGeneration.from_pretrained(base_model) |
|
|
|
|
|
|
|
|
model.generation_config.language = "vi" |
|
|
model.generation_config.task = "transcribe" |
|
|
model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
click.echo(f" Parameters: {n_params:,}") |
|
|
|
|
|
|
|
|
click.echo(f"\nLoading dataset: {dataset}") |
|
|
if dataset == "common_voice": |
|
|
train_ds = load_common_voice("train", cache_dir=cache_dir) |
|
|
eval_ds = load_common_voice("validation", cache_dir=cache_dir) |
|
|
elif dataset == "vivos": |
|
|
train_ds = load_vivos("train", cache_dir=cache_dir) |
|
|
eval_ds = load_vivos("test", cache_dir=cache_dir) |
|
|
else: |
|
|
from datasets import concatenate_datasets |
|
|
cv_train = load_common_voice("train", cache_dir=cache_dir) |
|
|
vivos_train = load_vivos("train", cache_dir=cache_dir) |
|
|
train_ds = concatenate_datasets([cv_train, vivos_train]) |
|
|
eval_ds = load_common_voice("validation", cache_dir=cache_dir) |
|
|
|
|
|
click.echo(f" Train: {len(train_ds)} samples") |
|
|
click.echo(f" Eval: {len(eval_ds)} samples") |
|
|
|
|
|
|
|
|
click.echo("\nPreparing datasets...") |
|
|
train_ds = train_ds.map( |
|
|
lambda batch: prepare_dataset(batch, processor), |
|
|
remove_columns=train_ds.column_names, |
|
|
) |
|
|
eval_ds = eval_ds.map( |
|
|
lambda batch: prepare_dataset(batch, processor), |
|
|
remove_columns=eval_ds.column_names, |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
|
|
|
|
|
|
|
|
wer_metric = evaluate.load("wer") |
|
|
cer_metric = evaluate.load("cer") |
|
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
|
output_dir=output, |
|
|
per_device_train_batch_size=batch_size, |
|
|
per_device_eval_batch_size=batch_size, |
|
|
gradient_accumulation_steps=grad_accum, |
|
|
learning_rate=lr, |
|
|
warmup_steps=warmup_steps, |
|
|
max_steps=max_steps, |
|
|
num_train_epochs=epochs, |
|
|
fp16=fp16 and torch.cuda.is_available(), |
|
|
eval_strategy="steps", |
|
|
eval_steps=eval_steps, |
|
|
save_strategy="steps", |
|
|
save_steps=save_steps, |
|
|
logging_steps=25, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="wer", |
|
|
greater_is_better=False, |
|
|
predict_with_generate=True, |
|
|
generation_max_length=225, |
|
|
report_to="wandb" if use_wandb else "none", |
|
|
push_to_hub=push_to_hub, |
|
|
hub_model_id=hub_model_id if push_to_hub else None, |
|
|
save_total_limit=3, |
|
|
dataloader_num_workers=4, |
|
|
remove_unused_columns=False, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=eval_ds, |
|
|
data_collator=data_collator, |
|
|
processing_class=processor.feature_extractor, |
|
|
compute_metrics=lambda pred: compute_metrics(pred, processor, wer_metric, cer_metric), |
|
|
) |
|
|
|
|
|
|
|
|
click.echo(f"\nTraining for {epochs} epochs...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
click.echo(f"\nSaving model to {output}") |
|
|
trainer.save_model(output) |
|
|
processor.save_pretrained(output) |
|
|
|
|
|
|
|
|
click.echo("\nFinal evaluation...") |
|
|
metrics = trainer.evaluate() |
|
|
click.echo(f" WER: {metrics['eval_wer']:.2f}%") |
|
|
click.echo(f" CER: {metrics['eval_cer']:.2f}%") |
|
|
|
|
|
click.echo(f"\nModel saved to: {output}") |
|
|
|
|
|
if push_to_hub: |
|
|
click.echo(f"Pushing to HuggingFace Hub: {hub_model_id}") |
|
|
trainer.push_to_hub() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
train() |
|
|
|