neurlang commited on
Commit
435e38e
·
verified ·
1 Parent(s): 0709e6e

Add trainer script itself

Browse files
Files changed (1) hide show
  1. whisper_medium.py +249 -0
whisper_medium.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # RESUMABLE WHISPER TRAINING SCRIPT WITH TIMESTAMP SUPPORT
3
+ # ============================================================
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Union
7
+ import os
8
+ import gc
9
+ import torch
10
+ import pandas as pd
11
+ import evaluate
12
+
13
+ from datasets import Dataset, Audio
14
+ from transformers import (
15
+ WhisperForConditionalGeneration,
16
+ WhisperProcessor,
17
+ Seq2SeqTrainer,
18
+ Seq2SeqTrainingArguments,
19
+ )
20
+
21
+ # ============================================================
22
+ # CONFIG
23
+ # ============================================================
24
+
25
+ MODEL_SIZE = "medium"
26
+ BASE_MODEL = f"neurlang/ipa-whisper-{MODEL_SIZE}"
27
+ OUTPUT_DIR = f"whisper-{MODEL_SIZE}-finetuned"
28
+
29
+ RESUME_TRAINING = False # 🔁 flip to True to resume
30
+ RESUME_CHECKPOINT = "checkpoint-1840000" # e.g. "checkpoint-40000"
31
+ RESUME_CHECKPOINT_TARGET = 1880000 # e.g. 80000
32
+
33
+ # don't tune this, it's auto tuned on training start/resume
34
+ lr = 1.251564455569462e-07 # 1e-5
35
+
36
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ # ============================================================
39
+ # LOAD DATA
40
+ #
41
+ # FORMAT:
42
+ # foo.mp3,hello world
43
+ #
44
+ # ============================================================
45
+
46
+ train_df = pd.read_csv("train.csv")
47
+ eval_df = pd.read_csv("test.csv")
48
+
49
+ train_df.columns = ["audio", "sentence"]
50
+ eval_df.columns = ["audio", "sentence"]
51
+
52
+ train_dataset = Dataset.from_pandas(train_df)
53
+ eval_dataset = Dataset.from_pandas(eval_df)
54
+
55
+ train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
56
+ eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=16000))
57
+
58
+ # Shuffle the dataset with exact seed control
59
+ train_dataset = train_dataset.shuffle(seed=42) # Default shuffles all
60
+
61
+ # ============================================================
62
+ # PROCESSOR (TOKENIZER + FEATURE EXTRACTOR)
63
+ # ============================================================
64
+
65
+ if RESUME_TRAINING:
66
+ print(f"🔁 Loading processor from {OUTPUT_DIR}")
67
+ processor = WhisperProcessor.from_pretrained(OUTPUT_DIR)
68
+ else:
69
+ print("🆕 Creating new processor")
70
+ processor = WhisperProcessor.from_pretrained(
71
+ BASE_MODEL,
72
+ language="english",
73
+ task="transcribe",
74
+ predict_timestamps=True,
75
+ )
76
+
77
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
78
+ processor.save_pretrained(OUTPUT_DIR) # 🔒 critical
79
+
80
+ # ============================================================
81
+ # DATA PREPARATION
82
+ # ============================================================
83
+
84
+ def prepare_dataset(batch):
85
+ audio = batch["audio"]
86
+
87
+ batch["input_features"] = processor.feature_extractor(
88
+ audio["array"],
89
+ sampling_rate=16000
90
+ ).input_features[0]
91
+
92
+ text = batch["sentence"] if batch["sentence"] else ""
93
+ batch["labels"] = processor.tokenizer(
94
+ text,
95
+ return_tensors="pt"
96
+ ).input_ids[0]
97
+
98
+ del batch["audio"]
99
+ del batch["sentence"]
100
+ return batch
101
+
102
+ train_dataset = train_dataset.map(prepare_dataset, num_proc=1)
103
+ eval_dataset = eval_dataset.map(prepare_dataset, num_proc=1)
104
+
105
+ # ============================================================
106
+ # DATA COLLATOR
107
+ # ============================================================
108
+
109
+ @dataclass
110
+ class DataCollatorSpeechSeq2SeqWithPadding:
111
+ processor: Any
112
+
113
+ def __call__(self, features):
114
+ inputs = [{"input_features": f["input_features"]} for f in features]
115
+ batch = self.processor.feature_extractor.pad(
116
+ inputs, return_tensors="pt"
117
+ )
118
+
119
+ labels = [{"input_ids": f["labels"]} for f in features]
120
+ labels_batch = self.processor.tokenizer.pad(
121
+ labels, return_tensors="pt"
122
+ )
123
+
124
+ labels = labels_batch["input_ids"].masked_fill(
125
+ labels_batch.attention_mask.ne(1), -100
126
+ )
127
+
128
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all():
129
+ labels = labels[:, 1:]
130
+
131
+ batch["labels"] = labels
132
+ return batch
133
+
134
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor)
135
+
136
+ # ============================================================
137
+ # METRICS
138
+ # ============================================================
139
+
140
+ cer_metric = evaluate.load("cer")
141
+
142
+ def compute_metrics(pred):
143
+ pred_ids = pred.predictions
144
+ label_ids = pred.label_ids
145
+
146
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
147
+
148
+ pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
149
+ label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
150
+
151
+ return {
152
+ "cer": 100 * cer_metric.compute(
153
+ predictions=pred_str,
154
+ references=label_str
155
+ )
156
+ }
157
+
158
+ # ============================================================
159
+ # TRAINING ARGUMENTS
160
+ # ============================================================
161
+
162
+
163
+
164
+ training_args = Seq2SeqTrainingArguments(
165
+ output_dir=OUTPUT_DIR,
166
+ per_device_train_batch_size=4,
167
+ per_device_eval_batch_size=1,
168
+ learning_rate=lr,
169
+ warmup_steps=1000,
170
+ max_steps=RESUME_CHECKPOINT_TARGET,
171
+ evaluation_strategy="steps",
172
+ save_strategy="steps",
173
+ logging_steps=10*100,
174
+ eval_steps=10*100,
175
+ save_steps=10*100,
176
+ save_total_limit=3,
177
+ predict_with_generate=True,
178
+ generation_max_length=225,
179
+ fp16=False,
180
+ report_to=["tensorboard"],
181
+ load_best_model_at_end=False,
182
+ metric_for_best_model="cer",
183
+ greater_is_better=False,
184
+ save_safetensors=True, # 🔒 ensure safetensors
185
+ )
186
+
187
+ # ============================================================
188
+ # LOAD MODEL
189
+ # ============================================================
190
+
191
+ if RESUME_TRAINING:
192
+ assert RESUME_CHECKPOINT is not None, "RESUME_CHECKPOINT must be set"
193
+
194
+ checkpoint_path = os.path.join(OUTPUT_DIR, RESUME_CHECKPOINT)
195
+ print(f"🔁 Loading model from {checkpoint_path}")
196
+
197
+ model = WhisperForConditionalGeneration.from_pretrained(
198
+ checkpoint_path,
199
+ torch_dtype=torch.float32,
200
+ )
201
+ else:
202
+ print("🆕 Loading base model")
203
+ model = WhisperForConditionalGeneration.from_pretrained(
204
+ BASE_MODEL,
205
+ torch_dtype=torch.float32,
206
+ )
207
+
208
+ # 🔒 Modified safety check for Transformers version
209
+ # Transformers' Whisper uses different parameter naming
210
+ print(f"✅ Model loaded successfully")
211
+ print(f" Model type: {type(model)}")
212
+ print(f" Device: {next(model.parameters()).device}")
213
+
214
+ model.to(DEVICE)
215
+
216
+ # ============================================================
217
+ # TRAINER
218
+ # ============================================================
219
+
220
+ trainer = Seq2SeqTrainer(
221
+ model=model,
222
+ args=training_args,
223
+ train_dataset=train_dataset,
224
+ eval_dataset=eval_dataset,
225
+ data_collator=data_collator,
226
+ compute_metrics=compute_metrics,
227
+ )
228
+
229
+ # ============================================================
230
+ # TRAIN
231
+ # ============================================================
232
+
233
+ torch.cuda.empty_cache()
234
+ gc.collect()
235
+
236
+ if RESUME_TRAINING:
237
+ trainer.train(resume_from_checkpoint=checkpoint_path)
238
+ else:
239
+ trainer.train()
240
+
241
+ # ============================================================
242
+ # SAVE FINAL
243
+ # ============================================================
244
+
245
+ trainer.save_model(OUTPUT_DIR)
246
+ processor.save_pretrained(OUTPUT_DIR)
247
+
248
+ print("✅ Training completed successfully")
249
+