Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
File size: 8,687 Bytes
5763d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch>=2.0.0",
#     "torchaudio>=2.0.0",
#     "transformers>=4.36.0",
#     "datasets>=2.14.0",
#     "click>=8.0.0",
#     "tqdm>=4.60.0",
#     "wandb>=0.15.0",
#     "python-dotenv>=1.0.0",
#     "jiwer>=3.0.0",
#     "huggingface_hub>=0.20.0",
# ]
# ///
"""
Training script for ASR-1 Vietnamese Speech Recognition.

Fine-tunes OpenAI Whisper on Vietnamese speech datasets.

Usage:
    uv run src/train.py
    uv run src/train.py --base-model openai/whisper-large-v3
    uv run src/train.py --dataset vivos
    uv run src/train.py --wandb --wandb-project asr-1
"""

import sys
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from dotenv import load_dotenv
load_dotenv()

import torch
import click
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import Audio
import evaluate

sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data import load_common_voice, load_vivos, prepare_dataset


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """Data collator for Whisper speech-to-text training."""

    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": f["input_features"]} for f in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        label_features = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 for loss computation
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # Remove BOS token if present (Whisper adds it during generation)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch


def compute_metrics(pred, processor, wer_metric, cer_metric):
    """Compute WER and CER metrics."""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer, "cer": cer}


@click.command()
@click.option('--base-model', default='openai/whisper-large-v3', help='Base Whisper model')
@click.option('--dataset', type=click.Choice(['common_voice', 'vivos', 'both']), default='common_voice',
              help='Training dataset')
@click.option('--output', '-o', default='models/asr-1', help='Output directory')
@click.option('--epochs', default=3, type=int, help='Number of training epochs')
@click.option('--batch-size', default=8, type=int, help='Per-device batch size')
@click.option('--grad-accum', default=2, type=int, help='Gradient accumulation steps')
@click.option('--lr', default=1e-5, type=float, help='Learning rate')
@click.option('--warmup-steps', default=500, type=int, help='Warmup steps')
@click.option('--max-steps', default=-1, type=int, help='Max training steps (-1 for epoch-based)')
@click.option('--fp16/--no-fp16', default=True, help='Use mixed precision')
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
@click.option('--wandb-project', default='asr-1', help='W&B project name')
@click.option('--push-to-hub', is_flag=True, help='Push model to HuggingFace Hub')
@click.option('--hub-model-id', default='undertheseanlp/asr-1', help='HuggingFace Hub model ID')
@click.option('--eval-steps', default=500, type=int, help='Evaluate every N steps')
@click.option('--save-steps', default=500, type=int, help='Save checkpoint every N steps')
@click.option('--cache-dir', default=None, help='Dataset cache directory')
def train(base_model, dataset, output, epochs, batch_size, grad_accum, lr,
          warmup_steps, max_steps, fp16, use_wandb, wandb_project, push_to_hub,
          hub_model_id, eval_steps, save_steps, cache_dir):
    """Train ASR-1 Vietnamese Speech Recognition model."""

    device = "cuda" if torch.cuda.is_available() else "cpu"
    click.echo(f"Using device: {device}")

    click.echo("=" * 60)
    click.echo("ASR-1: Vietnamese Automatic Speech Recognition")
    click.echo("=" * 60)

    # Load processor and model
    click.echo(f"\nLoading base model: {base_model}")
    processor = WhisperProcessor.from_pretrained(base_model)
    model = WhisperForConditionalGeneration.from_pretrained(base_model)

    # Force Vietnamese language and transcription task
    model.generation_config.language = "vi"
    model.generation_config.task = "transcribe"
    model.generation_config.forced_decoder_ids = None

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    click.echo(f"  Parameters: {n_params:,}")

    # Load datasets
    click.echo(f"\nLoading dataset: {dataset}")
    if dataset == "common_voice":
        train_ds = load_common_voice("train", cache_dir=cache_dir)
        eval_ds = load_common_voice("validation", cache_dir=cache_dir)
    elif dataset == "vivos":
        train_ds = load_vivos("train", cache_dir=cache_dir)
        eval_ds = load_vivos("test", cache_dir=cache_dir)
    else:  # both
        from datasets import concatenate_datasets
        cv_train = load_common_voice("train", cache_dir=cache_dir)
        vivos_train = load_vivos("train", cache_dir=cache_dir)
        train_ds = concatenate_datasets([cv_train, vivos_train])
        eval_ds = load_common_voice("validation", cache_dir=cache_dir)

    click.echo(f"  Train: {len(train_ds)} samples")
    click.echo(f"  Eval: {len(eval_ds)} samples")

    # Prepare datasets
    click.echo("\nPreparing datasets...")
    train_ds = train_ds.map(
        lambda batch: prepare_dataset(batch, processor),
        remove_columns=train_ds.column_names,
    )
    eval_ds = eval_ds.map(
        lambda batch: prepare_dataset(batch, processor),
        remove_columns=eval_ds.column_names,
    )

    # Data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # Metrics
    wer_metric = evaluate.load("wer")
    cer_metric = evaluate.load("cer")

    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=grad_accum,
        learning_rate=lr,
        warmup_steps=warmup_steps,
        max_steps=max_steps,
        num_train_epochs=epochs,
        fp16=fp16 and torch.cuda.is_available(),
        eval_strategy="steps",
        eval_steps=eval_steps,
        save_strategy="steps",
        save_steps=save_steps,
        logging_steps=25,
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        predict_with_generate=True,
        generation_max_length=225,
        report_to="wandb" if use_wandb else "none",
        push_to_hub=push_to_hub,
        hub_model_id=hub_model_id if push_to_hub else None,
        save_total_limit=3,
        dataloader_num_workers=4,
        remove_unused_columns=False,
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=data_collator,
        processing_class=processor.feature_extractor,
        compute_metrics=lambda pred: compute_metrics(pred, processor, wer_metric, cer_metric),
    )

    # Train
    click.echo(f"\nTraining for {epochs} epochs...")
    trainer.train()

    # Save best model
    click.echo(f"\nSaving model to {output}")
    trainer.save_model(output)
    processor.save_pretrained(output)

    # Final evaluation
    click.echo("\nFinal evaluation...")
    metrics = trainer.evaluate()
    click.echo(f"  WER: {metrics['eval_wer']:.2f}%")
    click.echo(f"  CER: {metrics['eval_cer']:.2f}%")

    click.echo(f"\nModel saved to: {output}")

    if push_to_hub:
        click.echo(f"Pushing to HuggingFace Hub: {hub_model_id}")
        trainer.push_to_hub()


if __name__ == '__main__':
    train()