ustwo-api / scripts /train_lora_emotion2vec.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
29 kB
#!/usr/bin/env python3
"""Manual LoRA fine-tuning for emotion2vec_plus_base β€” 7-class emotion.
Wraps frozen attention layers with low-rank adapters (LoRALinear),
replaces the 9-class proj head with a 7-class MLPHead, and trains
with FocalLoss + disgust F1 gating.
Usage:
# Quick smoke test (CPU)
python scripts/train_lora_emotion2vec.py \
--train-manifest data/lora_7class/train_manifest.json \
--val-manifest data/lora_7class/val_manifest.json \
--output-dir data/models/lora_emotion2vec_7class \
--epochs 3 --device cpu
# Full training (GPU, RTX 5050 8GB)
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python scripts/train_lora_emotion2vec.py \
--train-manifest data/lora_7class/train_manifest.json \
--val-manifest data/lora_7class/val_manifest.json \
--output-dir data/models/lora_emotion2vec_7class \
--epochs 20 --batch-size 4 --accumulate-steps 8 --device cuda
"""
from __future__ import annotations
import argparse
import json
import logging
import random
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# 7-class label taxonomy (matches prepare_lora_dataset.py)
LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"]
LABEL2IDX = {label: i for i, label in enumerate(LABELS_7CLASS)}
NUM_CLASSES = len(LABELS_7CLASS)
DISGUST_IDX = LABEL2IDX["disgust"] # 2
# ──────────────────────────────────────────────
# LoRA Components
# ──────────────────────────────────────────────
class LoRALinear(nn.Module):
"""Low-rank adapter wrapping a frozen nn.Linear.
At init, B is zero so LoRA contribution is zero (original behavior preserved).
scaling = alpha / r controls the magnitude of the LoRA update.
"""
def __init__(self, original: nn.Linear, r: int = 16, alpha: int = 32, dropout: float = 0.1):
super().__init__()
self.original = original
self.r = r
self.scaling = alpha / r
# Freeze original weights
self.original.weight.requires_grad = False
if self.original.bias is not None:
self.original.bias.requires_grad = False
in_features = original.in_features
out_features = original.out_features
# Low-rank matrices
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.dropout = nn.Dropout(dropout)
# Init: A = kaiming, B = zero (so initial LoRA output = 0)
nn.init.kaiming_uniform_(self.lora_A.weight)
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
base_out = self.original(x)
lora_out = self.lora_B(self.dropout(self.lora_A(x))) * self.scaling
return base_out + lora_out
def merge_lora_linear(lora: LoRALinear) -> nn.Linear:
"""Merge LoRA weights into a plain nn.Linear for inference.
W_merged = W_original + scaling * B.weight @ A.weight
"""
with torch.no_grad():
merged_weight = (
lora.original.weight
+ lora.scaling * lora.lora_B.weight @ lora.lora_A.weight
)
bias = lora.original.bias
merged = nn.Linear(
lora.original.in_features,
lora.original.out_features,
bias=bias is not None,
)
merged.weight = nn.Parameter(merged_weight)
if bias is not None:
merged.bias = nn.Parameter(bias.clone())
return merged
def inject_lora(
encoder: nn.Module,
r: int = 16,
alpha: int = 32,
dropout: float = 0.1,
) -> None:
"""Replace attn.qkv and attn.proj in each block with LoRALinear.
emotion2vec uses FUSED QKV: attn.qkv: Linear(768, 2304)
and attn.proj: Linear(768, 768). 8 blocks total.
MLP layers are NOT wrapped.
"""
for block in encoder.blocks:
block.attn.qkv = LoRALinear(block.attn.qkv, r=r, alpha=alpha, dropout=dropout)
block.attn.proj = LoRALinear(block.attn.proj, r=r, alpha=alpha, dropout=dropout)
# ──────────────────────────────────────────────
# Model Components
# ──────────────────────────────────────────────
class MLPHead(nn.Module):
"""Multi-layer classification head: 768 β†’ 512 β†’ 256 β†’ num_classes."""
def __init__(self, in_dim: int = 768, num_classes: int = NUM_CLASSES, dropout: float = 0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class FocalLoss(nn.Module):
"""Focal Loss with optional label smoothing and class weights."""
def __init__(self, weight=None, gamma: float = 2.0, label_smoothing: float = 0.05):
super().__init__()
self.gamma = gamma
self.weight = weight
self.label_smoothing = label_smoothing
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce_loss = F.cross_entropy(
logits, targets, weight=self.weight,
label_smoothing=self.label_smoothing, reduction="none",
)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
return focal_loss.mean()
# ──────────────────────────────────────────────
# Dataset
# ──────────────────────────────────────────────
class EmotionDataset(Dataset):
"""Load audio from unified manifest JSON for 7-class LoRA training.
Manifest format: list of {"audio_path": str, "label": str, ...}
OR {"path": str, "label": str, ...} (for backward compat).
"""
def __init__(
self,
manifest_path: str,
max_duration_sec: float = 8.0,
phone_augment_prob: float = 0.0,
noise_augment_prob: float = 0.0,
):
import torchaudio # lazy import
with open(manifest_path, encoding="utf-8") as f:
self.samples = json.load(f)
self.max_samples = int(max_duration_sec * 16000)
self.phone_augment_prob = phone_augment_prob
self.noise_augment_prob = noise_augment_prob
self._phone_sim = None
def _get_phone_sim(self):
if self._phone_sim is None:
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from common.phone_simulator import PhoneSimulator, CompandingType
self._phone_sim = PhoneSimulator(companding=CompandingType.ALAW)
return self._phone_sim
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
import torchaudio
sample = self.samples[idx]
# Support both "audio_path" and "path" keys
audio_path = sample.get("audio_path") or sample.get("path", "")
waveform, sr = torchaudio.load(audio_path)
# Mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze(0) # (T,)
# Resample to 16kHz
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
# Truncate
if waveform.shape[0] > self.max_samples:
waveform = waveform[:self.max_samples]
audio = waveform.numpy()
# Augmentation
r = random.random()
if r < self.phone_augment_prob:
sim = self._get_phone_sim()
audio, _ = sim.process(audio, 16000)
import librosa
audio = librosa.resample(audio, orig_sr=8000, target_sr=16000)
elif r < self.phone_augment_prob + self.noise_augment_prob:
snr_db = random.uniform(10, 20)
signal_power = np.mean(audio ** 2)
noise_power = signal_power / (10 ** (snr_db / 10))
noise = np.random.normal(0, np.sqrt(max(noise_power, 1e-10)), len(audio)).astype(np.float32)
audio = audio + noise
label = LABEL2IDX[sample["label"]]
return torch.tensor(audio, dtype=torch.float32), label
def collate_fn(batch):
"""Pad waveforms to same length in batch."""
waveforms, labels = zip(*batch)
max_len = max(w.shape[0] for w in waveforms)
padded = torch.zeros(len(waveforms), max_len)
for i, w in enumerate(waveforms):
padded[i, :w.shape[0]] = w
return padded, torch.tensor(labels, dtype=torch.long)
# ──────────────────────────────────────────────
# Model Loading & Forward
# ──────────────────────────────────────────────
def load_model(device: str, r: int = 16, alpha: int = 32, dropout: float = 0.1):
"""Load emotion2vec_plus_base, freeze all, inject LoRA, replace proj."""
from funasr import AutoModel
logger.info("Loading emotion2vec_plus_base...")
fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf")
encoder = fmodel.model
# Freeze everything
for param in encoder.parameters():
param.requires_grad = False
# Inject LoRA adapters into attention layers
inject_lora(encoder, r=r, alpha=alpha, dropout=dropout)
# Replace 9-class proj with 7-class MLPHead
old_proj = encoder.proj
encoder.proj = MLPHead(768, NUM_CLASSES)
logger.info("Replaced proj: Linear(768, %d) -> MLPHead(768->512->256->%d)",
old_proj.out_features, NUM_CLASSES)
# Move entire model (including new LoRA params + MLPHead) to device
encoder = encoder.to(device)
return encoder
def forward_pass(encoder, waveforms: torch.Tensor, device: str) -> torch.Tensor:
"""Differentiable forward pass through emotion2vec with LoRA.
Args:
encoder: emotion2vec model with LoRA injected
waveforms: (B, T) float32, 16kHz
device: compute device
Returns:
logits: (B, 7)
"""
waveforms = waveforms.to(device)
# Layer norm (per emotion2vec inference)
if encoder.cfg.normalize:
normed = []
for i in range(waveforms.shape[0]):
normed.append(F.layer_norm(waveforms[i], waveforms[i].shape))
waveforms = torch.stack(normed)
# Extract features
feats = encoder.extract_features(waveforms, padding_mask=None)
x = feats["x"] # (B, T', 768)
# Mean pool + classify
pooled = x.mean(dim=1) # (B, 768)
logits = encoder.proj(pooled) # (B, 7)
return logits
# ──────────────────────────────────────────────
# Validation
# ──────────────────────────────────────────────
@torch.no_grad()
def validate(encoder, val_loader, device, criterion):
"""Run validation, return metrics including disgust_f1."""
encoder.eval()
total_loss = 0
y_true, y_pred = [], []
for waveforms, labels in val_loader:
labels = labels.to(device)
logits = forward_pass(encoder, waveforms, device)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
preds = logits.argmax(dim=-1)
y_true.extend(labels.cpu().tolist())
y_pred.extend(preds.cpu().tolist())
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
accuracy = accuracy_score(y_true, y_pred)
_, _, f1_per_class, _ = precision_recall_fscore_support(
y_true, y_pred, labels=list(range(NUM_CLASSES)), average=None, zero_division=0,
)
macro_f1 = float(np.mean(f1_per_class))
disgust_f1 = float(f1_per_class[DISGUST_IDX])
per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)}
avg_loss = total_loss / max(len(y_true), 1)
cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES)))
return {
"loss": round(avg_loss, 4),
"accuracy": round(accuracy, 4),
"macro_f1": round(macro_f1, 4),
"disgust_f1": round(disgust_f1, 4),
"per_class_f1": per_class,
"confusion_matrix": cm.tolist(),
"y_true": y_true,
"y_pred": y_pred,
}
# ──────────────────────────────────────────────
# Checkpoint
# ──────────────────────────────────────────────
def save_lora_checkpoint(encoder, path: Path, epoch: int, metrics: dict,
best_f1: float = 0.0, patience_counter: int = 0,
optimizer=None, scheduler=None, scaler=None,
training_log=None):
"""Save LoRA weights + MLPHead + optimizer state for resume."""
lora_state = {}
for name, module in encoder.named_modules():
if isinstance(module, LoRALinear):
lora_state[f"{name}.lora_A.weight"] = module.lora_A.weight.data.cpu()
lora_state[f"{name}.lora_B.weight"] = module.lora_B.weight.data.cpu()
state = {
"lora_weights": lora_state,
"proj": encoder.proj.state_dict(),
"epoch": epoch,
"metrics": metrics,
"best_f1": best_f1,
"patience_counter": patience_counter,
"num_classes": NUM_CLASSES,
"labels": LABELS_7CLASS,
}
if optimizer is not None:
state["optimizer"] = optimizer.state_dict()
if scheduler is not None:
state["scheduler"] = scheduler.state_dict()
if scaler is not None:
state["scaler"] = scaler.state_dict()
if training_log is not None:
state["training_log"] = training_log
torch.save(state, str(path))
def load_lora_checkpoint(encoder, path: Path, device: str,
optimizer=None, scheduler=None, scaler=None):
"""Load LoRA checkpoint and restore training state."""
logger.info("Resuming from checkpoint: %s", path)
ckpt = torch.load(str(path), map_location=device, weights_only=False)
# Restore LoRA weights
for name, module in encoder.named_modules():
if isinstance(module, LoRALinear):
a_key = f"{name}.lora_A.weight"
b_key = f"{name}.lora_B.weight"
if a_key in ckpt["lora_weights"]:
module.lora_A.weight.data.copy_(ckpt["lora_weights"][a_key].to(device))
module.lora_B.weight.data.copy_(ckpt["lora_weights"][b_key].to(device))
# Restore MLPHead
encoder.proj.load_state_dict(ckpt["proj"])
# Restore optimizer/scheduler/scaler if available
if optimizer is not None and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if scheduler is not None and "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
if scaler is not None and "scaler" in ckpt:
scaler.load_state_dict(ckpt["scaler"])
logger.info("Resumed from epoch %d (best_f1=%.4f, patience=%d)",
ckpt["epoch"], ckpt["best_f1"], ckpt["patience_counter"])
return {
"epoch": ckpt["epoch"],
"best_f1": ckpt["best_f1"],
"patience_counter": ckpt["patience_counter"],
"training_log": ckpt.get("training_log", []),
}
# ──────────────────────────────────────────────
# Confusion Matrix Plot
# ──────────────────────────────────────────────
def plot_confusion_matrix(cm, output_path: Path, epoch: int):
"""Save confusion matrix as PNG."""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(9, 7))
cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)
sns.heatmap(
cm_norm, annot=True, fmt=".2f", cmap="Blues",
xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax,
)
for i in range(NUM_CLASSES):
for j in range(NUM_CLASSES):
ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})",
ha="center", va="center", fontsize=6, color="gray")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title(f"LoRA 7-Class Confusion Matrix (Epoch {epoch})")
plt.tight_layout()
plt.savefig(str(output_path), dpi=150)
plt.close()
logger.info("Confusion matrix saved to %s", output_path)
except ImportError:
logger.warning("matplotlib/seaborn not available β€” skipping confusion matrix plot")
# ──────────────────────────────────────────────
# Training
# ──────────────────────────────────────────────
def train(args):
device = args.device
use_amp = (device == "cuda")
accumulate_steps = args.accumulate_steps
# Load model
encoder = load_model(device, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout)
# Log trainable vs total
trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
total = sum(p.numel() for p in encoder.parameters())
logger.info("Trainable: %dK / Total: %dK (%.1f%%)",
trainable // 1000, total // 1000, 100 * trainable / total)
# Datasets
train_dataset = EmotionDataset(
args.train_manifest,
max_duration_sec=8.0,
phone_augment_prob=args.phone_augment_prob,
noise_augment_prob=args.noise_augment_prob,
)
val_dataset = EmotionDataset(args.val_manifest, max_duration_sec=8.0)
logger.info("Train: %d samples, Val: %d samples", len(train_dataset), len(val_dataset))
logger.info("AMP: %s, Accumulate: %d, Effective batch: %d",
use_amp, accumulate_steps, args.batch_size * accumulate_steps)
n_workers = 2
use_pin = (device == "cuda")
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=n_workers, drop_last=True,
pin_memory=use_pin, persistent_workers=(n_workers > 0),
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_fn, num_workers=n_workers,
pin_memory=use_pin, persistent_workers=(n_workers > 0),
)
# Class weights: inverse frequency + disgust 2.5x boost
class_counts = np.zeros(NUM_CLASSES)
for sample in train_dataset.samples:
class_counts[LABEL2IDX[sample["label"]]] += 1
class_weights = 1.0 / np.maximum(class_counts, 1)
class_weights = class_weights / class_weights.sum() * NUM_CLASSES
class_weights[DISGUST_IDX] *= 2.5 # Disgust boost
logger.info("Class weights: %s",
{LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)})
criterion = FocalLoss(
weight=torch.tensor(class_weights, dtype=torch.float32).to(device),
gamma=2.0,
label_smoothing=0.05,
)
# Optimizer: differential LR
lora_params = []
proj_params = list(encoder.proj.parameters())
proj_ids = {id(p) for p in proj_params}
for name, param in encoder.named_parameters():
if param.requires_grad and id(param) not in proj_ids:
lora_params.append(param)
optimizer = torch.optim.AdamW([
{"params": lora_params, "lr": args.lora_lr},
{"params": proj_params, "lr": args.proj_lr},
], weight_decay=args.weight_decay)
# OneCycleLR scheduler
steps_per_epoch = max(len(train_loader) // accumulate_steps, 1)
total_steps = steps_per_epoch * args.epochs
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=[args.lora_lr, args.proj_lr],
total_steps=total_steps,
pct_start=0.1,
anneal_strategy="cos",
)
# Output dir
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# AMP scaler
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
# Training state
training_log = []
best_f1 = 0.0
patience_counter = 0
start_epoch = 1
disgust_gate = 0.3
# Resume from checkpoint if requested
if args.resume:
resume_path = output_dir / "last_lora.pt"
if resume_path.exists():
resume_state = load_lora_checkpoint(
encoder, resume_path, device,
optimizer=optimizer, scheduler=scheduler, scaler=scaler,
)
start_epoch = resume_state["epoch"] + 1
best_f1 = resume_state["best_f1"]
patience_counter = resume_state["patience_counter"]
training_log = resume_state["training_log"]
logger.info("Resuming training from epoch %d (best_f1=%.4f)", start_epoch, best_f1)
else:
logger.warning("--resume set but no checkpoint found at %s, starting fresh", resume_path)
for epoch in range(start_epoch, args.epochs + 1):
epoch_start = time.time()
encoder.train()
total_loss = 0
correct = 0
total_samples = 0
optimizer.zero_grad()
for batch_idx, (waveforms, labels) in enumerate(train_loader):
labels = labels.to(device)
with torch.amp.autocast("cuda", enabled=use_amp):
logits = forward_pass(encoder, waveforms, device)
loss = criterion(logits, labels) / accumulate_steps
scaler.scale(loss).backward()
if (batch_idx + 1) % accumulate_steps == 0 or (batch_idx + 1) == len(train_loader):
scaler.unscale_(optimizer)
trainable_params = [p for p in encoder.parameters() if p.requires_grad]
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
total_loss += loss.item() * accumulate_steps * labels.size(0)
preds = logits.argmax(dim=-1)
correct += (preds == labels).sum().item()
total_samples += labels.size(0)
if (batch_idx + 1) % 10 == 0:
cur_lr = optimizer.param_groups[0]["lr"]
logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e",
epoch, batch_idx + 1, len(train_loader),
loss.item() * accumulate_steps, cur_lr)
train_loss = total_loss / max(total_samples, 1)
train_acc = correct / max(total_samples, 1)
# Validate
val_metrics = validate(encoder, val_loader, device, criterion)
epoch_time = time.time() - epoch_start
logger.info(
"Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | "
"val_loss=%.4f val_f1=%.3f val_acc=%.3f disgust_f1=%.3f",
epoch, args.epochs, epoch_time,
train_loss, train_acc,
val_metrics["loss"], val_metrics["macro_f1"],
val_metrics["accuracy"], val_metrics["disgust_f1"],
)
logger.info(" Per-class F1: %s", val_metrics["per_class_f1"])
epoch_log = {
"epoch": epoch,
"train_loss": round(train_loss, 4),
"train_acc": round(train_acc, 4),
"val_loss": val_metrics["loss"],
"val_accuracy": val_metrics["accuracy"],
"val_macro_f1": val_metrics["macro_f1"],
"val_disgust_f1": val_metrics["disgust_f1"],
"val_per_class_f1": val_metrics["per_class_f1"],
"epoch_time_sec": round(epoch_time, 1),
}
training_log.append(epoch_log)
# Disgust F1 Gate + best model save
gate_pass = val_metrics["disgust_f1"] >= disgust_gate
if not gate_pass:
logger.warning("[GATE FAIL] disgust_f1=%.3f < %.1f β€” not saving as best",
val_metrics["disgust_f1"], disgust_gate)
if val_metrics["macro_f1"] > best_f1 and gate_pass:
best_f1 = val_metrics["macro_f1"]
patience_counter = 0
save_lora_checkpoint(
encoder, output_dir / "best_lora.pt", epoch, val_metrics, best_f1,
optimizer=optimizer, scheduler=scheduler, scaler=scaler,
training_log=training_log,
)
logger.info(" New best! macro_f1=%.4f (gate passed) saved to best_lora.pt", best_f1)
if "confusion_matrix" in val_metrics:
cm = np.array(val_metrics["confusion_matrix"])
plot_confusion_matrix(cm, output_dir / f"confusion_matrix_epoch{epoch}.png", epoch)
elif gate_pass:
patience_counter += 1
else:
# Gate fail does not increment patience
pass
# Always save last (with full state for resume)
save_lora_checkpoint(
encoder, output_dir / "last_lora.pt", epoch, val_metrics, best_f1, patience_counter,
optimizer=optimizer, scheduler=scheduler, scaler=scaler,
training_log=training_log,
)
# Save training log
with open(output_dir / "training_log.json", "w") as f:
json.dump(training_log, f, indent=2)
# Early stopping (only on gate-passing epochs)
if patience_counter >= args.patience:
logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience)
break
if device == "cuda":
torch.cuda.empty_cache()
# Save config
config = {
"model": "iic/emotion2vec_plus_base",
"method": "LoRA",
"lora_r": args.lora_r,
"lora_alpha": args.lora_alpha,
"num_classes": NUM_CLASSES,
"labels": LABELS_7CLASS,
"label2idx": LABEL2IDX,
"best_val_f1": best_f1,
"training_args": vars(args),
}
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info("Training complete. Best F1=%.4f at %s", best_f1, output_dir / "best_lora.pt")
# ──────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tune emotion2vec 7-class")
parser.add_argument("--train-manifest", required=True)
parser.add_argument("--val-manifest", required=True)
parser.add_argument("--output-dir", default="data/models/lora_emotion2vec_7class")
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"])
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--accumulate-steps", type=int, default=8)
parser.add_argument("--lora-r", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--lora-dropout", type=float, default=0.1)
parser.add_argument("--lora-lr", type=float, default=2e-4)
parser.add_argument("--proj-lr", type=float, default=2e-3)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--patience", type=int, default=7)
parser.add_argument("--phone-augment-prob", type=float, default=0.3)
parser.add_argument("--noise-augment-prob", type=float, default=0.15)
parser.add_argument("--resume", action="store_true",
help="Resume from last_lora.pt checkpoint in output-dir")
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()