#!/usr/bin/env python """ Improved Wav2Vec2 RAVDESS Emotion Detection Training Script Fixes: - 25 epochs for proper convergence - Feature extractor freeze/unfreeze strategy - Balanced class weights for imbalanced dataset - Proper audio normalization (16kHz, amplitude) - Gaussian noise augmentation - Correct label mapping """ import argparse import glob import io import inspect import os from dataclasses import dataclass from typing import Dict, List import evaluate import librosa import numpy as np import pyarrow as pa import pyarrow.parquet as pq import soundfile as sf import torch import torch.nn as nn from sklearn.utils.class_weight import compute_class_weight from torch.nn.utils.rnn import pad_sequence from datasets import Dataset from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, Trainer, TrainingArguments, Wav2Vec2ForSequenceClassification, set_seed, ) @dataclass class DataCollatorWithPadding: processor: AutoProcessor padding: bool = True def __call__(self, features: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: input_tensors = [ torch.as_tensor(feature["input_values"], dtype=torch.float32) for feature in features ] padded_inputs = pad_sequence( input_tensors, batch_first=True, padding_value=0.0, ) if "attention_mask" in features[0]: attention_tensors = [ torch.as_tensor(feature["attention_mask"], dtype=torch.long) for feature in features ] padded_attention = pad_sequence( attention_tensors, batch_first=True, padding_value=0, ) else: padded_attention = (padded_inputs != 0.0).long() labels = torch.tensor([feature["labels"] for feature in features], dtype=torch.long) return { "input_values": padded_inputs, "attention_mask": padded_attention, "labels": labels, } class WeightedTrainer(Trainer): """Trainer with weighted loss for imbalanced classes""" def __init__(self, class_weights=None, *args, **kwargs): super().__init__(*args, **kwargs) self.class_weights = class_weights if class_weights is not None: self.class_weights = torch.tensor(class_weights, dtype=torch.float32) if torch.cuda.is_available(): self.class_weights = self.class_weights.cuda() def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.get("labels") outputs = model(**inputs) logits = outputs.get("logits") if self.class_weights is not None: loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss def compute_metrics(eval_pred): accuracy_metric = evaluate.load("accuracy") predictions, labels = eval_pred preds = np.argmax(predictions, axis=1) # Also compute per-class metrics from sklearn.metrics import classification_report, confusion_matrix report = classification_report(labels, preds, output_dict=True, zero_division=0) return { "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"], "macro_f1": report.get("macro avg", {}).get("f1-score", 0.0), "weighted_f1": report.get("weighted avg", {}).get("f1-score", 0.0), } def add_gaussian_noise(audio: np.ndarray, noise_factor: float = 0.01) -> np.ndarray: """Add small Gaussian noise for augmentation""" noise = np.random.normal(0, noise_factor, audio.shape).astype(np.float32) return np.clip(audio + noise, -1.0, 1.0) def prepare_dataset(batch, processor, sampling_rate, augment: bool = False): """ Prepare dataset with proper audio normalization and optional augmentation. - Enforces 16kHz resampling - Normalizes amplitude to [-1, 1] - Optionally adds small Gaussian noise """ audio_arrays: List[np.ndarray] = [] for audio_bytes in batch["audio_bytes"]: # Read audio with io.BytesIO(audio_bytes) as buffer: waveform, source_sr = sf.read(buffer, dtype='float32') # Ensure mono if waveform.ndim > 1: waveform = np.mean(waveform, axis=1) # Enforce 16kHz resampling if source_sr != sampling_rate: waveform = librosa.resample( waveform, orig_sr=source_sr, target_sr=sampling_rate, res_type='kaiser_best' ) # Normalize amplitude to [-1, 1] range max_val = np.abs(waveform).max() if max_val > 0: waveform = waveform / max_val # Ensure float32 waveform = waveform.astype(np.float32) # Apply augmentation (only for training) if augment: waveform = add_gaussian_noise(waveform, noise_factor=0.01) audio_arrays.append(waveform) # Process with feature extractor processed = processor( audio_arrays, sampling_rate=sampling_rate, return_attention_mask=True, ) batch["input_values"] = [ np.asarray(array, dtype=np.float32) for array in processed["input_values"] ] if "attention_mask" in processed: batch["attention_mask"] = [ np.asarray(mask, dtype=np.int64) for mask in processed["attention_mask"] ] batch["labels"] = [int(label) for label in batch["label"]] return batch def parse_args(): parser = argparse.ArgumentParser(description="Train Wav2Vec2 on RAVDESS emotion dataset") parser.add_argument("--model_name_or_path", default="facebook/wav2vec2-base-960h") default_output_dir = os.path.join(os.path.dirname(__file__), "wav2vec2-ravdess-emotion") parser.add_argument("--output_dir", default=default_output_dir) parser.add_argument("--dataset_name", default="confit/ravdess-parquet") parser.add_argument("--dataset_config", default="fold1") parser.add_argument("--train_split", default="train") parser.add_argument("--eval_split", default="test") parser.add_argument("--sampling_rate", type=int, default=16_000) parser.add_argument("--num_train_epochs", type=float, default=25.0) parser.add_argument("--warmup_epochs", type=int, default=3, help="Epochs with frozen feature extractor") parser.add_argument("--per_device_train_batch_size", type=int, default=4) parser.add_argument("--per_device_eval_batch_size", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=3e-5) parser.add_argument("--warmup_ratio", type=float, default=0.1) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--gradient_accumulation_steps", type=int, default=2) parser.add_argument("--seed", type=int, default=1337) parser.add_argument("--max_train_samples", type=int, default=None) parser.add_argument("--max_eval_samples", type=int, default=None) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_model_id", default=None) parser.add_argument("--hub_private_repo", action="store_true") return parser.parse_args() def main(): args = parse_args() set_seed(args.seed) print("=" * 80) print("Wav2Vec2 RAVDESS Emotion Detection Training") print("=" * 80) print(f"Model: {args.model_name_or_path}") print(f"Epochs: {args.num_train_epochs} (warmup: {args.warmup_epochs})") print(f"Learning rate: {args.learning_rate}") print(f"Batch size: {args.per_device_train_batch_size} (gradient accumulation: {args.gradient_accumulation_steps})") print("=" * 80) # Download dataset print("\nšŸ“„ Downloading RAVDESS dataset...") snapshot_path = snapshot_download( repo_id=args.dataset_name, repo_type="dataset", cache_dir=os.getenv("HF_HOME"), token=os.getenv("HF_TOKEN"), ) split_root = os.path.join(snapshot_path, args.dataset_config) if args.dataset_config else snapshot_path def load_split(split_name: str): pattern = os.path.join(split_root, f"{split_name}-*.parquet") parquet_files = sorted(glob.glob(pattern)) if not parquet_files: return None tables = [pq.read_table(path) for path in parquet_files] table = pa.concat_tables(tables) data = table.to_pydict() return { "audio_bytes": [entry["bytes"] for entry in data["audio"]], "label": [int(label) for label in data["label"]], "emotion": data["emotion"], "file": data["file"], } train_dict = load_split(args.train_split) if train_dict is None: raise ValueError(f"Could not locate parquet files for split '{args.train_split}' in {split_root}") eval_dict = load_split(args.eval_split) train_dataset = Dataset.from_dict(train_dict) if eval_dict is not None: eval_dataset = Dataset.from_dict(eval_dict) else: split_dataset = train_dataset.train_test_split(test_size=0.1, seed=args.seed) train_dataset = split_dataset["train"] eval_dataset = split_dataset["test"] print(f"āœ… Train samples: {len(train_dataset)}") print(f"āœ… Eval samples: {len(eval_dataset)}") # Build label mapping (consistent id2label / label2id) print("\nšŸ“Š Building label mapping...") label_names = {} for label, emotion in zip(train_dataset["label"], train_dataset["emotion"]): label_names[int(label)] = emotion # Ensure consistent ordering id2label = {idx: label_names[idx] for idx in sorted(label_names)} label2id = {name: idx for idx, name in id2label.items()} print(f"āœ… Labels ({len(id2label)}): {list(id2label.values())}") print(f"āœ… Label mapping: {id2label}") # Compute class weights for balanced training print("\nāš–ļø Computing class weights for balanced training...") labels_array = np.array(train_dataset["label"]) unique_labels = np.unique(labels_array) class_weights = compute_class_weight( 'balanced', classes=unique_labels, y=labels_array ) class_weight_dict = dict(zip(unique_labels, class_weights)) class_weight_list = [class_weight_dict[i] for i in sorted(unique_labels)] print(f"āœ… Class weights: {dict(zip([id2label[i] for i in sorted(unique_labels)], class_weight_list))}") # Load processor and config print("\nšŸ“¦ Loading processor and config...") processor = AutoProcessor.from_pretrained( args.model_name_or_path, cache_dir=os.getenv("HF_HOME"), ) config = AutoConfig.from_pretrained( args.model_name_or_path, num_labels=len(label2id), label2id=label2id, id2label=id2label, finetuning_task="wav2vec2_emotion", cache_dir=os.getenv("HF_HOME"), ) # Verify label mapping in config print(f"āœ… Config labels: {config.id2label}") assert config.label2id == label2id, "Label mapping mismatch!" assert config.id2label == id2label, "Label mapping mismatch!" # Prepare datasets with proper normalization print("\nšŸ”„ Preparing training dataset (with augmentation)...") processed_train_dataset = train_dataset.map( prepare_dataset, fn_kwargs=dict( processor=processor, sampling_rate=args.sampling_rate, augment=True, # Add noise augmentation for training ), remove_columns=["audio_bytes", "file", "emotion", "label"], batched=True, batch_size=8, num_proc=1, ) print("šŸ”„ Preparing evaluation dataset (no augmentation)...") processed_eval_dataset = eval_dataset.map( prepare_dataset, fn_kwargs=dict( processor=processor, sampling_rate=args.sampling_rate, augment=False, # No augmentation for eval ), remove_columns=["audio_bytes", "file", "emotion", "label"], batched=True, batch_size=8, num_proc=1, ) if args.max_train_samples: processed_train_dataset = processed_train_dataset.select(range(args.max_train_samples)) if args.max_eval_samples: processed_eval_dataset = processed_eval_dataset.select(range(args.max_eval_samples)) # Load model print("\nšŸ¤– Loading model...") model = Wav2Vec2ForSequenceClassification.from_pretrained( args.model_name_or_path, config=config, cache_dir=os.getenv("HF_HOME"), ) # Freeze feature extractor initially print("šŸ”’ Freezing feature extractor for warmup...") model.freeze_feature_extractor() data_collator = DataCollatorWithPadding(processor=processor) # Training arguments requested_training_arguments = dict( output_dir=args.output_dir, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, evaluation_strategy="epoch", save_strategy="epoch", num_train_epochs=args.num_train_epochs, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, gradient_accumulation_steps=args.gradient_accumulation_steps, fp16=torch.cuda.is_available(), group_by_length=True, dataloader_num_workers=min(4, os.cpu_count() or 1), logging_steps=25, save_total_limit=3, # Keep only last 3 checkpoints load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, hub_private_repo=args.hub_private_repo, report_to="none", # Disable wandb/tensorboard ) # Filter to supported arguments training_args_signature = inspect.signature(TrainingArguments) supported_training_arguments = { key: value for key, value in requested_training_arguments.items() if key in training_args_signature.parameters } if "evaluation_strategy" not in supported_training_arguments: supported_training_arguments.pop("save_strategy", None) supported_training_arguments.pop("load_best_model_at_end", None) supported_training_arguments.pop("metric_for_best_model", None) training_args = TrainingArguments(**supported_training_arguments) # Create trainer with weighted loss trainer = WeightedTrainer( model=model, args=training_args, train_dataset=processed_train_dataset, eval_dataset=processed_eval_dataset, tokenizer=processor, data_collator=data_collator, compute_metrics=compute_metrics, class_weights=class_weight_list, ) # Phase 1: Train with frozen feature extractor (warmup) print("\n" + "=" * 80) print(f"PHASE 1: Training with FROZEN feature extractor ({args.warmup_epochs} epochs)") print("=" * 80) # Calculate steps for warmup total_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.num_train_epochs warmup_steps = int(total_steps * args.warmup_ratio) warmup_epochs_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.warmup_epochs # Train for warmup epochs trainer.train() # Check if we've completed warmup epochs current_epoch = trainer.state.epoch if current_epoch >= args.warmup_epochs: print(f"\nāœ… Completed {args.warmup_epochs} warmup epochs") print("šŸ”“ Unfreezing feature extractor...") model.unfreeze_feature_extractor() print("āœ… Feature extractor unfrozen!") # Phase 2: Continue training with unfrozen feature extractor print("\n" + "=" * 80) print(f"PHASE 2: Training with UNFROZEN feature extractor (remaining epochs)") print("=" * 80) # Continue training trainer.train() else: print(f"\nāš ļø Training stopped before warmup completed. Current epoch: {current_epoch}") # Save final model print("\nšŸ’¾ Saving final model and processor...") trainer.save_model() processor.save_pretrained(args.output_dir) # Verify label mapping is saved correctly saved_config = AutoConfig.from_pretrained(args.output_dir) print(f"\nāœ… Saved model label mapping:") print(f" id2label: {saved_config.id2label}") print(f" label2id: {saved_config.label2id}") if args.push_to_hub: print("\nšŸ“¤ Pushing to Hugging Face Hub...") trainer.push_to_hub() print(f"\nāœ… Training complete! Model saved to: {args.output_dir}") print("=" * 80) if __name__ == "__main__": main()