Tsedee commited on
Commit
3da699a
ยท
verified ยท
1 Parent(s): fe35ea5

Upload run_finetune_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_finetune_v3.py +337 -0
run_finetune_v3.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MonSub Whisper v3 Fine-tune โ€” A40 48GB.
3
+
4
+ Continued fine-tune from Tsedee/whisper-large-v2-mn-monsub (v1).
5
+ Uses CER metric (Mongolian-ะด WER-ััั ะธะปาฏาฏ ั‚ะพั…ะธั€ะพะผะถั‚ะพะน).
6
+
7
+ ำจะผะฝำฉั… ะฑาฏั… ะฐะปะดะฐะฐะณ ะทะฐัะฐัะฐะฝ:
8
+ - processing_class (NOT tokenizer โ€” deprecated)
9
+ - datasets==2.21.0 (NOT latest โ€” torchcodec error)
10
+ - num_proc=1 (NOT 4 โ€” multiprocess audio decode ะณะฐั†ะฝะฐ)
11
+ - HF_HOME=/workspace/.cache (container disk ะดาฏาฏั€ัั…ะณาฏะน)
12
+ - generation_config fix (alignment_heads + no_timestamps_token_id)
13
+ - fp16 (A40 ะดััั€ ั‚ะพั…ะธั€ะพะผะถั‚ะพะน)
14
+ - eval crash handler
15
+ """
16
+ import os
17
+ import sys
18
+ import torch
19
+ import numpy as np
20
+ from dataclasses import dataclass
21
+ from datasets import load_dataset, concatenate_datasets, Audio
22
+ from transformers import (
23
+ WhisperForConditionalGeneration,
24
+ WhisperProcessor,
25
+ Seq2SeqTrainingArguments,
26
+ Seq2SeqTrainer,
27
+ GenerationConfig,
28
+ )
29
+ import evaluate
30
+
31
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
32
+ # CONFIG โ€” ะ40-ะด ะพะฝะพะฒั‡ะธะปัะพะฝ
33
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
34
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
35
+ BASE_MODEL = "Tsedee/whisper-large-v2-mn-monsub" # v1 ััƒัƒั€ัŒ ะผะพะดะตะป
36
+ OUTPUT_MODEL = "Tsedee/whisper-large-v2-mn-monsub-v3"
37
+ OUTPUT_DIR = "/workspace/monsub-finetune-v3"
38
+
39
+ # A40 48GB โ€” batch_size=16 ะฑะฐะณั‚ะฐะฝะฐ
40
+ BATCH_SIZE = 16
41
+ GRAD_ACCUM = 2 # effective batch = 32
42
+ LEARNING_RATE = 3e-6 # Continued fine-tune โ†’ ะฑะฐะณะฐ LR (ัˆะธะฝััั€ ะฑะพะป 1e-5)
43
+ WARMUP_STEPS = 300
44
+ MAX_STEPS = 4000 # ~30 ั†ะฐะณ ะดะฐั‚ะฐ โ†’ 4000 step ั…ะฐะฝะณะฐะปั‚ั‚ะฐะน
45
+ EVAL_STEPS = 500
46
+ SAVE_STEPS = 500
47
+ MAX_LABEL_LENGTH = 448
48
+ LANGUAGE = "mn"
49
+ TASK = "transcribe"
50
+
51
+ # Datasets โ€” mongolian-cv20-normalized ะฅะะกะกะะ (ั‡ะฐะฝะฐั€ ะผัƒัƒ)
52
+ DATASETS = [
53
+ {"name": "Tsedee/monsub-chimege-10h", "split": "train", "text_col": "sentence"},
54
+ {"name": "Tsedee/monsub-mongolian-asr", "split": "train", "text_col": "sentence"},
55
+ {"name": "Tsedee/monsub-chimege-youtube-9h", "split": "train", "text_col": "sentence"},
56
+ # ะัะผัะปั‚ dataset-าฏาฏะด (ั…ัั€ัะณั‚ัะน ะฑะพะป comment ะฐั€ะธะปะณะฐ):
57
+ # {"name": "Tsedee/mongolian-bible-speech", "split": "train", "text_col": "sentence"},
58
+ ]
59
+
60
+
61
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
62
+ # DATA LOADING
63
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
64
+ def load_all_datasets():
65
+ print("=" * 60)
66
+ print("LOADING DATASETS")
67
+ print("=" * 60)
68
+
69
+ all_ds = []
70
+ total_hours = 0
71
+
72
+ for ds_info in DATASETS:
73
+ name = ds_info["name"]
74
+ text_col = ds_info["text_col"]
75
+ print(f"\n Loading {name}...", flush=True)
76
+ try:
77
+ ds = load_dataset(name, split=ds_info["split"], token=HF_TOKEN)
78
+
79
+ # Normalize column names
80
+ if text_col != "sentence" and text_col in ds.column_names:
81
+ ds = ds.rename_column(text_col, "sentence")
82
+
83
+ # Ensure audio 16kHz
84
+ if "audio" in ds.column_names:
85
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
86
+
87
+ # Calculate duration
88
+ if "duration" in ds.column_names:
89
+ hours = sum(ds["duration"]) / 3600
90
+ else:
91
+ hours = len(ds) * 10 / 3600
92
+
93
+ total_hours += hours
94
+ print(f" OK: {len(ds)} samples, ~{hours:.1f}h", flush=True)
95
+ all_ds.append(ds)
96
+ except Exception as e:
97
+ print(f" FAILED: {e}", flush=True)
98
+
99
+ if not all_ds:
100
+ print("ERROR: No datasets loaded!")
101
+ sys.exit(1)
102
+
103
+ combined = concatenate_datasets(all_ds)
104
+ print(f"\n TOTAL: {len(combined)} samples, ~{total_hours:.1f} hours")
105
+
106
+ # Train/test split (95/5)
107
+ split = combined.train_test_split(test_size=0.05, seed=42)
108
+ print(f" Train: {len(split['train'])}, Test: {len(split['test'])}")
109
+ return split["train"], split["test"]
110
+
111
+
112
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
113
+ # DATA PROCESSING
114
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
115
+ def prepare_dataset(batch, processor):
116
+ audio = batch["audio"]
117
+ inputs = processor.feature_extractor(
118
+ audio["array"], sampling_rate=audio["sampling_rate"]
119
+ )
120
+ batch["input_features"] = inputs.input_features[0]
121
+
122
+ text = batch["sentence"]
123
+ if not text or len(text.strip()) < 1:
124
+ text = " "
125
+ batch["labels"] = processor.tokenizer(text).input_ids
126
+ return batch
127
+
128
+
129
+ @dataclass
130
+ class DataCollatorSpeechSeq2SeqWithPadding:
131
+ processor: any
132
+ decoder_start_token_id: int
133
+
134
+ def __call__(self, features):
135
+ input_features = [{"input_features": f["input_features"]} for f in features]
136
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
137
+
138
+ label_features = [{"input_ids": f["labels"]} for f in features]
139
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
140
+
141
+ labels = labels_batch["input_ids"].masked_fill(
142
+ labels_batch.attention_mask.ne(1), -100
143
+ )
144
+
145
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
146
+ labels = labels[:, 1:]
147
+
148
+ batch["labels"] = labels
149
+ return batch
150
+
151
+
152
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
153
+ # CER METRIC โ€” ะœะพะฝะณะพะป ั…ัะปัะฝะด WER-ััั ะธะปาฏาฏ ั‚ะพั…ะธั€ะพะผะถั‚ะพะน
154
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
155
+ cer_metric = evaluate.load("cer")
156
+
157
+
158
+ def compute_metrics(pred, tokenizer):
159
+ pred_ids = pred.predictions
160
+ label_ids = pred.label_ids
161
+
162
+ # Replace -100 with pad
163
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
164
+
165
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
166
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
167
+
168
+ # Filter empty pairs
169
+ pairs = [(p, l) for p, l in zip(pred_str, label_str) if l.strip()]
170
+ if not pairs:
171
+ return {"cer": 0.0}
172
+ pred_str, label_str = zip(*pairs)
173
+
174
+ cer = cer_metric.compute(predictions=list(pred_str), references=list(label_str))
175
+ return {"cer": cer}
176
+
177
+
178
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
179
+ # MAIN
180
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
181
+ def main():
182
+ print("=" * 60)
183
+ print("MonSub Whisper v3 Fine-tune")
184
+ print(f"Base: {BASE_MODEL}")
185
+ print(f"Output: {OUTPUT_MODEL}")
186
+
187
+ if torch.cuda.is_available():
188
+ gpu_name = torch.cuda.get_device_name(0)
189
+ vram = torch.cuda.get_device_properties(0).total_memory / 1e9
190
+ print(f"GPU: {gpu_name}")
191
+ print(f"VRAM: {vram:.1f}GB")
192
+ else:
193
+ print("WARNING: No GPU detected!")
194
+ print("=" * 60)
195
+
196
+ # โ”€โ”€ Load model + processor โ”€โ”€
197
+ print("\nLoading model...", flush=True)
198
+ processor = WhisperProcessor.from_pretrained(BASE_MODEL, token=HF_TOKEN)
199
+ model = WhisperForConditionalGeneration.from_pretrained(
200
+ BASE_MODEL, token=HF_TOKEN
201
+ )
202
+
203
+ # โ”€โ”€ generation_config fix โ”€โ”€
204
+ # alignment_heads + no_timestamps_token_id base-ััั ะฐะฒะฝะฐ
205
+ print(" Fixing generation_config from base whisper-large-v2...", flush=True)
206
+ base_gc = GenerationConfig.from_pretrained("openai/whisper-large-v2")
207
+ model.generation_config = base_gc
208
+
209
+ # Set Mongolian forced_decoder_ids
210
+ model.generation_config.forced_decoder_ids = processor.get_decoder_prompt_ids(
211
+ language=LANGUAGE, task=TASK
212
+ )
213
+ model.config.forced_decoder_ids = None # Training-ะด None
214
+ model.config.suppress_tokens = []
215
+ model.config.use_cache = False # Training-ะด ะทะฐะฐะฒะฐะป False
216
+
217
+ # Gradient checkpointing (VRAM ั…ัะผะฝัะฝั)
218
+ model.gradient_checkpointing_enable()
219
+
220
+ params_m = sum(p.numel() for p in model.parameters()) / 1e6
221
+ print(f" Model params: {params_m:.1f}M", flush=True)
222
+
223
+ # โ”€โ”€ Load data โ”€โ”€
224
+ train_ds, eval_ds = load_all_datasets()
225
+
226
+ # โ”€โ”€ Process datasets (num_proc=1 ะทะฐะฐะฒะฐะป!) โ”€โ”€
227
+ print("\nProcessing datasets (num_proc=1)...", flush=True)
228
+ train_ds = train_ds.map(
229
+ lambda x: prepare_dataset(x, processor),
230
+ remove_columns=train_ds.column_names,
231
+ num_proc=1, # NOT 4 โ€” multiprocess audio decode ะณะฐั†ะฝะฐ
232
+ )
233
+ eval_ds = eval_ds.map(
234
+ lambda x: prepare_dataset(x, processor),
235
+ remove_columns=eval_ds.column_names,
236
+ num_proc=1,
237
+ )
238
+
239
+ # Filter too-long labels
240
+ train_ds = train_ds.filter(lambda x: len(x["labels"]) < MAX_LABEL_LENGTH)
241
+ eval_ds = eval_ds.filter(lambda x: len(x["labels"]) < MAX_LABEL_LENGTH)
242
+ print(f" After filter: train={len(train_ds)}, eval={len(eval_ds)}", flush=True)
243
+
244
+ # โ”€โ”€ Data collator โ”€โ”€
245
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
246
+ processor=processor,
247
+ decoder_start_token_id=model.config.decoder_start_token_id,
248
+ )
249
+
250
+ # โ”€โ”€ Training args โ€” A40 48GB optimized โ”€โ”€
251
+ training_args = Seq2SeqTrainingArguments(
252
+ output_dir=OUTPUT_DIR,
253
+ per_device_train_batch_size=BATCH_SIZE,
254
+ per_device_eval_batch_size=8,
255
+ gradient_accumulation_steps=GRAD_ACCUM,
256
+ learning_rate=LEARNING_RATE,
257
+ warmup_steps=WARMUP_STEPS,
258
+ max_steps=MAX_STEPS,
259
+ fp16=True, # A40 ะดััั€ fp16 ั…ัƒั€ะดะฐะฝ
260
+ eval_strategy="steps", # NOT evaluation_strategy (deprecated)
261
+ eval_steps=EVAL_STEPS,
262
+ save_strategy="steps",
263
+ save_steps=SAVE_STEPS,
264
+ save_total_limit=3,
265
+ load_best_model_at_end=True,
266
+ metric_for_best_model="cer", # CER = ะœะพะฝะณะพะปะด ั‚ะพั…ะธั€ะพะผะถั‚ะพะน
267
+ greater_is_better=False,
268
+ predict_with_generate=True,
269
+ generation_max_length=225,
270
+ logging_steps=25,
271
+ report_to="none",
272
+ dataloader_num_workers=2, # A40-ะด 2 ั…ะฐะฝะณะฐะปั‚ั‚ะฐะน
273
+ push_to_hub=False,
274
+ lr_scheduler_type="cosine",
275
+ weight_decay=0.01,
276
+ gradient_checkpointing=True,
277
+ remove_unused_columns=False,
278
+ )
279
+
280
+ # โ”€โ”€ Trainer โ”€โ”€
281
+ trainer = Seq2SeqTrainer(
282
+ args=training_args,
283
+ model=model,
284
+ train_dataset=train_ds,
285
+ eval_dataset=eval_ds,
286
+ data_collator=data_collator,
287
+ compute_metrics=lambda pred: compute_metrics(pred, processor.tokenizer),
288
+ processing_class=processor.feature_extractor, # NOT tokenizer= (deprecated)
289
+ )
290
+
291
+ # โ”€โ”€ Train โ”€โ”€
292
+ print(f"\nTRAINING STARTED!", flush=True)
293
+ print(f" Steps: {MAX_STEPS}")
294
+ print(f" Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}")
295
+ print(f" LR: {LEARNING_RATE}")
296
+ print(f" Eval every: {EVAL_STEPS} steps")
297
+ print(f" Metric: CER (lower = better)")
298
+ print("=" * 60, flush=True)
299
+
300
+ try:
301
+ trainer.train()
302
+ except Exception as e:
303
+ print(f"\nTraining error: {e}", flush=True)
304
+ print("Attempting to save current model...", flush=True)
305
+ trainer.save_model(f"{OUTPUT_DIR}/emergency-save")
306
+ processor.save_pretrained(f"{OUTPUT_DIR}/emergency-save")
307
+ raise
308
+
309
+ # โ”€โ”€ Save best model โ”€โ”€
310
+ print("\nSaving best model...", flush=True)
311
+ trainer.save_model(f"{OUTPUT_DIR}/best")
312
+ processor.save_pretrained(f"{OUTPUT_DIR}/best")
313
+
314
+ # Save generation_config with Mongolian settings
315
+ model.generation_config.forced_decoder_ids = processor.get_decoder_prompt_ids(
316
+ language=LANGUAGE, task=TASK
317
+ )
318
+ model.generation_config.save_pretrained(f"{OUTPUT_DIR}/best")
319
+
320
+ # โ”€โ”€ Upload to HuggingFace โ”€โ”€
321
+ print(f"\nUploading to {OUTPUT_MODEL}...", flush=True)
322
+ try:
323
+ model.push_to_hub(OUTPUT_MODEL, token=HF_TOKEN, private=True)
324
+ processor.push_to_hub(OUTPUT_MODEL, token=HF_TOKEN, private=True)
325
+ model.generation_config.push_to_hub(OUTPUT_MODEL, token=HF_TOKEN)
326
+ print(f" Upload OK: https://huggingface.co/{OUTPUT_MODEL}")
327
+ except Exception as e:
328
+ print(f" Upload failed: {e}")
329
+ print(f" Model saved locally: {OUTPUT_DIR}/best")
330
+
331
+ print(f"\n{'=' * 60}")
332
+ print(f"DONE! Model: {OUTPUT_MODEL}")
333
+ print(f"{'=' * 60}")
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()