monishaaura commited on
Commit
90c88ff
Β·
1 Parent(s): 0f6a073

Update training script with retry logic and resampy dependency

Browse files
Files changed (1) hide show
  1. train_ravdess.py +224 -35
train_ravdess.py CHANGED
@@ -1,4 +1,16 @@
1
  #!/usr/bin/env python
 
 
 
 
 
 
 
 
 
 
 
 
2
  import argparse
3
  import glob
4
  import io
@@ -14,6 +26,8 @@ import pyarrow as pa
14
  import pyarrow.parquet as pq
15
  import soundfile as sf
16
  import torch
 
 
17
  from torch.nn.utils.rnn import pad_sequence
18
  from datasets import Dataset
19
  from huggingface_hub import snapshot_download
@@ -65,43 +79,117 @@ class DataCollatorWithPadding:
65
  }
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def compute_metrics(eval_pred):
69
  accuracy_metric = evaluate.load("accuracy")
70
  predictions, labels = eval_pred
71
  preds = np.argmax(predictions, axis=1)
72
- return accuracy_metric.compute(predictions=preds, references=labels)
 
 
 
 
 
 
 
 
 
73
 
74
 
75
- def prepare_dataset(batch, processor, sampling_rate):
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  audio_arrays: List[np.ndarray] = []
 
77
  for audio_bytes in batch["audio_bytes"]:
 
78
  with io.BytesIO(audio_bytes) as buffer:
79
- waveform, source_sr = sf.read(buffer)
 
 
80
  if waveform.ndim > 1:
81
  waveform = np.mean(waveform, axis=1)
 
 
82
  if source_sr != sampling_rate:
83
- waveform = librosa.resample(waveform, orig_sr=source_sr, target_sr=sampling_rate)
84
- audio_arrays.append(waveform.astype(np.float32))
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  processed = processor(
87
  audio_arrays,
88
  sampling_rate=sampling_rate,
89
  return_attention_mask=True,
90
  )
 
91
  batch["input_values"] = [
92
  np.asarray(array, dtype=np.float32) for array in processed["input_values"]
93
  ]
 
94
  if "attention_mask" in processed:
95
  batch["attention_mask"] = [
96
  np.asarray(mask, dtype=np.int64) for mask in processed["attention_mask"]
97
  ]
98
-
99
  batch["labels"] = [int(label) for label in batch["label"]]
100
  return batch
101
 
102
 
103
  def parse_args():
104
- parser = argparse.ArgumentParser()
105
  parser.add_argument("--model_name_or_path", default="facebook/wav2vec2-base-960h")
106
  default_output_dir = os.path.join(os.path.dirname(__file__), "wav2vec2-ravdess-emotion")
107
  parser.add_argument("--output_dir", default=default_output_dir)
@@ -110,10 +198,11 @@ def parse_args():
110
  parser.add_argument("--train_split", default="train")
111
  parser.add_argument("--eval_split", default="test")
112
  parser.add_argument("--sampling_rate", type=int, default=16_000)
113
- parser.add_argument("--num_train_epochs", type=float, default=10.0)
114
- parser.add_argument("--per_device_train_batch_size", type=int, default=8)
115
- parser.add_argument("--per_device_eval_batch_size", type=int, default=8)
116
- parser.add_argument("--learning_rate", type=float, default=2e-5)
 
117
  parser.add_argument("--warmup_ratio", type=float, default=0.1)
118
  parser.add_argument("--weight_decay", type=float, default=0.01)
119
  parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
@@ -129,16 +218,27 @@ def parse_args():
129
  def main():
130
  args = parse_args()
131
  set_seed(args.seed)
132
-
 
 
 
 
 
 
 
 
 
 
 
133
  snapshot_path = snapshot_download(
134
  repo_id=args.dataset_name,
135
  repo_type="dataset",
136
  cache_dir=os.getenv("HF_HOME"),
137
  token=os.getenv("HF_TOKEN"),
138
  )
139
-
140
  split_root = os.path.join(snapshot_path, args.dataset_config) if args.dataset_config else snapshot_path
141
-
142
  def load_split(split_name: str):
143
  pattern = os.path.join(split_root, f"{split_name}-*.parquet")
144
  parquet_files = sorted(glob.glob(pattern))
@@ -153,13 +253,13 @@ def main():
153
  "emotion": data["emotion"],
