Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Improved Wav2Vec2 RAVDESS Emotion Detection Training Script | |
| Fixes: | |
| - 25 epochs for proper convergence | |
| - Feature extractor freeze/unfreeze strategy | |
| - Balanced class weights for imbalanced dataset | |
| - Proper audio normalization (16kHz, amplitude) | |
| - Gaussian noise augmentation | |
| - Correct label mapping | |
| """ | |
| import argparse | |
| import glob | |
| import io | |
| import inspect | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| import evaluate | |
| import librosa | |
| import numpy as np | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.utils.class_weight import compute_class_weight | |
| from torch.nn.utils.rnn import pad_sequence | |
| from datasets import Dataset | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| AutoConfig, | |
| AutoProcessor, | |
| Trainer, | |
| TrainingArguments, | |
| Wav2Vec2ForSequenceClassification, | |
| set_seed, | |
| ) | |
| class DataCollatorWithPadding: | |
| processor: AutoProcessor | |
| padding: bool = True | |
| def __call__(self, features: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: | |
| input_tensors = [ | |
| torch.as_tensor(feature["input_values"], dtype=torch.float32) | |
| for feature in features | |
| ] | |
| padded_inputs = pad_sequence( | |
| input_tensors, | |
| batch_first=True, | |
| padding_value=0.0, | |
| ) | |
| if "attention_mask" in features[0]: | |
| attention_tensors = [ | |
| torch.as_tensor(feature["attention_mask"], dtype=torch.long) | |
| for feature in features | |
| ] | |
| padded_attention = pad_sequence( | |
| attention_tensors, | |
| batch_first=True, | |
| padding_value=0, | |
| ) | |
| else: | |
| padded_attention = (padded_inputs != 0.0).long() | |
| labels = torch.tensor([feature["labels"] for feature in features], dtype=torch.long) | |
| return { | |
| "input_values": padded_inputs, | |
| "attention_mask": padded_attention, | |
| "labels": labels, | |
| } | |
| class WeightedTrainer(Trainer): | |
| """Trainer with weighted loss for imbalanced classes""" | |
| def __init__(self, class_weights=None, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.class_weights = class_weights | |
| if class_weights is not None: | |
| self.class_weights = torch.tensor(class_weights, dtype=torch.float32) | |
| if torch.cuda.is_available(): | |
| self.class_weights = self.class_weights.cuda() | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| labels = inputs.get("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") | |
| if self.class_weights is not None: | |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) | |
| else: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
| return (loss, outputs) if return_outputs else loss | |
| def compute_metrics(eval_pred): | |
| accuracy_metric = evaluate.load("accuracy") | |
| predictions, labels = eval_pred | |
| preds = np.argmax(predictions, axis=1) | |
| # Also compute per-class metrics | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| report = classification_report(labels, preds, output_dict=True, zero_division=0) | |
| return { | |
| "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"], | |
| "macro_f1": report.get("macro avg", {}).get("f1-score", 0.0), | |
| "weighted_f1": report.get("weighted avg", {}).get("f1-score", 0.0), | |
| } | |
| def add_gaussian_noise(audio: np.ndarray, noise_factor: float = 0.01) -> np.ndarray: | |
| """Add small Gaussian noise for augmentation""" | |
| noise = np.random.normal(0, noise_factor, audio.shape).astype(np.float32) | |
| return np.clip(audio + noise, -1.0, 1.0) | |
| def prepare_dataset(batch, processor, sampling_rate, augment: bool = False): | |
| """ | |
| Prepare dataset with proper audio normalization and optional augmentation. | |
| - Enforces 16kHz resampling | |
| - Normalizes amplitude to [-1, 1] | |
| - Optionally adds small Gaussian noise | |
| """ | |
| audio_arrays: List[np.ndarray] = [] | |
| for audio_bytes in batch["audio_bytes"]: | |
| # Read audio | |
| with io.BytesIO(audio_bytes) as buffer: | |
| waveform, source_sr = sf.read(buffer, dtype='float32') | |
| # Ensure mono | |
| if waveform.ndim > 1: | |
| waveform = np.mean(waveform, axis=1) | |
| # Enforce 16kHz resampling | |
| if source_sr != sampling_rate: | |
| waveform = librosa.resample( | |
| waveform, | |
| orig_sr=source_sr, | |
| target_sr=sampling_rate, | |
| res_type='kaiser_best' | |
| ) | |
| # Normalize amplitude to [-1, 1] range | |
| max_val = np.abs(waveform).max() | |
| if max_val > 0: | |
| waveform = waveform / max_val | |
| # Ensure float32 | |
| waveform = waveform.astype(np.float32) | |
| # Apply augmentation (only for training) | |
| if augment: | |
| waveform = add_gaussian_noise(waveform, noise_factor=0.01) | |
| audio_arrays.append(waveform) | |
| # Process with feature extractor | |
| processed = processor( | |
| audio_arrays, | |
| sampling_rate=sampling_rate, | |
| return_attention_mask=True, | |
| ) | |
| batch["input_values"] = [ | |
| np.asarray(array, dtype=np.float32) for array in processed["input_values"] | |
| ] | |
| if "attention_mask" in processed: | |
| batch["attention_mask"] = [ | |
| np.asarray(mask, dtype=np.int64) for mask in processed["attention_mask"] | |
| ] | |
| batch["labels"] = [int(label) for label in batch["label"]] | |
| return batch | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train Wav2Vec2 on RAVDESS emotion dataset") | |
| parser.add_argument("--model_name_or_path", default="facebook/wav2vec2-base-960h") | |
| default_output_dir = os.path.join(os.path.dirname(__file__), "wav2vec2-ravdess-emotion") | |
| parser.add_argument("--output_dir", default=default_output_dir) | |
| parser.add_argument("--dataset_name", default="confit/ravdess-parquet") | |
| parser.add_argument("--dataset_config", default="fold1") | |
| parser.add_argument("--train_split", default="train") | |
| parser.add_argument("--eval_split", default="test") | |
| parser.add_argument("--sampling_rate", type=int, default=16_000) | |
| parser.add_argument("--num_train_epochs", type=float, default=25.0) | |
| parser.add_argument("--warmup_epochs", type=int, default=3, help="Epochs with frozen feature extractor") | |
| parser.add_argument("--per_device_train_batch_size", type=int, default=4) | |
| parser.add_argument("--per_device_eval_batch_size", type=int, default=4) | |
| parser.add_argument("--learning_rate", type=float, default=3e-5) | |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) | |
| parser.add_argument("--weight_decay", type=float, default=0.01) | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=2) | |
| parser.add_argument("--seed", type=int, default=1337) | |
| parser.add_argument("--max_train_samples", type=int, default=None) | |
| parser.add_argument("--max_eval_samples", type=int, default=None) | |
| parser.add_argument("--push_to_hub", action="store_true") | |
| parser.add_argument("--hub_model_id", default=None) | |
| parser.add_argument("--hub_private_repo", action="store_true") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| set_seed(args.seed) | |
| print("=" * 80) | |
| print("Wav2Vec2 RAVDESS Emotion Detection Training") | |
| print("=" * 80) | |
| print(f"Model: {args.model_name_or_path}") | |
| print(f"Epochs: {args.num_train_epochs} (warmup: {args.warmup_epochs})") | |
| print(f"Learning rate: {args.learning_rate}") | |
| print(f"Batch size: {args.per_device_train_batch_size} (gradient accumulation: {args.gradient_accumulation_steps})") | |
| print("=" * 80) | |
| # Download dataset | |
| print("\n📥 Downloading RAVDESS dataset...") | |
| snapshot_path = snapshot_download( | |
| repo_id=args.dataset_name, | |
| repo_type="dataset", | |
| cache_dir=os.getenv("HF_HOME"), | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| split_root = os.path.join(snapshot_path, args.dataset_config) if args.dataset_config else snapshot_path | |
| def load_split(split_name: str): | |
| pattern = os.path.join(split_root, f"{split_name}-*.parquet") | |
| parquet_files = sorted(glob.glob(pattern)) | |
| if not parquet_files: | |
| return None | |
| tables = [pq.read_table(path) for path in parquet_files] | |
| table = pa.concat_tables(tables) | |
| data = table.to_pydict() | |
| return { | |
| "audio_bytes": [entry["bytes"] for entry in data["audio"]], | |
| "label": [int(label) for label in data["label"]], | |
| "emotion": data["emotion"], | |
| "file": data["file"], | |
| } | |
| train_dict = load_split(args.train_split) | |
| if train_dict is None: | |
| raise ValueError(f"Could not locate parquet files for split '{args.train_split}' in {split_root}") | |
| eval_dict = load_split(args.eval_split) | |
| train_dataset = Dataset.from_dict(train_dict) | |
| if eval_dict is not None: | |
| eval_dataset = Dataset.from_dict(eval_dict) | |
| else: | |
| split_dataset = train_dataset.train_test_split(test_size=0.1, seed=args.seed) | |
| train_dataset = split_dataset["train"] | |
| eval_dataset = split_dataset["test"] | |
| print(f"✅ Train samples: {len(train_dataset)}") | |
| print(f"✅ Eval samples: {len(eval_dataset)}") | |
| # Build label mapping (consistent id2label / label2id) | |
| print("\n📊 Building label mapping...") | |
| label_names = {} | |
| for label, emotion in zip(train_dataset["label"], train_dataset["emotion"]): | |
| label_names[int(label)] = emotion | |
| # Ensure consistent ordering | |
| id2label = {idx: label_names[idx] for idx in sorted(label_names)} | |
| label2id = {name: idx for idx, name in id2label.items()} | |
| print(f"✅ Labels ({len(id2label)}): {list(id2label.values())}") | |
| print(f"✅ Label mapping: {id2label}") | |
| # Compute class weights for balanced training | |
| print("\n⚖️ Computing class weights for balanced training...") | |
| labels_array = np.array(train_dataset["label"]) | |
| unique_labels = np.unique(labels_array) | |
| class_weights = compute_class_weight( | |
| 'balanced', | |
| classes=unique_labels, | |
| y=labels_array | |
| ) | |
| class_weight_dict = dict(zip(unique_labels, class_weights)) | |
| class_weight_list = [class_weight_dict[i] for i in sorted(unique_labels)] | |
| print(f"✅ Class weights: {dict(zip([id2label[i] for i in sorted(unique_labels)], class_weight_list))}") | |
| # Load processor and config | |
| print("\n📦 Loading processor and config...") | |
| processor = AutoProcessor.from_pretrained( | |
| args.model_name_or_path, | |
| cache_dir=os.getenv("HF_HOME"), | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| args.model_name_or_path, | |
| num_labels=len(label2id), | |
| label2id=label2id, | |
| id2label=id2label, | |
| finetuning_task="wav2vec2_emotion", | |
| cache_dir=os.getenv("HF_HOME"), | |
| ) | |
| # Verify label mapping in config | |
| print(f"✅ Config labels: {config.id2label}") | |
| assert config.label2id == label2id, "Label mapping mismatch!" | |
| assert config.id2label == id2label, "Label mapping mismatch!" | |
| # Prepare datasets with proper normalization | |
| print("\n🔄 Preparing training dataset (with augmentation)...") | |
| processed_train_dataset = train_dataset.map( | |
| prepare_dataset, | |
| fn_kwargs=dict( | |
| processor=processor, | |
| sampling_rate=args.sampling_rate, | |
| augment=True, # Add noise augmentation for training | |
| ), | |
| remove_columns=["audio_bytes", "file", "emotion", "label"], | |
| batched=True, | |
| batch_size=8, | |
| num_proc=1, | |
| ) | |
| print("🔄 Preparing evaluation dataset (no augmentation)...") | |
| processed_eval_dataset = eval_dataset.map( | |
| prepare_dataset, | |
| fn_kwargs=dict( | |
| processor=processor, | |
| sampling_rate=args.sampling_rate, | |
| augment=False, # No augmentation for eval | |
| ), | |
| remove_columns=["audio_bytes", "file", "emotion", "label"], | |
| batched=True, | |
| batch_size=8, | |
| num_proc=1, | |
| ) | |
| if args.max_train_samples: | |
| processed_train_dataset = processed_train_dataset.select(range(args.max_train_samples)) | |
| if args.max_eval_samples: | |
| processed_eval_dataset = processed_eval_dataset.select(range(args.max_eval_samples)) | |
| # Load model | |
| print("\n🤖 Loading model...") | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained( | |
| args.model_name_or_path, | |
| config=config, | |
| cache_dir=os.getenv("HF_HOME"), | |
| ) | |
| # Freeze feature extractor initially | |
| print("🔒 Freezing feature extractor for warmup...") | |
| model.freeze_feature_extractor() | |
| data_collator = DataCollatorWithPadding(processor=processor) | |
| # Training arguments | |
| requested_training_arguments = dict( | |
| output_dir=args.output_dir, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| num_train_epochs=args.num_train_epochs, | |
| learning_rate=args.learning_rate, | |
| warmup_ratio=args.warmup_ratio, | |
| weight_decay=args.weight_decay, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| fp16=torch.cuda.is_available(), | |
| group_by_length=True, | |
| dataloader_num_workers=min(4, os.cpu_count() or 1), | |
| logging_steps=25, | |
| save_total_limit=3, # Keep only last 3 checkpoints | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| greater_is_better=True, | |
| push_to_hub=args.push_to_hub, | |
| hub_model_id=args.hub_model_id, | |
| hub_private_repo=args.hub_private_repo, | |
| report_to="none", # Disable wandb/tensorboard | |
| ) | |
| # Filter to supported arguments | |
| training_args_signature = inspect.signature(TrainingArguments) | |
| supported_training_arguments = { | |
| key: value | |
| for key, value in requested_training_arguments.items() | |
| if key in training_args_signature.parameters | |
| } | |
| if "evaluation_strategy" not in supported_training_arguments: | |
| supported_training_arguments.pop("save_strategy", None) | |
| supported_training_arguments.pop("load_best_model_at_end", None) | |
| supported_training_arguments.pop("metric_for_best_model", None) | |
| training_args = TrainingArguments(**supported_training_arguments) | |
| # Create trainer with weighted loss | |
| trainer = WeightedTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=processed_train_dataset, | |
| eval_dataset=processed_eval_dataset, | |
| tokenizer=processor, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| class_weights=class_weight_list, | |
| ) | |
| # Phase 1: Train with frozen feature extractor (warmup) | |
| print("\n" + "=" * 80) | |
| print(f"PHASE 1: Training with FROZEN feature extractor ({args.warmup_epochs} epochs)") | |
| print("=" * 80) | |
| # Calculate steps for warmup | |
| total_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.num_train_epochs | |
| warmup_steps = int(total_steps * args.warmup_ratio) | |
| warmup_epochs_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.warmup_epochs | |
| # Train for warmup epochs | |
| trainer.train() | |
| # Check if we've completed warmup epochs | |
| current_epoch = trainer.state.epoch | |
| if current_epoch >= args.warmup_epochs: | |
| print(f"\n✅ Completed {args.warmup_epochs} warmup epochs") | |
| print("🔓 Unfreezing feature extractor...") | |
| model.unfreeze_feature_extractor() | |
| print("✅ Feature extractor unfrozen!") | |
| # Phase 2: Continue training with unfrozen feature extractor | |
| print("\n" + "=" * 80) | |
| print(f"PHASE 2: Training with UNFROZEN feature extractor (remaining epochs)") | |
| print("=" * 80) | |
| # Continue training | |
| trainer.train() | |
| else: | |
| print(f"\n⚠️ Training stopped before warmup completed. Current epoch: {current_epoch}") | |
| # Save final model | |
| print("\n💾 Saving final model and processor...") | |
| trainer.save_model() | |
| processor.save_pretrained(args.output_dir) | |
| # Verify label mapping is saved correctly | |
| saved_config = AutoConfig.from_pretrained(args.output_dir) | |
| print(f"\n✅ Saved model label mapping:") | |
| print(f" id2label: {saved_config.id2label}") | |
| print(f" label2id: {saved_config.label2id}") | |
| if args.push_to_hub: | |
| print("\n📤 Pushing to Hugging Face Hub...") | |
| trainer.push_to_hub() | |
| print(f"\n✅ Training complete! Model saved to: {args.output_dir}") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| main() | |