aura-emotion-api / train_ravdess.py
monishaaura's picture
Update training script with retry logic and resampy dependency
90c88ff
#!/usr/bin/env python
"""
Improved Wav2Vec2 RAVDESS Emotion Detection Training Script
Fixes:
- 25 epochs for proper convergence
- Feature extractor freeze/unfreeze strategy
- Balanced class weights for imbalanced dataset
- Proper audio normalization (16kHz, amplitude)
- Gaussian noise augmentation
- Correct label mapping
"""
import argparse
import glob
import io
import inspect
import os
from dataclasses import dataclass
from typing import Dict, List
import evaluate
import librosa
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import soundfile as sf
import torch
import torch.nn as nn
from sklearn.utils.class_weight import compute_class_weight
from torch.nn.utils.rnn import pad_sequence
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
AutoProcessor,
Trainer,
TrainingArguments,
Wav2Vec2ForSequenceClassification,
set_seed,
)
@dataclass
class DataCollatorWithPadding:
processor: AutoProcessor
padding: bool = True
def __call__(self, features: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
input_tensors = [
torch.as_tensor(feature["input_values"], dtype=torch.float32)
for feature in features
]
padded_inputs = pad_sequence(
input_tensors,
batch_first=True,
padding_value=0.0,
)
if "attention_mask" in features[0]:
attention_tensors = [
torch.as_tensor(feature["attention_mask"], dtype=torch.long)
for feature in features
]
padded_attention = pad_sequence(
attention_tensors,
batch_first=True,
padding_value=0,
)
else:
padded_attention = (padded_inputs != 0.0).long()
labels = torch.tensor([feature["labels"] for feature in features], dtype=torch.long)
return {
"input_values": padded_inputs,
"attention_mask": padded_attention,
"labels": labels,
}
class WeightedTrainer(Trainer):
"""Trainer with weighted loss for imbalanced classes"""
def __init__(self, class_weights=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights
if class_weights is not None:
self.class_weights = torch.tensor(class_weights, dtype=torch.float32)
if torch.cuda.is_available():
self.class_weights = self.class_weights.cuda()
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
if self.class_weights is not None:
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
def compute_metrics(eval_pred):
accuracy_metric = evaluate.load("accuracy")
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
# Also compute per-class metrics
from sklearn.metrics import classification_report, confusion_matrix
report = classification_report(labels, preds, output_dict=True, zero_division=0)
return {
"accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
"macro_f1": report.get("macro avg", {}).get("f1-score", 0.0),
"weighted_f1": report.get("weighted avg", {}).get("f1-score", 0.0),
}
def add_gaussian_noise(audio: np.ndarray, noise_factor: float = 0.01) -> np.ndarray:
"""Add small Gaussian noise for augmentation"""
noise = np.random.normal(0, noise_factor, audio.shape).astype(np.float32)
return np.clip(audio + noise, -1.0, 1.0)
def prepare_dataset(batch, processor, sampling_rate, augment: bool = False):
"""
Prepare dataset with proper audio normalization and optional augmentation.
- Enforces 16kHz resampling
- Normalizes amplitude to [-1, 1]
- Optionally adds small Gaussian noise
"""
audio_arrays: List[np.ndarray] = []
for audio_bytes in batch["audio_bytes"]:
# Read audio
with io.BytesIO(audio_bytes) as buffer:
waveform, source_sr = sf.read(buffer, dtype='float32')
# Ensure mono
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=1)
# Enforce 16kHz resampling
if source_sr != sampling_rate:
waveform = librosa.resample(
waveform,
orig_sr=source_sr,
target_sr=sampling_rate,
res_type='kaiser_best'
)
# Normalize amplitude to [-1, 1] range
max_val = np.abs(waveform).max()
if max_val > 0:
waveform = waveform / max_val
# Ensure float32
waveform = waveform.astype(np.float32)
# Apply augmentation (only for training)
if augment:
waveform = add_gaussian_noise(waveform, noise_factor=0.01)
audio_arrays.append(waveform)
# Process with feature extractor
processed = processor(
audio_arrays,
sampling_rate=sampling_rate,
return_attention_mask=True,
)
batch["input_values"] = [
np.asarray(array, dtype=np.float32) for array in processed["input_values"]
]
if "attention_mask" in processed:
batch["attention_mask"] = [
np.asarray(mask, dtype=np.int64) for mask in processed["attention_mask"]
]
batch["labels"] = [int(label) for label in batch["label"]]
return batch
def parse_args():
parser = argparse.ArgumentParser(description="Train Wav2Vec2 on RAVDESS emotion dataset")
parser.add_argument("--model_name_or_path", default="facebook/wav2vec2-base-960h")
default_output_dir = os.path.join(os.path.dirname(__file__), "wav2vec2-ravdess-emotion")
parser.add_argument("--output_dir", default=default_output_dir)
parser.add_argument("--dataset_name", default="confit/ravdess-parquet")
parser.add_argument("--dataset_config", default="fold1")
parser.add_argument("--train_split", default="train")
parser.add_argument("--eval_split", default="test")
parser.add_argument("--sampling_rate", type=int, default=16_000)
parser.add_argument("--num_train_epochs", type=float, default=25.0)
parser.add_argument("--warmup_epochs", type=int, default=3, help="Epochs with frozen feature extractor")
parser.add_argument("--per_device_train_batch_size", type=int, default=4)
parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=3e-5)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--seed", type=int, default=1337)
parser.add_argument("--max_train_samples", type=int, default=None)
parser.add_argument("--max_eval_samples", type=int, default=None)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_model_id", default=None)
parser.add_argument("--hub_private_repo", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
set_seed(args.seed)
print("=" * 80)
print("Wav2Vec2 RAVDESS Emotion Detection Training")
print("=" * 80)
print(f"Model: {args.model_name_or_path}")
print(f"Epochs: {args.num_train_epochs} (warmup: {args.warmup_epochs})")
print(f"Learning rate: {args.learning_rate}")
print(f"Batch size: {args.per_device_train_batch_size} (gradient accumulation: {args.gradient_accumulation_steps})")
print("=" * 80)
# Download dataset
print("\n📥 Downloading RAVDESS dataset...")
snapshot_path = snapshot_download(
repo_id=args.dataset_name,
repo_type="dataset",
cache_dir=os.getenv("HF_HOME"),
token=os.getenv("HF_TOKEN"),
)
split_root = os.path.join(snapshot_path, args.dataset_config) if args.dataset_config else snapshot_path
def load_split(split_name: str):
pattern = os.path.join(split_root, f"{split_name}-*.parquet")
parquet_files = sorted(glob.glob(pattern))
if not parquet_files:
return None
tables = [pq.read_table(path) for path in parquet_files]
table = pa.concat_tables(tables)
data = table.to_pydict()
return {
"audio_bytes": [entry["bytes"] for entry in data["audio"]],
"label": [int(label) for label in data["label"]],
"emotion": data["emotion"],
"file": data["file"],
}
train_dict = load_split(args.train_split)
if train_dict is None:
raise ValueError(f"Could not locate parquet files for split '{args.train_split}' in {split_root}")
eval_dict = load_split(args.eval_split)
train_dataset = Dataset.from_dict(train_dict)
if eval_dict is not None:
eval_dataset = Dataset.from_dict(eval_dict)
else:
split_dataset = train_dataset.train_test_split(test_size=0.1, seed=args.seed)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
print(f"✅ Train samples: {len(train_dataset)}")
print(f"✅ Eval samples: {len(eval_dataset)}")
# Build label mapping (consistent id2label / label2id)
print("\n📊 Building label mapping...")
label_names = {}
for label, emotion in zip(train_dataset["label"], train_dataset["emotion"]):
label_names[int(label)] = emotion
# Ensure consistent ordering
id2label = {idx: label_names[idx] for idx in sorted(label_names)}
label2id = {name: idx for idx, name in id2label.items()}
print(f"✅ Labels ({len(id2label)}): {list(id2label.values())}")
print(f"✅ Label mapping: {id2label}")
# Compute class weights for balanced training
print("\n⚖️ Computing class weights for balanced training...")
labels_array = np.array(train_dataset["label"])
unique_labels = np.unique(labels_array)
class_weights = compute_class_weight(
'balanced',
classes=unique_labels,
y=labels_array
)
class_weight_dict = dict(zip(unique_labels, class_weights))
class_weight_list = [class_weight_dict[i] for i in sorted(unique_labels)]
print(f"✅ Class weights: {dict(zip([id2label[i] for i in sorted(unique_labels)], class_weight_list))}")
# Load processor and config
print("\n📦 Loading processor and config...")
processor = AutoProcessor.from_pretrained(
args.model_name_or_path,
cache_dir=os.getenv("HF_HOME"),
)
config = AutoConfig.from_pretrained(
args.model_name_or_path,
num_labels=len(label2id),
label2id=label2id,
id2label=id2label,
finetuning_task="wav2vec2_emotion",
cache_dir=os.getenv("HF_HOME"),
)
# Verify label mapping in config
print(f"✅ Config labels: {config.id2label}")
assert config.label2id == label2id, "Label mapping mismatch!"
assert config.id2label == id2label, "Label mapping mismatch!"
# Prepare datasets with proper normalization
print("\n🔄 Preparing training dataset (with augmentation)...")
processed_train_dataset = train_dataset.map(
prepare_dataset,
fn_kwargs=dict(
processor=processor,
sampling_rate=args.sampling_rate,
augment=True, # Add noise augmentation for training
),
remove_columns=["audio_bytes", "file", "emotion", "label"],
batched=True,
batch_size=8,
num_proc=1,
)
print("🔄 Preparing evaluation dataset (no augmentation)...")
processed_eval_dataset = eval_dataset.map(
prepare_dataset,
fn_kwargs=dict(
processor=processor,
sampling_rate=args.sampling_rate,
augment=False, # No augmentation for eval
),
remove_columns=["audio_bytes", "file", "emotion", "label"],
batched=True,
batch_size=8,
num_proc=1,
)
if args.max_train_samples:
processed_train_dataset = processed_train_dataset.select(range(args.max_train_samples))
if args.max_eval_samples:
processed_eval_dataset = processed_eval_dataset.select(range(args.max_eval_samples))
# Load model
print("\n🤖 Loading model...")
model = Wav2Vec2ForSequenceClassification.from_pretrained(
args.model_name_or_path,
config=config,
cache_dir=os.getenv("HF_HOME"),
)
# Freeze feature extractor initially
print("🔒 Freezing feature extractor for warmup...")
model.freeze_feature_extractor()
data_collator = DataCollatorWithPadding(processor=processor)
# Training arguments
requested_training_arguments = dict(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
gradient_accumulation_steps=args.gradient_accumulation_steps,
fp16=torch.cuda.is_available(),
group_by_length=True,
dataloader_num_workers=min(4, os.cpu_count() or 1),
logging_steps=25,
save_total_limit=3, # Keep only last 3 checkpoints
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
hub_private_repo=args.hub_private_repo,
report_to="none", # Disable wandb/tensorboard
)
# Filter to supported arguments
training_args_signature = inspect.signature(TrainingArguments)
supported_training_arguments = {
key: value
for key, value in requested_training_arguments.items()
if key in training_args_signature.parameters
}
if "evaluation_strategy" not in supported_training_arguments:
supported_training_arguments.pop("save_strategy", None)
supported_training_arguments.pop("load_best_model_at_end", None)
supported_training_arguments.pop("metric_for_best_model", None)
training_args = TrainingArguments(**supported_training_arguments)
# Create trainer with weighted loss
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=processed_train_dataset,
eval_dataset=processed_eval_dataset,
tokenizer=processor,
data_collator=data_collator,
compute_metrics=compute_metrics,
class_weights=class_weight_list,
)
# Phase 1: Train with frozen feature extractor (warmup)
print("\n" + "=" * 80)
print(f"PHASE 1: Training with FROZEN feature extractor ({args.warmup_epochs} epochs)")
print("=" * 80)
# Calculate steps for warmup
total_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.num_train_epochs
warmup_steps = int(total_steps * args.warmup_ratio)
warmup_epochs_steps = len(processed_train_dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps) * args.warmup_epochs
# Train for warmup epochs
trainer.train()
# Check if we've completed warmup epochs
current_epoch = trainer.state.epoch
if current_epoch >= args.warmup_epochs:
print(f"\n✅ Completed {args.warmup_epochs} warmup epochs")
print("🔓 Unfreezing feature extractor...")
model.unfreeze_feature_extractor()
print("✅ Feature extractor unfrozen!")
# Phase 2: Continue training with unfrozen feature extractor
print("\n" + "=" * 80)
print(f"PHASE 2: Training with UNFROZEN feature extractor (remaining epochs)")
print("=" * 80)
# Continue training
trainer.train()
else:
print(f"\n⚠️ Training stopped before warmup completed. Current epoch: {current_epoch}")
# Save final model
print("\n💾 Saving final model and processor...")
trainer.save_model()
processor.save_pretrained(args.output_dir)
# Verify label mapping is saved correctly
saved_config = AutoConfig.from_pretrained(args.output_dir)
print(f"\n✅ Saved model label mapping:")
print(f" id2label: {saved_config.id2label}")
print(f" label2id: {saved_config.label2id}")
if args.push_to_hub:
print("\n📤 Pushing to Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ Training complete! Model saved to: {args.output_dir}")
print("=" * 80)
if __name__ == "__main__":
main()