154
  "file": data["file"],
155
  }
156
-
157
  train_dict = load_split(args.train_split)
158
  if train_dict is None:
159
  raise ValueError(f"Could not locate parquet files for split '{args.train_split}' in {split_root}")
160
-
161
  eval_dict = load_split(args.eval_split)
162
-
163
  train_dataset = Dataset.from_dict(train_dict)
164
  if eval_dict is not None:
165
  eval_dataset = Dataset.from_dict(eval_dict)
@@ -167,17 +267,44 @@ def main():
167
  split_dataset = train_dataset.train_test_split(test_size=0.1, seed=args.seed)
168
  train_dataset = split_dataset["train"]
169
  eval_dataset = split_dataset["test"]
170
-
 
 
 
 
 
171
  label_names = {}
172
  for label, emotion in zip(train_dataset["label"], train_dataset["emotion"]):
173
  label_names[int(label)] = emotion
 
 
174
  id2label = {idx: label_names[idx] for idx in sorted(label_names)}
175
  label2id = {name: idx for idx, name in id2label.items()}
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  processor = AutoProcessor.from_pretrained(
178
  args.model_name_or_path,
179
  cache_dir=os.getenv("HF_HOME"),
180
  )
 
181
  config = AutoConfig.from_pretrained(
182
  args.model_name_or_path,
183
  num_labels=len(label2id),
@@ -186,45 +313,61 @@ def main():
186
  finetuning_task="wav2vec2_emotion",
187
  cache_dir=os.getenv("HF_HOME"),
188
  )
189
-
 
 
 
 
 
 
 
190
  processed_train_dataset = train_dataset.map(
191
  prepare_dataset,
192
  fn_kwargs=dict(
193
  processor=processor,
194
  sampling_rate=args.sampling_rate,
 
195
  ),
196
  remove_columns=["audio_bytes", "file", "emotion", "label"],
197
  batched=True,
198
  batch_size=8,
199
  num_proc=1,
200
  )
201
-
 
202
  processed_eval_dataset = eval_dataset.map(
203
  prepare_dataset,
204
  fn_kwargs=dict(
205
  processor=processor,
206
  sampling_rate=args.sampling_rate,
 
207
  ),
208
  remove_columns=["audio_bytes", "file", "emotion", "label"],
209
  batched=True,
210
  batch_size=8,
211
  num_proc=1,
212
  )
213
-
214
  if args.max_train_samples:
215
  processed_train_dataset = processed_train_dataset.select(range(args.max_train_samples))
216
  if args.max_eval_samples:
217
  processed_eval_dataset = processed_eval_dataset.select(range(args.max_eval_samples))
218
-
 
 
219
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
220
  args.model_name_or_path,
221
  config=config,
222
  cache_dir=os.getenv("HF_HOME"),
223
  )
 
 
 
224
  model.freeze_feature_extractor()
225
-
226
  data_collator = DataCollatorWithPadding(processor=processor)
227
-
 
228
  requested_training_arguments = dict(
229
  output_dir=args.output_dir,
230
  per_device_train_batch_size=args.per_device_train_batch_size,
@@ -240,27 +383,33 @@ def main():
240
  group_by_length=True,
241
  dataloader_num_workers=min(4, os.cpu_count() or 1),
242
  logging_steps=25,
 
243
  load_best_model_at_end=True,
244
  metric_for_best_model="accuracy",
 
245
  push_to_hub=args.push_to_hub,
246
  hub_model_id=args.hub_model_id,
247
  hub_private_repo=args.hub_private_repo,
 
248
  )
 
 
249
  training_args_signature = inspect.signature(TrainingArguments)
250
  supported_training_arguments = {
251
  key: value
252
  for key, value in requested_training_arguments.items()
253
  if key in training_args_signature.parameters
254
  }
255
-
256
  if "evaluation_strategy" not in supported_training_arguments:
257
  supported_training_arguments.pop("save_strategy", None)
258
  supported_training_arguments.pop("load_best_model_at_end", None)
259
  supported_training_arguments.pop("metric_for_best_model", None)
260
-
261
  training_args = TrainingArguments(**supported_training_arguments)
