Tsedee commited on
Commit
fe35ea5
Β·
verified Β·
1 Parent(s): 25ea3a2

Upload run_finetune.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_finetune.py +312 -0
run_finetune.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MonSub Whisper Full Fine-tune β€” RunPod A100 80GB.
3
+
4
+ Base: Tsedee/whisper-large-v2-mn-monsub (existing Mongolian fine-tune)
5
+ Data: 27 hours (4 HF datasets combined)
6
+ Method: Full fine-tune (all 1.5B params)
7
+ Output: Tsedee/whisper-large-v2-mn-monsub-v2
8
+
9
+ Proven techniques:
10
+ - Mixed precision bf16
11
+ - Gradient checkpointing (saves VRAM)
12
+ - Linear warmup + cosine decay
13
+ - Mongolian forced_decoder_ids
14
+ - Eval every 500 steps, save best by WER
15
+ """
16
+ import os
17
+ import sys
18
+ import torch
19
+ from dataclasses import dataclass
20
+ from datasets import load_dataset, concatenate_datasets, Audio
21
+ from transformers import (
22
+ WhisperForConditionalGeneration,
23
+ WhisperProcessor,
24
+ WhisperFeatureExtractor,
25
+ WhisperTokenizer,
26
+ Seq2SeqTrainingArguments,
27
+ Seq2SeqTrainer,
28
+ GenerationConfig,
29
+ )
30
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
31
+ import evaluate
32
+
33
+ # ═══════════════════════════════════════════════════════════
34
+ # CONFIG
35
+ # ═══════════════════════════════════════════════════════════
36
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
37
+ BASE_MODEL = "Tsedee/whisper-large-v2-mn-monsub"
38
+ OUTPUT_MODEL = "Tsedee/whisper-large-v2-mn-monsub-v2"
39
+ OUTPUT_DIR = "/workspace/monsub-finetune"
40
+
41
+ # Training hyperparams (proven for Whisper full fine-tune)
42
+ BATCH_SIZE = 8 # A40 48GB
43
+ GRAD_ACCUM = 4 # effective batch = 32
44
+ LEARNING_RATE = 5e-6 # low LR for continued fine-tune (not from scratch)
45
+ WARMUP_STEPS = 500
46
+ NUM_EPOCHS = 3
47
+ EVAL_STEPS = 500
48
+ SAVE_STEPS = 500
49
+ MAX_LABEL_LENGTH = 448
50
+ LANGUAGE = "mn"
51
+ TASK = "transcribe"
52
+
53
+ # Datasets to combine
54
+ DATASETS = [
55
+ {"name": "Tsedee/monsub-chimege-10h", "split": "train", "text_col": "sentence"},
56
+ {"name": "Tsedee/monsub-mongolian-asr", "split": "train", "text_col": "sentence"},
57
+ {"name": "Tsedee/mongolian-cv20-normalized", "split": "train", "text_col": "sentence"},
58
+ {"name": "Tsedee/monsub-chimege-youtube-9h", "split": "train", "text_col": "sentence"},
59
+ ]
60
+
61
+ normalizer = BasicTextNormalizer()
62
+
63
+
64
+ # ═══════════════════════════════════════════════════════════
65
+ # DATA LOADING
66
+ # ═══════════════════════════════════════════════════════════
67
+ def load_all_datasets():
68
+ """Load and combine all datasets."""
69
+ print("=" * 60)
70
+ print("LOADING DATASETS")
71
+ print("=" * 60)
72
+
73
+ all_ds = []
74
+ total_hours = 0
75
+
76
+ for ds_info in DATASETS:
77
+ name = ds_info["name"]
78
+ text_col = ds_info["text_col"]
79
+ print(f"\n Loading {name}...", flush=True)
80
+ try:
81
+ ds = load_dataset(name, split=ds_info["split"], token=HF_TOKEN)
82
+ # Normalize column names
83
+ if text_col != "sentence" and text_col in ds.column_names:
84
+ ds = ds.rename_column(text_col, "sentence")
85
+ # Ensure audio column
86
+ if "audio" in ds.column_names:
87
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
88
+
89
+ # Calculate duration
90
+ if "duration" in ds.column_names:
91
+ hours = sum(ds["duration"]) / 3600
92
+ else:
93
+ hours = len(ds) * 10 / 3600 # estimate ~10s per sample
94
+
95
+ total_hours += hours
96
+ print(f" βœ… {len(ds)} samples, ~{hours:.1f}h", flush=True)
97
+ all_ds.append(ds)
98
+ except Exception as e:
99
+ print(f" ❌ Failed: {e}", flush=True)
100
+
101
+ # Combine
102
+ combined = concatenate_datasets(all_ds)
103
+ print(f"\n TOTAL: {len(combined)} samples, ~{total_hours:.1f} hours")
104
+
105
+ # Train/test split (95/5)
106
+ split = combined.train_test_split(test_size=0.05, seed=42)
107
+ print(f" Train: {len(split['train'])}, Test: {len(split['test'])}")
108
+ return split["train"], split["test"]
109
+
110
+
111
+ # ═══════════════════════════════════════════════════════════
112
+ # DATA PROCESSING
113
+ # ═══════════════════════════════════════════════════════════
114
+ def prepare_dataset(batch, processor):
115
+ """Process a batch: audio β†’ features, text β†’ labels."""
116
+ audio = batch["audio"]
117
+ inputs = processor.feature_extractor(
118
+ audio["array"], sampling_rate=audio["sampling_rate"]
119
+ )
120
+ batch["input_features"] = inputs.input_features[0]
121
+
122
+ # Tokenize text
123
+ text = batch["sentence"]
124
+ if not text or len(text.strip()) < 1:
125
+ text = " "
126
+ batch["labels"] = processor.tokenizer(text).input_ids
127
+ return batch
128
+
129
+
130
+ @dataclass
131
+ class DataCollatorSpeechSeq2SeqWithPadding:
132
+ processor: any
133
+ decoder_start_token_id: int
134
+
135
+ def __call__(self, features):
136
+ input_features = [{"input_features": f["input_features"]} for f in features]
137
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
138
+
139
+ label_features = [{"input_ids": f["labels"]} for f in features]
140
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
141
+
142
+ labels = labels_batch["input_ids"].masked_fill(
143
+ labels_batch.attention_mask.ne(1), -100
144
+ )
145
+
146
+ # Remove decoder_start_token_id from labels
147
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
148
+ labels = labels[:, 1:]
149
+
150
+ batch["labels"] = labels
151
+ return batch
152
+
153
+
154
+ # ═══════════════════════════════════════════════════════════
155
+ # METRICS
156
+ # ═══════════════════════════════════════════════════════════
157
+ wer_metric = evaluate.load("wer")
158
+
159
+
160
+ def compute_metrics(pred, tokenizer):
161
+ pred_ids = pred.predictions
162
+ label_ids = pred.label_ids
163
+
164
+ # Replace -100 with pad
165
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
166
+
167
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
168
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
169
+
170
+ # Normalize
171
+ pred_str = [normalizer(p) for p in pred_str]
172
+ label_str = [normalizer(l) for l in label_str]
173
+
174
+ wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
175
+ return {"wer": wer}
176
+
177
+
178
+ # ═══════════════════════════════════════════════════════════
179
+ # MAIN
180
+ # ═══════════════════════════════════════════════════════════
181
+ def main():
182
+ print("=" * 60)
183
+ print("MonSub Whisper Full Fine-tune")
184
+ print(f"Base: {BASE_MODEL}")
185
+ print(f"Output: {OUTPUT_MODEL}")
186
+ print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
187
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB" if torch.cuda.is_available() else "")
188
+ print("=" * 60)
189
+
190
+ # Load model + processor
191
+ print("\nπŸ“¦ Loading model...", flush=True)
192
+ processor = WhisperProcessor.from_pretrained(BASE_MODEL, token=HF_TOKEN)
193
+ model = WhisperForConditionalGeneration.from_pretrained(
194
+ BASE_MODEL, token=HF_TOKEN, torch_dtype=torch.bfloat16
195
+ )
196
+
197
+ # Fix generation config from base whisper-large-v2
198
+ print(" Fixing generation_config...", flush=True)
199
+ base_gc = GenerationConfig.from_pretrained("openai/whisper-large-v2")
200
+ model.generation_config = base_gc
201
+
202
+ # Set Mongolian
203
+ model.generation_config.forced_decoder_ids = processor.get_decoder_prompt_ids(
204
+ language=LANGUAGE, task=TASK
205
+ )
206
+ model.config.forced_decoder_ids = model.generation_config.forced_decoder_ids
207
+ model.config.suppress_tokens = []
208
+
209
+ # Enable gradient checkpointing (saves VRAM)
210
+ model.config.use_cache = False
211
+ model.gradient_checkpointing_enable()
212
+
213
+ print(f" Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M", flush=True)
214
+
215
+ # Load data
216
+ train_ds, eval_ds = load_all_datasets()
217
+
218
+ # Process datasets
219
+ print("\nπŸ”„ Processing datasets...", flush=True)
220
+ train_ds = train_ds.map(
221
+ lambda x: prepare_dataset(x, processor),
222
+ remove_columns=train_ds.column_names,
223
+ num_proc=4,
224
+ )
225
+ eval_ds = eval_ds.map(
226
+ lambda x: prepare_dataset(x, processor),
227
+ remove_columns=eval_ds.column_names,
228
+ num_proc=4,
229
+ )
230
+
231
+ # Filter too-long labels
232
+ train_ds = train_ds.filter(lambda x: len(x["labels"]) < MAX_LABEL_LENGTH)
233
+ eval_ds = eval_ds.filter(lambda x: len(x["labels"]) < MAX_LABEL_LENGTH)
234
+ print(f" After filter: train={len(train_ds)}, eval={len(eval_ds)}", flush=True)
235
+
236
+ # Data collator
237
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
238
+ processor=processor,
239
+ decoder_start_token_id=model.config.decoder_start_token_id,
240
+ )
241
+
242
+ # Training args
243
+ training_args = Seq2SeqTrainingArguments(
244
+ output_dir=OUTPUT_DIR,
245
+ per_device_train_batch_size=BATCH_SIZE,
246
+ per_device_eval_batch_size=8,
247
+ gradient_accumulation_steps=GRAD_ACCUM,
248
+ learning_rate=LEARNING_RATE,
249
+ warmup_steps=WARMUP_STEPS,
250
+ num_train_epochs=NUM_EPOCHS,
251
+ bf16=True,
252
+ evaluation_strategy="steps",
253
+ eval_steps=EVAL_STEPS,
254
+ save_strategy="steps",
255
+ save_steps=SAVE_STEPS,
256
+ save_total_limit=3,
257
+ load_best_model_at_end=True,
258
+ metric_for_best_model="wer",
259
+ greater_is_better=False,
260
+ predict_with_generate=True,
261
+ generation_max_length=225,
262
+ logging_steps=50,
263
+ report_to="none",
264
+ dataloader_num_workers=4,
265
+ push_to_hub=False,
266
+ lr_scheduler_type="cosine",
267
+ weight_decay=0.01,
268
+ gradient_checkpointing=True,
269
+ remove_unused_columns=False,
270
+ )
271
+
272
+ # Trainer
273
+ trainer = Seq2SeqTrainer(
274
+ args=training_args,
275
+ model=model,
276
+ train_dataset=train_ds,
277
+ eval_dataset=eval_ds,
278
+ data_collator=data_collator,
279
+ compute_metrics=lambda pred: compute_metrics(pred, processor.tokenizer),
280
+ tokenizer=processor.feature_extractor,
281
+ )
282
+
283
+ # Train!
284
+ print("\nπŸš€ TRAINING STARTED!", flush=True)
285
+ print(f" Epochs: {NUM_EPOCHS}")
286
+ print(f" Batch: {BATCH_SIZE} Γ— {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
287
+ print(f" LR: {LEARNING_RATE}")
288
+ print(f" Eval every: {EVAL_STEPS} steps")
289
+ print("=" * 60, flush=True)
290
+
291
+ trainer.train()
292
+
293
+ # Save best model
294
+ print("\nπŸ’Ύ Saving best model...", flush=True)
295
+ trainer.save_model(f"{OUTPUT_DIR}/best")
296
+ processor.save_pretrained(f"{OUTPUT_DIR}/best")
297
+
298
+ # Upload to HuggingFace
299
+ print(f"\nπŸ“€ Uploading to {OUTPUT_MODEL}...", flush=True)
300
+ model.push_to_hub(OUTPUT_MODEL, token=HF_TOKEN, private=True)
301
+ processor.push_to_hub(OUTPUT_MODEL, token=HF_TOKEN, private=True)
302
+
303
+ # Also upload generation_config
304
+ model.generation_config.save_pretrained(f"{OUTPUT_DIR}/best")
305
+
306
+ print(f"\n{'=' * 60}")
307
+ print(f"βœ… DONE! Model: https://huggingface.co/{OUTPUT_MODEL}")
308
+ print(f"{'=' * 60}")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()