| |
| """Manual LoRA fine-tuning for emotion2vec_plus_base β 7-class emotion. |
| |
| Wraps frozen attention layers with low-rank adapters (LoRALinear), |
| replaces the 9-class proj head with a 7-class MLPHead, and trains |
| with FocalLoss + disgust F1 gating. |
| |
| Usage: |
| # Quick smoke test (CPU) |
| python scripts/train_lora_emotion2vec.py \ |
| --train-manifest data/lora_7class/train_manifest.json \ |
| --val-manifest data/lora_7class/val_manifest.json \ |
| --output-dir data/models/lora_emotion2vec_7class \ |
| --epochs 3 --device cpu |
| |
| # Full training (GPU, RTX 5050 8GB) |
| PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ |
| python scripts/train_lora_emotion2vec.py \ |
| --train-manifest data/lora_7class/train_manifest.json \ |
| --val-manifest data/lora_7class/val_manifest.json \ |
| --output-dir data/models/lora_emotion2vec_7class \ |
| --epochs 20 --batch-size 4 --accumulate-steps 8 --device cuda |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] |
| LABEL2IDX = {label: i for i, label in enumerate(LABELS_7CLASS)} |
| NUM_CLASSES = len(LABELS_7CLASS) |
| DISGUST_IDX = LABEL2IDX["disgust"] |
|
|
|
|
| |
| |
| |
|
|
| class LoRALinear(nn.Module): |
| """Low-rank adapter wrapping a frozen nn.Linear. |
| |
| At init, B is zero so LoRA contribution is zero (original behavior preserved). |
| scaling = alpha / r controls the magnitude of the LoRA update. |
| """ |
|
|
| def __init__(self, original: nn.Linear, r: int = 16, alpha: int = 32, dropout: float = 0.1): |
| super().__init__() |
| self.original = original |
| self.r = r |
| self.scaling = alpha / r |
|
|
| |
| self.original.weight.requires_grad = False |
| if self.original.bias is not None: |
| self.original.bias.requires_grad = False |
|
|
| in_features = original.in_features |
| out_features = original.out_features |
|
|
| |
| self.lora_A = nn.Linear(in_features, r, bias=False) |
| self.lora_B = nn.Linear(r, out_features, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| nn.init.kaiming_uniform_(self.lora_A.weight) |
| nn.init.zeros_(self.lora_B.weight) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| base_out = self.original(x) |
| lora_out = self.lora_B(self.dropout(self.lora_A(x))) * self.scaling |
| return base_out + lora_out |
|
|
|
|
| def merge_lora_linear(lora: LoRALinear) -> nn.Linear: |
| """Merge LoRA weights into a plain nn.Linear for inference. |
| |
| W_merged = W_original + scaling * B.weight @ A.weight |
| """ |
| with torch.no_grad(): |
| merged_weight = ( |
| lora.original.weight |
| + lora.scaling * lora.lora_B.weight @ lora.lora_A.weight |
| ) |
| bias = lora.original.bias |
|
|
| merged = nn.Linear( |
| lora.original.in_features, |
| lora.original.out_features, |
| bias=bias is not None, |
| ) |
| merged.weight = nn.Parameter(merged_weight) |
| if bias is not None: |
| merged.bias = nn.Parameter(bias.clone()) |
| return merged |
|
|
|
|
| def inject_lora( |
| encoder: nn.Module, |
| r: int = 16, |
| alpha: int = 32, |
| dropout: float = 0.1, |
| ) -> None: |
| """Replace attn.qkv and attn.proj in each block with LoRALinear. |
| |
| emotion2vec uses FUSED QKV: attn.qkv: Linear(768, 2304) |
| and attn.proj: Linear(768, 768). 8 blocks total. |
| MLP layers are NOT wrapped. |
| """ |
| for block in encoder.blocks: |
| block.attn.qkv = LoRALinear(block.attn.qkv, r=r, alpha=alpha, dropout=dropout) |
| block.attn.proj = LoRALinear(block.attn.proj, r=r, alpha=alpha, dropout=dropout) |
|
|
|
|
| |
| |
| |
|
|
| class MLPHead(nn.Module): |
| """Multi-layer classification head: 768 β 512 β 256 β num_classes.""" |
|
|
| def __init__(self, in_dim: int = 768, num_classes: int = NUM_CLASSES, dropout: float = 0.3): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, 512), |
| nn.BatchNorm1d(512), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(512, 256), |
| nn.BatchNorm1d(256), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(256, num_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class FocalLoss(nn.Module): |
| """Focal Loss with optional label smoothing and class weights.""" |
|
|
| def __init__(self, weight=None, gamma: float = 2.0, label_smoothing: float = 0.05): |
| super().__init__() |
| self.gamma = gamma |
| self.weight = weight |
| self.label_smoothing = label_smoothing |
|
|
| def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| ce_loss = F.cross_entropy( |
| logits, targets, weight=self.weight, |
| label_smoothing=self.label_smoothing, reduction="none", |
| ) |
| pt = torch.exp(-ce_loss) |
| focal_loss = ((1 - pt) ** self.gamma) * ce_loss |
| return focal_loss.mean() |
|
|
|
|
| |
| |
| |
|
|
| class EmotionDataset(Dataset): |
| """Load audio from unified manifest JSON for 7-class LoRA training. |
| |
| Manifest format: list of {"audio_path": str, "label": str, ...} |
| OR {"path": str, "label": str, ...} (for backward compat). |
| """ |
|
|
| def __init__( |
| self, |
| manifest_path: str, |
| max_duration_sec: float = 8.0, |
| phone_augment_prob: float = 0.0, |
| noise_augment_prob: float = 0.0, |
| ): |
| import torchaudio |
|
|
| with open(manifest_path, encoding="utf-8") as f: |
| self.samples = json.load(f) |
|
|
| self.max_samples = int(max_duration_sec * 16000) |
| self.phone_augment_prob = phone_augment_prob |
| self.noise_augment_prob = noise_augment_prob |
| self._phone_sim = None |
|
|
| def _get_phone_sim(self): |
| if self._phone_sim is None: |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) |
| from common.phone_simulator import PhoneSimulator, CompandingType |
| self._phone_sim = PhoneSimulator(companding=CompandingType.ALAW) |
| return self._phone_sim |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| import torchaudio |
|
|
| sample = self.samples[idx] |
| |
| audio_path = sample.get("audio_path") or sample.get("path", "") |
|
|
| waveform, sr = torchaudio.load(audio_path) |
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| waveform = waveform.squeeze(0) |
|
|
| |
| if sr != 16000: |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) |
|
|
| |
| if waveform.shape[0] > self.max_samples: |
| waveform = waveform[:self.max_samples] |
|
|
| audio = waveform.numpy() |
|
|
| |
| r = random.random() |
| if r < self.phone_augment_prob: |
| sim = self._get_phone_sim() |
| audio, _ = sim.process(audio, 16000) |
| import librosa |
| audio = librosa.resample(audio, orig_sr=8000, target_sr=16000) |
| elif r < self.phone_augment_prob + self.noise_augment_prob: |
| snr_db = random.uniform(10, 20) |
| signal_power = np.mean(audio ** 2) |
| noise_power = signal_power / (10 ** (snr_db / 10)) |
| noise = np.random.normal(0, np.sqrt(max(noise_power, 1e-10)), len(audio)).astype(np.float32) |
| audio = audio + noise |
|
|
| label = LABEL2IDX[sample["label"]] |
| return torch.tensor(audio, dtype=torch.float32), label |
|
|
|
|
| def collate_fn(batch): |
| """Pad waveforms to same length in batch.""" |
| waveforms, labels = zip(*batch) |
| max_len = max(w.shape[0] for w in waveforms) |
| padded = torch.zeros(len(waveforms), max_len) |
| for i, w in enumerate(waveforms): |
| padded[i, :w.shape[0]] = w |
| return padded, torch.tensor(labels, dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| def load_model(device: str, r: int = 16, alpha: int = 32, dropout: float = 0.1): |
| """Load emotion2vec_plus_base, freeze all, inject LoRA, replace proj.""" |
| from funasr import AutoModel |
|
|
| logger.info("Loading emotion2vec_plus_base...") |
| fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf") |
| encoder = fmodel.model |
|
|
| |
| for param in encoder.parameters(): |
| param.requires_grad = False |
|
|
| |
| inject_lora(encoder, r=r, alpha=alpha, dropout=dropout) |
|
|
| |
| old_proj = encoder.proj |
| encoder.proj = MLPHead(768, NUM_CLASSES) |
| logger.info("Replaced proj: Linear(768, %d) -> MLPHead(768->512->256->%d)", |
| old_proj.out_features, NUM_CLASSES) |
|
|
| |
| encoder = encoder.to(device) |
| return encoder |
|
|
|
|
| def forward_pass(encoder, waveforms: torch.Tensor, device: str) -> torch.Tensor: |
| """Differentiable forward pass through emotion2vec with LoRA. |
| |
| Args: |
| encoder: emotion2vec model with LoRA injected |
| waveforms: (B, T) float32, 16kHz |
| device: compute device |
| |
| Returns: |
| logits: (B, 7) |
| """ |
| waveforms = waveforms.to(device) |
|
|
| |
| if encoder.cfg.normalize: |
| normed = [] |
| for i in range(waveforms.shape[0]): |
| normed.append(F.layer_norm(waveforms[i], waveforms[i].shape)) |
| waveforms = torch.stack(normed) |
|
|
| |
| feats = encoder.extract_features(waveforms, padding_mask=None) |
| x = feats["x"] |
|
|
| |
| pooled = x.mean(dim=1) |
| logits = encoder.proj(pooled) |
| return logits |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def validate(encoder, val_loader, device, criterion): |
| """Run validation, return metrics including disgust_f1.""" |
| encoder.eval() |
| total_loss = 0 |
| y_true, y_pred = [], [] |
|
|
| for waveforms, labels in val_loader: |
| labels = labels.to(device) |
| logits = forward_pass(encoder, waveforms, device) |
| loss = criterion(logits, labels) |
| total_loss += loss.item() * labels.size(0) |
|
|
| preds = logits.argmax(dim=-1) |
| y_true.extend(labels.cpu().tolist()) |
| y_pred.extend(preds.cpu().tolist()) |
|
|
| from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix |
|
|
| accuracy = accuracy_score(y_true, y_pred) |
| _, _, f1_per_class, _ = precision_recall_fscore_support( |
| y_true, y_pred, labels=list(range(NUM_CLASSES)), average=None, zero_division=0, |
| ) |
| macro_f1 = float(np.mean(f1_per_class)) |
| disgust_f1 = float(f1_per_class[DISGUST_IDX]) |
|
|
| per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)} |
| avg_loss = total_loss / max(len(y_true), 1) |
| cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES))) |
|
|
| return { |
| "loss": round(avg_loss, 4), |
| "accuracy": round(accuracy, 4), |
| "macro_f1": round(macro_f1, 4), |
| "disgust_f1": round(disgust_f1, 4), |
| "per_class_f1": per_class, |
| "confusion_matrix": cm.tolist(), |
| "y_true": y_true, |
| "y_pred": y_pred, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def save_lora_checkpoint(encoder, path: Path, epoch: int, metrics: dict, |
| best_f1: float = 0.0, patience_counter: int = 0, |
| optimizer=None, scheduler=None, scaler=None, |
| training_log=None): |
| """Save LoRA weights + MLPHead + optimizer state for resume.""" |
| lora_state = {} |
| for name, module in encoder.named_modules(): |
| if isinstance(module, LoRALinear): |
| lora_state[f"{name}.lora_A.weight"] = module.lora_A.weight.data.cpu() |
| lora_state[f"{name}.lora_B.weight"] = module.lora_B.weight.data.cpu() |
|
|
| state = { |
| "lora_weights": lora_state, |
| "proj": encoder.proj.state_dict(), |
| "epoch": epoch, |
| "metrics": metrics, |
| "best_f1": best_f1, |
| "patience_counter": patience_counter, |
| "num_classes": NUM_CLASSES, |
| "labels": LABELS_7CLASS, |
| } |
| if optimizer is not None: |
| state["optimizer"] = optimizer.state_dict() |
| if scheduler is not None: |
| state["scheduler"] = scheduler.state_dict() |
| if scaler is not None: |
| state["scaler"] = scaler.state_dict() |
| if training_log is not None: |
| state["training_log"] = training_log |
| torch.save(state, str(path)) |
|
|
|
|
| def load_lora_checkpoint(encoder, path: Path, device: str, |
| optimizer=None, scheduler=None, scaler=None): |
| """Load LoRA checkpoint and restore training state.""" |
| logger.info("Resuming from checkpoint: %s", path) |
| ckpt = torch.load(str(path), map_location=device, weights_only=False) |
|
|
| |
| for name, module in encoder.named_modules(): |
| if isinstance(module, LoRALinear): |
| a_key = f"{name}.lora_A.weight" |
| b_key = f"{name}.lora_B.weight" |
| if a_key in ckpt["lora_weights"]: |
| module.lora_A.weight.data.copy_(ckpt["lora_weights"][a_key].to(device)) |
| module.lora_B.weight.data.copy_(ckpt["lora_weights"][b_key].to(device)) |
|
|
| |
| encoder.proj.load_state_dict(ckpt["proj"]) |
|
|
| |
| if optimizer is not None and "optimizer" in ckpt: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| if scheduler is not None and "scheduler" in ckpt: |
| scheduler.load_state_dict(ckpt["scheduler"]) |
| if scaler is not None and "scaler" in ckpt: |
| scaler.load_state_dict(ckpt["scaler"]) |
|
|
| logger.info("Resumed from epoch %d (best_f1=%.4f, patience=%d)", |
| ckpt["epoch"], ckpt["best_f1"], ckpt["patience_counter"]) |
| return { |
| "epoch": ckpt["epoch"], |
| "best_f1": ckpt["best_f1"], |
| "patience_counter": ckpt["patience_counter"], |
| "training_log": ckpt.get("training_log", []), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def plot_confusion_matrix(cm, output_path: Path, epoch: int): |
| """Save confusion matrix as PNG.""" |
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| fig, ax = plt.subplots(figsize=(9, 7)) |
| cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1) |
| sns.heatmap( |
| cm_norm, annot=True, fmt=".2f", cmap="Blues", |
| xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax, |
| ) |
| for i in range(NUM_CLASSES): |
| for j in range(NUM_CLASSES): |
| ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})", |
| ha="center", va="center", fontsize=6, color="gray") |
|
|
| ax.set_xlabel("Predicted") |
| ax.set_ylabel("True") |
| ax.set_title(f"LoRA 7-Class Confusion Matrix (Epoch {epoch})") |
| plt.tight_layout() |
| plt.savefig(str(output_path), dpi=150) |
| plt.close() |
| logger.info("Confusion matrix saved to %s", output_path) |
| except ImportError: |
| logger.warning("matplotlib/seaborn not available β skipping confusion matrix plot") |
|
|
|
|
| |
| |
| |
|
|
| def train(args): |
| device = args.device |
| use_amp = (device == "cuda") |
| accumulate_steps = args.accumulate_steps |
|
|
| |
| encoder = load_model(device, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout) |
|
|
| |
| trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in encoder.parameters()) |
| logger.info("Trainable: %dK / Total: %dK (%.1f%%)", |
| trainable // 1000, total // 1000, 100 * trainable / total) |
|
|
| |
| train_dataset = EmotionDataset( |
| args.train_manifest, |
| max_duration_sec=8.0, |
| phone_augment_prob=args.phone_augment_prob, |
| noise_augment_prob=args.noise_augment_prob, |
| ) |
| val_dataset = EmotionDataset(args.val_manifest, max_duration_sec=8.0) |
|
|
| logger.info("Train: %d samples, Val: %d samples", len(train_dataset), len(val_dataset)) |
| logger.info("AMP: %s, Accumulate: %d, Effective batch: %d", |
| use_amp, accumulate_steps, args.batch_size * accumulate_steps) |
|
|
| n_workers = 2 |
| use_pin = (device == "cuda") |
| train_loader = DataLoader( |
| train_dataset, batch_size=args.batch_size, shuffle=True, |
| collate_fn=collate_fn, num_workers=n_workers, drop_last=True, |
| pin_memory=use_pin, persistent_workers=(n_workers > 0), |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=args.batch_size, shuffle=False, |
| collate_fn=collate_fn, num_workers=n_workers, |
| pin_memory=use_pin, persistent_workers=(n_workers > 0), |
| ) |
|
|
| |
| class_counts = np.zeros(NUM_CLASSES) |
| for sample in train_dataset.samples: |
| class_counts[LABEL2IDX[sample["label"]]] += 1 |
| class_weights = 1.0 / np.maximum(class_counts, 1) |
| class_weights = class_weights / class_weights.sum() * NUM_CLASSES |
| class_weights[DISGUST_IDX] *= 2.5 |
| logger.info("Class weights: %s", |
| {LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)}) |
|
|
| criterion = FocalLoss( |
| weight=torch.tensor(class_weights, dtype=torch.float32).to(device), |
| gamma=2.0, |
| label_smoothing=0.05, |
| ) |
|
|
| |
| lora_params = [] |
| proj_params = list(encoder.proj.parameters()) |
| proj_ids = {id(p) for p in proj_params} |
| for name, param in encoder.named_parameters(): |
| if param.requires_grad and id(param) not in proj_ids: |
| lora_params.append(param) |
|
|
| optimizer = torch.optim.AdamW([ |
| {"params": lora_params, "lr": args.lora_lr}, |
| {"params": proj_params, "lr": args.proj_lr}, |
| ], weight_decay=args.weight_decay) |
|
|
| |
| steps_per_epoch = max(len(train_loader) // accumulate_steps, 1) |
| total_steps = steps_per_epoch * args.epochs |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=[args.lora_lr, args.proj_lr], |
| total_steps=total_steps, |
| pct_start=0.1, |
| anneal_strategy="cos", |
| ) |
|
|
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| scaler = torch.amp.GradScaler("cuda", enabled=use_amp) |
|
|
| |
| training_log = [] |
| best_f1 = 0.0 |
| patience_counter = 0 |
| start_epoch = 1 |
| disgust_gate = 0.3 |
|
|
| |
| if args.resume: |
| resume_path = output_dir / "last_lora.pt" |
| if resume_path.exists(): |
| resume_state = load_lora_checkpoint( |
| encoder, resume_path, device, |
| optimizer=optimizer, scheduler=scheduler, scaler=scaler, |
| ) |
| start_epoch = resume_state["epoch"] + 1 |
| best_f1 = resume_state["best_f1"] |
| patience_counter = resume_state["patience_counter"] |
| training_log = resume_state["training_log"] |
| logger.info("Resuming training from epoch %d (best_f1=%.4f)", start_epoch, best_f1) |
| else: |
| logger.warning("--resume set but no checkpoint found at %s, starting fresh", resume_path) |
|
|
| for epoch in range(start_epoch, args.epochs + 1): |
| epoch_start = time.time() |
| encoder.train() |
| total_loss = 0 |
| correct = 0 |
| total_samples = 0 |
| optimizer.zero_grad() |
|
|
| for batch_idx, (waveforms, labels) in enumerate(train_loader): |
| labels = labels.to(device) |
|
|
| with torch.amp.autocast("cuda", enabled=use_amp): |
| logits = forward_pass(encoder, waveforms, device) |
| loss = criterion(logits, labels) / accumulate_steps |
|
|
| scaler.scale(loss).backward() |
|
|
| if (batch_idx + 1) % accumulate_steps == 0 or (batch_idx + 1) == len(train_loader): |
| scaler.unscale_(optimizer) |
| trainable_params = [p for p in encoder.parameters() if p.requires_grad] |
| torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| scheduler.step() |
|
|
| total_loss += loss.item() * accumulate_steps * labels.size(0) |
| preds = logits.argmax(dim=-1) |
| correct += (preds == labels).sum().item() |
| total_samples += labels.size(0) |
|
|
| if (batch_idx + 1) % 10 == 0: |
| cur_lr = optimizer.param_groups[0]["lr"] |
| logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e", |
| epoch, batch_idx + 1, len(train_loader), |
| loss.item() * accumulate_steps, cur_lr) |
|
|
| train_loss = total_loss / max(total_samples, 1) |
| train_acc = correct / max(total_samples, 1) |
|
|
| |
| val_metrics = validate(encoder, val_loader, device, criterion) |
|
|
| epoch_time = time.time() - epoch_start |
| logger.info( |
| "Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | " |
| "val_loss=%.4f val_f1=%.3f val_acc=%.3f disgust_f1=%.3f", |
| epoch, args.epochs, epoch_time, |
| train_loss, train_acc, |
| val_metrics["loss"], val_metrics["macro_f1"], |
| val_metrics["accuracy"], val_metrics["disgust_f1"], |
| ) |
| logger.info(" Per-class F1: %s", val_metrics["per_class_f1"]) |
|
|
| epoch_log = { |
| "epoch": epoch, |
| "train_loss": round(train_loss, 4), |
| "train_acc": round(train_acc, 4), |
| "val_loss": val_metrics["loss"], |
| "val_accuracy": val_metrics["accuracy"], |
| "val_macro_f1": val_metrics["macro_f1"], |
| "val_disgust_f1": val_metrics["disgust_f1"], |
| "val_per_class_f1": val_metrics["per_class_f1"], |
| "epoch_time_sec": round(epoch_time, 1), |
| } |
| training_log.append(epoch_log) |
|
|
| |
| gate_pass = val_metrics["disgust_f1"] >= disgust_gate |
| if not gate_pass: |
| logger.warning("[GATE FAIL] disgust_f1=%.3f < %.1f β not saving as best", |
| val_metrics["disgust_f1"], disgust_gate) |
|
|
| if val_metrics["macro_f1"] > best_f1 and gate_pass: |
| best_f1 = val_metrics["macro_f1"] |
| patience_counter = 0 |
| save_lora_checkpoint( |
| encoder, output_dir / "best_lora.pt", epoch, val_metrics, best_f1, |
| optimizer=optimizer, scheduler=scheduler, scaler=scaler, |
| training_log=training_log, |
| ) |
| logger.info(" New best! macro_f1=%.4f (gate passed) saved to best_lora.pt", best_f1) |
|
|
| if "confusion_matrix" in val_metrics: |
| cm = np.array(val_metrics["confusion_matrix"]) |
| plot_confusion_matrix(cm, output_dir / f"confusion_matrix_epoch{epoch}.png", epoch) |
| elif gate_pass: |
| patience_counter += 1 |
| else: |
| |
| pass |
|
|
| |
| save_lora_checkpoint( |
| encoder, output_dir / "last_lora.pt", epoch, val_metrics, best_f1, patience_counter, |
| optimizer=optimizer, scheduler=scheduler, scaler=scaler, |
| training_log=training_log, |
| ) |
|
|
| |
| with open(output_dir / "training_log.json", "w") as f: |
| json.dump(training_log, f, indent=2) |
|
|
| |
| if patience_counter >= args.patience: |
| logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience) |
| break |
|
|
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| |
| config = { |
| "model": "iic/emotion2vec_plus_base", |
| "method": "LoRA", |
| "lora_r": args.lora_r, |
| "lora_alpha": args.lora_alpha, |
| "num_classes": NUM_CLASSES, |
| "labels": LABELS_7CLASS, |
| "label2idx": LABEL2IDX, |
| "best_val_f1": best_f1, |
| "training_args": vars(args), |
| } |
| with open(output_dir / "config.json", "w") as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
| logger.info("Training complete. Best F1=%.4f at %s", best_f1, output_dir / "best_lora.pt") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="LoRA fine-tune emotion2vec 7-class") |
| parser.add_argument("--train-manifest", required=True) |
| parser.add_argument("--val-manifest", required=True) |
| parser.add_argument("--output-dir", default="data/models/lora_emotion2vec_7class") |
| parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) |
| parser.add_argument("--epochs", type=int, default=20) |
| parser.add_argument("--batch-size", type=int, default=4) |
| parser.add_argument("--accumulate-steps", type=int, default=8) |
| parser.add_argument("--lora-r", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=int, default=32) |
| parser.add_argument("--lora-dropout", type=float, default=0.1) |
| parser.add_argument("--lora-lr", type=float, default=2e-4) |
| parser.add_argument("--proj-lr", type=float, default=2e-3) |
| parser.add_argument("--weight-decay", type=float, default=0.01) |
| parser.add_argument("--patience", type=int, default=7) |
| parser.add_argument("--phone-augment-prob", type=float, default=0.3) |
| parser.add_argument("--noise-augment-prob", type=float, default=0.15) |
| parser.add_argument("--resume", action="store_true", |
| help="Resume from last_lora.pt checkpoint in output-dir") |
| args = parser.parse_args() |
|
|
| train(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|