262
-
263
- trainer = Trainer(
 
264
  model=model,
265
  args=training_args,
266
  train_dataset=processed_train_dataset,
@@ -268,18 +417,58 @@ def main():
268
  tokenizer=processor,
269
  data_collator=data_collator,
270
  compute_metrics=compute_metrics,
 
271
  )
272
-
 
 
 
 
 
 
 
 
 
 
 
273
  trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  trainer.save_model()
275
  processor.save_pretrained(args.output_dir)
276
-
 
 
 
 
 
 
277
  if args.push_to_hub:
 
278
  trainer.push_to_hub()
279
-
280
- print(f"Model and processor saved to {args.output_dir}")
 
281
 
282
 
283
  if __name__ == "__main__":
284
  main()
285
-
 
1
  #!/usr/bin/env python
2
+ """
3
+ Improved Wav2Vec2 RAVDESS Emotion Detection Training Script
4
+
5
+ Fixes:
6
+ - 25 epochs for proper convergence
7
+ - Feature extractor freeze/unfreeze strategy
8
+ - Balanced class weights for imbalanced dataset
9
+ - Proper audio normalization (16kHz, amplitude)
10
+ - Gaussian noise augmentation
11
+ - Correct label mapping
12
+ """
13
+
14
  import argparse
15
  import glob
16
  import io
 
26
  import pyarrow.parquet as pq
27
  import soundfile as sf
28
  import torch
29
+ import torch.nn as nn
30
+ from sklearn.utils.class_weight import compute_class_weight
31
  from torch.nn.utils.rnn import pad_sequence
32
  from datasets import Dataset
33
  from huggingface_hub import snapshot_download
 
79
  }
80
 
81
 
82
+ class WeightedTrainer(Trainer):
83
+ """Trainer with weighted loss for imbalanced classes"""
84
+
85
+ def __init__(self, class_weights=None, *args, **kwargs):
86
+ super().__init__(*args, **kwargs)
87
+ self.class_weights = class_weights
88
+ if class_weights is not None:
89
+ self.class_weights = torch.tensor(class_weights, dtype=torch.float32)
90
+ if torch.cuda.is_available():
91
+ self.class_weights = self.class_weights.cuda()
92
+
93
+ def compute_loss(self, model, inputs, return_outputs=False):
94
+ labels = inputs.get("labels")
95
+ outputs = model(**inputs)
96
+ logits = outputs.get("logits")
97
+
98
+ if self.class_weights is not None:
99
+ loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
100
+ else:
101
+ loss_fct = nn.CrossEntropyLoss()
102
+
103
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
104
+ return (loss, outputs) if return_outputs else loss
105
+
106
+
107
  def compute_metrics(eval_pred):
108
  accuracy_metric = evaluate.load("accuracy")
109
  predictions, labels = eval_pred
110
  preds = np.argmax(predictions, axis=1)
111
+
112
+ # Also compute per-class metrics
113
+ from sklearn.metrics import classification_report, confusion_matrix
114
+ report = classification_report(labels, preds, output_dict=True, zero_division=0)
115
+
116
+ return {
117
+ "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
118
+ "macro_f1": report.get("macro avg", {}).get("f1-score", 0.0),
119
+ "weighted_f1": report.get("weighted avg", {}).get("f1-score", 0.0),
120
+ }
121
 
122
 
123
+ def add_gaussian_noise(audio: np.ndarray, noise_factor: float = 0.01) -> np.ndarray:
124
+ """Add small Gaussian noise for augmentation"""
125
+ noise = np.random.normal(0, noise_factor, audio.shape).astype(np.float32)
126
+ return np.clip(audio + noise, -1.0, 1.0)
127
+
128
+
129
+ def prepare_dataset(batch, processor, sampling_rate, augment: bool = False):
130
+ """
131
+ Prepare dataset with proper audio normalization and optional augmentation.
132
+
133
+ - Enforces 16kHz resampling
134
+ - Normalizes amplitude to [-1, 1]
135
+ - Optionally adds small Gaussian noise
136
+ """
137
  audio_arrays: List[np.ndarray] = []
138
+
139
  for audio_bytes in batch["audio_bytes"]:
140
+ # Read audio
141
  with io.BytesIO(audio_bytes) as buffer:
142
+ waveform, source_sr = sf.read(buffer, dtype='float32')
143
+
144
+ # Ensure mono
145
  if waveform.ndim > 1:
146
  waveform = np.mean(waveform, axis=1)
147
+
148
+ # Enforce 16kHz resampling
149
  if source_sr != sampling_rate:
150
+ waveform = librosa.resample(
151
+ waveform,
152
+ orig_sr=source_sr,
153
+ target_sr=sampling_rate,
154
+ res_type='kaiser_best'
155
+ )
156
+
157
+ # Normalize amplitude to [-1, 1] range
158
+ max_val = np.abs(waveform).max()
159
+ if max_val > 0:
160
+ waveform = waveform / max_val
161
+
162
+ # Ensure float32
163
+ waveform = waveform.astype(np.float32)
164
+
165
+ # Apply augmentation (only for training)
166
+ if augment:
167
+ waveform = add_gaussian_noise(waveform, noise_factor=0.01)
168
+
169
+ audio_arrays.append(waveform)
170
+
171
+ # Process with feature extractor
172
  processed = processor(
173
  audio_arrays,
174
  sampling_rate=sampling_rate,
175
  return_attention_mask=True,
176
  )
177
+
178
  batch["input_values"] = [
179
  np.asarray(array, dtype=np.float32) for array in processed["input_values"]
180
  ]
181
+
182
  if "attention_mask" in processed:
183
  batch["attention_mask"] = [
184
  np.asarray(mask, dtype=np.int64) for mask in processed["attention_mask"]
185
  ]
186
+
187
  batch["labels"] = [int(label) for label in batch["label"]]
188
  return batch
189
 
190
 
191
  def parse_args():
192
+ parser = argparse.ArgumentParser(description="Train Wav2Vec2 on RAVDESS emotion dataset")
193
  parser.add_argument("--model_name_or_path", default="facebook/wav2vec2-base-960h")
194
  default_output_dir = os.path.join(os.path.dirname(__file__), "wav2vec2-ravdess-emotion")
195
  parser.add_argument("--output_dir", default=default_output_dir)
 
198
  parser.add_argument("--train_split", default="train")
199
  parser.add_argument("--eval_split", default="test")
200
  parser.add_argument("--sampling_rate", type=int, default=16_000)
201
+ parser.add_argument("--num_train_epochs", type=float, default=25.0)
202
+ parser.add_argument("--warmup_epochs", type=int, default=3, help="Epochs with frozen feature extractor")
203
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4)
204
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
205
+ parser.add_argument("--learning_rate", type=float, default=3e-5)
206
  parser.add_argument("--warmup_ratio", type=float, default=0.1)
207
  parser.add_argument("--weight_decay", type=float, default=0.01)
208
  parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
 
218
  def main():
219
  args = parse_args()
220
  set_seed(args.seed)
221
+
222
+ print("=" * 80)
223
+ print("Wav2Vec2 RAVDESS Emotion Detection Training")
224
+ print("=" * 80)
225
+ print(f"Model: {args.model_name_or_path}")
226
+ print(f"Epochs: {args.num_train_epochs} (warmup: {args.warmup_epochs})")
227
+ print(f"Learning rate: {args.learning_rate}")
228
+ print(f"Batch size: {args.per_device_train_batch_size} (gradient accumulation: {args.gradient_accumulation_steps})")
229
+ print("=" * 80)
230
+
231
+ # Download dataset
232
+ print("\nπŸ“₯ Downloading RAVDESS dataset...")
233
  snapshot_path = snapshot_download(
234
  repo_id=args.dataset_name,
235
  repo_type="dataset",
236
  cache_dir=os.getenv("HF_HOME"),
237
  token=os.getenv("HF_TOKEN"),
238
  )
239
+
240
  split_root = os.path.join(snapshot_path, args.dataset_config) if args.dataset_config else snapshot_path
241
+
242
  def load_split(split_name: str):
243
  pattern = os.path.join(split_root, f"{split_name}-*.parquet")
244
  parquet_files = sorted(glob.glob(pattern))
 
253
  "emotion": data["emotion"],
254
  "file": data["file"],
255
  }
256
+
257
  train_dict = load_split(args.train_split)
258
  if train_dict is None:
259
  raise ValueError(f"Could not locate parquet files for split '{args.train_split}' in {split_root}")
260
+
261
  eval_dict = load_split(args.eval_split)
262
+
263
  train_dataset = Dataset.from_dict(train_dict)
264
  if eval_dict is not None:
265
  eval_dataset = Dataset.from_dict(eval_dict)
 
267
  split_dataset = train_dataset.train_test_split(test_size=0.1, seed=args.seed)
268
  train_dataset = split_dataset["train"]
269
  eval_dataset = split_dataset["test"]
270
+
271
+ print(f"βœ… Train samples: {len(train_dataset)}")
272
+ print(f"βœ… Eval samples: {len(eval_dataset)}")
273
+
274
+ # Build label mapping (consistent id2label / label2id)
275
+ print("\nπŸ“Š Building label mapping...")
276
  label_names = {}
277
  for label, emotion in zip(train_dataset["label"], train_dataset["emotion"]):
278
  label_names[int(label)] = emotion
279
+
280
+ # Ensure consistent ordering
281
  id2label = {idx: label_names[idx] for idx in sorted(label_names)}
282
  label2id = {name: idx for idx, name in id2label.items()}
283
+
284
+ print(f"βœ… Labels ({len(id2label)}): {list(id2label.values())}")
285
+ print(f"βœ… Label mapping: {id2label}")
286
+
287
+ # Compute class weights for balanced training
288
+ print("\nβš–οΈ Computing class weights for balanced training...")
289
+ labels_array = np.array(train_dataset["label"])
290
+ unique_labels = np.unique(labels_array)
291
+ class_weights = compute_class_weight(
292
+ 'balanced',
293
+ classes=unique_labels,
294
+ y=labels_array
295
+ )
296
+ class_weight_dict = dict(zip(unique_labels, class_weights))
297
+ class_weight_list = [class_weight_dict[i] for i in sorted(unique_labels)]
298
+
299
+ print(f"βœ… Class weights: {dict(zip([id2label[i] for i in sorted(unique_labels)], class_weight_list))}")
300
+
301
+ # Load processor and config
302
+ print("\nπŸ“¦ Loading processor and config...")
303
  processor = AutoProcessor.from_pretrained(
304
  args.model_name_or_path,
305
  cache_dir=os.getenv("HF_HOME"),
306
  )
307
+
308
  config = AutoConfig.from_pretrained(
309
  args.model_name_or_path,
310
  num_labels=len(label2id),
 
313
  finetuning_task="wav2vec2_emotion",
314
  cache_dir=os.getenv("HF_HOME"),
315
  )
316
+
317
+ # Verify label mapping in config
318
+ print(f"βœ… Config labels: {config.id2label}")
319
+ assert config.label2id == label2id, "Label mapping mismatch!"
320
+ assert config.id2label == id2label, "Label mapping mismatch!"
321
+
322
+ # Prepare datasets with proper normalization
323
+ print("\nπŸ”„ Preparing training dataset (with augmentation)...")
324
  processed_train_dataset = train_dataset.map(
325
  prepare_dataset,
326
  fn_kwargs=dict(
327
  processor=processor,
328
  sampling_rate=args.sampling_rate,
329
+ augment=True, # Add noise augmentation for training
330
  ),
331
  remove_columns=["audio_bytes", "file", "emotion", "label"],
332
  batched=True,
333
  batch_size=8,
334
  num_proc=1,
335
  )
336
+
337
+ print("πŸ”„ Preparing evaluation dataset (no augmentation)...")
338
  processed_eval_dataset = eval_dataset.map(
339
  prepare_dataset,
340
  fn_kwargs=dict(
341
  processor=processor,
342
  sampling_rate=args.sampling_rate,
343
+ augment=False, # No augmentation for eval
344
  ),
345
  remove_columns=["audio_bytes", "file", "emotion", "label"],
346
  batched=True,
347
  batch_size=8,
348
  num_proc=1,
349
  )
350
+
351
  if args.max_train_samples:
352
  processed_train_dataset = processed_train_dataset.select(range(args.max_train_samples))
353
  if args.max_eval_samples:
354
  processed_eval_dataset = processed_eval_dataset.select(range(args.max_eval_samples))
355
+
356
+ # Load model
357
+ print("\nπŸ€– Loading model...")
358
  model = Wav2Vec2ForSequenceClassification.from_pretrained(
359
  args.model_name_or_path,
360
  config=config,
361
  cache_dir=os.getenv("HF_HOME"),
362
  )
363
+
364
+ # Freeze feature extractor initially
365
+ print("πŸ”’ Freezing feature extractor for warmup...")
366
  model.freeze_feature_extractor()
367
+
368
  data_collator = DataCollatorWithPadding(processor=processor)
369
+
370
+ # Training arguments
371
  requested_training_arguments = dict(
372
  output_dir=args.output_dir,
373
  per_device_train_batch_size=args.per_device_train_batch_size,
 
383
  group_by_length=True,
384
  dataloader_num_workers=min(4, os.cpu_count() or 1),
385
  logging_steps=25,
386
+ save_total_limit=3, # Keep only last 3 checkpoints
387
  load_best_model_at_end=True,
388
  metric_for_best_model="accuracy",
389
+ greater_is_better=True,
390
  push_to_hub=args.push_to_hub,
391
  hub_model_id=args.hub_model_id,
392
  hub_private_repo=args.hub_private_repo,
393
+ report_to="none", # Disable wandb/tensorboard
394
  )
395
+
396
+ # Filter to supported arguments
397
  training_args_signature = inspect.signature(TrainingArguments)
398
  supported_training_arguments = {
399
  key: value
400
  for key, value in requested_training_arguments.items()
401
  if key in training_args_signature.parameters
402
  }
403
+
404
  if "evaluation_strategy" not in supported_training_arguments:
405
  supported_training_arguments.pop("save_strategy", None)
406
  supported_training_arguments.pop("load_best_model_at_end", None)
407
  supported_training_arguments.pop("metric_for_best_model", None)
408
+
409
  training_args = TrainingArguments(**supported_training_arguments)
410
+
411
+ # Create trainer with weighted loss
412
+ trainer = WeightedTrainer(
413
  model=model,
414
  args=training_args,
415
  train_dataset=processed_train_dataset,
 
417
  tokenizer=processor,
418
  data_collator=data_collator,
419
  compute_metrics=compute_metrics,
420
+ class_weights=class_weight_list,
421
  )
422
+
423
+ # Phase 1: Train with frozen feature extractor (warmup)
424
+ print("\n" + "=" * 80)
425
+ print(f"PHASE 1: Training with FROZEN feature extractor ({args.warmup_epochs} epochs)")
426
+ print("=" * 80)
427
+
428
+ # Calculate steps for warmup
429
+ total_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.num_train_epochs
430
+ warmup_steps = int(total_steps * args.warmup_ratio)
431
+ warmup_epochs_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.warmup_epochs
432
+
433
+ # Train for warmup epochs
434
  trainer.train()
435
+
436
+ # Check if we've completed warmup epochs
437
+ current_epoch = trainer.state.epoch
438
+ if current_epoch >= args.warmup_epochs:
439
+ print(f"\nβœ… Completed {args.warmup_epochs} warmup epochs")
440
+ print("πŸ”“ Unfreezing feature extractor...")
441
+ model.unfreeze_feature_extractor()
442
+ print("βœ… Feature extractor unfrozen!")
443
+
444
+ # Phase 2: Continue training with unfrozen feature extractor
445
+ print("\n" + "=" * 80)
446
+ print(f"PHASE 2: Training with UNFROZEN feature extractor (remaining epochs)")
447
+ print("=" * 80)
448
+
449
+ # Continue training
450
+ trainer.train()
451
+ else:
452
+ print(f"\n⚠️ Training stopped before warmup completed. Current epoch: {current_epoch}")
453
+
454
+ # Save final model
455
+ print("\nπŸ’Ύ Saving final model and processor...")
456
  trainer.save_model()
457
  processor.save_pretrained(args.output_dir)
458
+
459
+ # Verify label mapping is saved correctly
460
+ saved_config = AutoConfig.from_pretrained(args.output_dir)
461
+ print(f"\nβœ… Saved model label mapping:")
462
+ print(f" id2label: {saved_config.id2label}")
463
+ print(f" label2id: {saved_config.label2id}")
464
+
465
  if args.push_to_hub:
466
+ print("\nπŸ“€ Pushing to Hugging Face Hub...")
467
  trainer.push_to_hub()
468
+
469
+ print(f"\nβœ… Training complete! Model saved to: {args.output_dir}")
470
+ print("=" * 80)
471
 
472
 
473
  if __name__ == "__main__":
474
  main()