#!/usr/bin/env python3 """Manual LoRA fine-tuning for emotion2vec_plus_base — 7-class emotion. Wraps frozen attention layers with low-rank adapters (LoRALinear), replaces the 9-class proj head with a 7-class MLPHead, and trains with FocalLoss + disgust F1 gating. Usage: # Quick smoke test (CPU) python scripts/train_lora_emotion2vec.py \ --train-manifest data/lora_7class/train_manifest.json \ --val-manifest data/lora_7class/val_manifest.json \ --output-dir data/models/lora_emotion2vec_7class \ --epochs 3 --device cpu # Full training (GPU, RTX 5050 8GB) PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ python scripts/train_lora_emotion2vec.py \ --train-manifest data/lora_7class/train_manifest.json \ --val-manifest data/lora_7class/val_manifest.json \ --output-dir data/models/lora_emotion2vec_7class \ --epochs 20 --batch-size 4 --accumulate-steps 8 --device cuda """ from __future__ import annotations import argparse import json import logging import random import sys import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # 7-class label taxonomy (matches prepare_lora_dataset.py) LABELS_7CLASS = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] LABEL2IDX = {label: i for i, label in enumerate(LABELS_7CLASS)} NUM_CLASSES = len(LABELS_7CLASS) DISGUST_IDX = LABEL2IDX["disgust"] # 2 # ────────────────────────────────────────────── # LoRA Components # ────────────────────────────────────────────── class LoRALinear(nn.Module): """Low-rank adapter wrapping a frozen nn.Linear. At init, B is zero so LoRA contribution is zero (original behavior preserved). scaling = alpha / r controls the magnitude of the LoRA update. """ def __init__(self, original: nn.Linear, r: int = 16, alpha: int = 32, dropout: float = 0.1): super().__init__() self.original = original self.r = r self.scaling = alpha / r # Freeze original weights self.original.weight.requires_grad = False if self.original.bias is not None: self.original.bias.requires_grad = False in_features = original.in_features out_features = original.out_features # Low-rank matrices self.lora_A = nn.Linear(in_features, r, bias=False) self.lora_B = nn.Linear(r, out_features, bias=False) self.dropout = nn.Dropout(dropout) # Init: A = kaiming, B = zero (so initial LoRA output = 0) nn.init.kaiming_uniform_(self.lora_A.weight) nn.init.zeros_(self.lora_B.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: base_out = self.original(x) lora_out = self.lora_B(self.dropout(self.lora_A(x))) * self.scaling return base_out + lora_out def merge_lora_linear(lora: LoRALinear) -> nn.Linear: """Merge LoRA weights into a plain nn.Linear for inference. W_merged = W_original + scaling * B.weight @ A.weight """ with torch.no_grad(): merged_weight = ( lora.original.weight + lora.scaling * lora.lora_B.weight @ lora.lora_A.weight ) bias = lora.original.bias merged = nn.Linear( lora.original.in_features, lora.original.out_features, bias=bias is not None, ) merged.weight = nn.Parameter(merged_weight) if bias is not None: merged.bias = nn.Parameter(bias.clone()) return merged def inject_lora( encoder: nn.Module, r: int = 16, alpha: int = 32, dropout: float = 0.1, ) -> None: """Replace attn.qkv and attn.proj in each block with LoRALinear. emotion2vec uses FUSED QKV: attn.qkv: Linear(768, 2304) and attn.proj: Linear(768, 768). 8 blocks total. MLP layers are NOT wrapped. """ for block in encoder.blocks: block.attn.qkv = LoRALinear(block.attn.qkv, r=r, alpha=alpha, dropout=dropout) block.attn.proj = LoRALinear(block.attn.proj, r=r, alpha=alpha, dropout=dropout) # ────────────────────────────────────────────── # Model Components # ────────────────────────────────────────────── class MLPHead(nn.Module): """Multi-layer classification head: 768 → 512 → 256 → num_classes.""" def __init__(self, in_dim: int = 768, num_classes: int = NUM_CLASSES, dropout: float = 0.3): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(dropout), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class FocalLoss(nn.Module): """Focal Loss with optional label smoothing and class weights.""" def __init__(self, weight=None, gamma: float = 2.0, label_smoothing: float = 0.05): super().__init__() self.gamma = gamma self.weight = weight self.label_smoothing = label_smoothing def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: ce_loss = F.cross_entropy( logits, targets, weight=self.weight, label_smoothing=self.label_smoothing, reduction="none", ) pt = torch.exp(-ce_loss) focal_loss = ((1 - pt) ** self.gamma) * ce_loss return focal_loss.mean() # ────────────────────────────────────────────── # Dataset # ────────────────────────────────────────────── class EmotionDataset(Dataset): """Load audio from unified manifest JSON for 7-class LoRA training. Manifest format: list of {"audio_path": str, "label": str, ...} OR {"path": str, "label": str, ...} (for backward compat). """ def __init__( self, manifest_path: str, max_duration_sec: float = 8.0, phone_augment_prob: float = 0.0, noise_augment_prob: float = 0.0, ): import torchaudio # lazy import with open(manifest_path, encoding="utf-8") as f: self.samples = json.load(f) self.max_samples = int(max_duration_sec * 16000) self.phone_augment_prob = phone_augment_prob self.noise_augment_prob = noise_augment_prob self._phone_sim = None def _get_phone_sim(self): if self._phone_sim is None: sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) from common.phone_simulator import PhoneSimulator, CompandingType self._phone_sim = PhoneSimulator(companding=CompandingType.ALAW) return self._phone_sim def __len__(self): return len(self.samples) def __getitem__(self, idx): import torchaudio sample = self.samples[idx] # Support both "audio_path" and "path" keys audio_path = sample.get("audio_path") or sample.get("path", "") waveform, sr = torchaudio.load(audio_path) # Mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.squeeze(0) # (T,) # Resample to 16kHz if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) # Truncate if waveform.shape[0] > self.max_samples: waveform = waveform[:self.max_samples] audio = waveform.numpy() # Augmentation r = random.random() if r < self.phone_augment_prob: sim = self._get_phone_sim() audio, _ = sim.process(audio, 16000) import librosa audio = librosa.resample(audio, orig_sr=8000, target_sr=16000) elif r < self.phone_augment_prob + self.noise_augment_prob: snr_db = random.uniform(10, 20) signal_power = np.mean(audio ** 2) noise_power = signal_power / (10 ** (snr_db / 10)) noise = np.random.normal(0, np.sqrt(max(noise_power, 1e-10)), len(audio)).astype(np.float32) audio = audio + noise label = LABEL2IDX[sample["label"]] return torch.tensor(audio, dtype=torch.float32), label def collate_fn(batch): """Pad waveforms to same length in batch.""" waveforms, labels = zip(*batch) max_len = max(w.shape[0] for w in waveforms) padded = torch.zeros(len(waveforms), max_len) for i, w in enumerate(waveforms): padded[i, :w.shape[0]] = w return padded, torch.tensor(labels, dtype=torch.long) # ────────────────────────────────────────────── # Model Loading & Forward # ────────────────────────────────────────────── def load_model(device: str, r: int = 16, alpha: int = 32, dropout: float = 0.1): """Load emotion2vec_plus_base, freeze all, inject LoRA, replace proj.""" from funasr import AutoModel logger.info("Loading emotion2vec_plus_base...") fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf") encoder = fmodel.model # Freeze everything for param in encoder.parameters(): param.requires_grad = False # Inject LoRA adapters into attention layers inject_lora(encoder, r=r, alpha=alpha, dropout=dropout) # Replace 9-class proj with 7-class MLPHead old_proj = encoder.proj encoder.proj = MLPHead(768, NUM_CLASSES) logger.info("Replaced proj: Linear(768, %d) -> MLPHead(768->512->256->%d)", old_proj.out_features, NUM_CLASSES) # Move entire model (including new LoRA params + MLPHead) to device encoder = encoder.to(device) return encoder def forward_pass(encoder, waveforms: torch.Tensor, device: str) -> torch.Tensor: """Differentiable forward pass through emotion2vec with LoRA. Args: encoder: emotion2vec model with LoRA injected waveforms: (B, T) float32, 16kHz device: compute device Returns: logits: (B, 7) """ waveforms = waveforms.to(device) # Layer norm (per emotion2vec inference) if encoder.cfg.normalize: normed = [] for i in range(waveforms.shape[0]): normed.append(F.layer_norm(waveforms[i], waveforms[i].shape)) waveforms = torch.stack(normed) # Extract features feats = encoder.extract_features(waveforms, padding_mask=None) x = feats["x"] # (B, T', 768) # Mean pool + classify pooled = x.mean(dim=1) # (B, 768) logits = encoder.proj(pooled) # (B, 7) return logits # ────────────────────────────────────────────── # Validation # ────────────────────────────────────────────── @torch.no_grad() def validate(encoder, val_loader, device, criterion): """Run validation, return metrics including disgust_f1.""" encoder.eval() total_loss = 0 y_true, y_pred = [], [] for waveforms, labels in val_loader: labels = labels.to(device) logits = forward_pass(encoder, waveforms, device) loss = criterion(logits, labels) total_loss += loss.item() * labels.size(0) preds = logits.argmax(dim=-1) y_true.extend(labels.cpu().tolist()) y_pred.extend(preds.cpu().tolist()) from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix accuracy = accuracy_score(y_true, y_pred) _, _, f1_per_class, _ = precision_recall_fscore_support( y_true, y_pred, labels=list(range(NUM_CLASSES)), average=None, zero_division=0, ) macro_f1 = float(np.mean(f1_per_class)) disgust_f1 = float(f1_per_class[DISGUST_IDX]) per_class = {LABELS_7CLASS[i]: round(float(f1_per_class[i]), 4) for i in range(NUM_CLASSES)} avg_loss = total_loss / max(len(y_true), 1) cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES))) return { "loss": round(avg_loss, 4), "accuracy": round(accuracy, 4), "macro_f1": round(macro_f1, 4), "disgust_f1": round(disgust_f1, 4), "per_class_f1": per_class, "confusion_matrix": cm.tolist(), "y_true": y_true, "y_pred": y_pred, } # ────────────────────────────────────────────── # Checkpoint # ────────────────────────────────────────────── def save_lora_checkpoint(encoder, path: Path, epoch: int, metrics: dict, best_f1: float = 0.0, patience_counter: int = 0, optimizer=None, scheduler=None, scaler=None, training_log=None): """Save LoRA weights + MLPHead + optimizer state for resume.""" lora_state = {} for name, module in encoder.named_modules(): if isinstance(module, LoRALinear): lora_state[f"{name}.lora_A.weight"] = module.lora_A.weight.data.cpu() lora_state[f"{name}.lora_B.weight"] = module.lora_B.weight.data.cpu() state = { "lora_weights": lora_state, "proj": encoder.proj.state_dict(), "epoch": epoch, "metrics": metrics, "best_f1": best_f1, "patience_counter": patience_counter, "num_classes": NUM_CLASSES, "labels": LABELS_7CLASS, } if optimizer is not None: state["optimizer"] = optimizer.state_dict() if scheduler is not None: state["scheduler"] = scheduler.state_dict() if scaler is not None: state["scaler"] = scaler.state_dict() if training_log is not None: state["training_log"] = training_log torch.save(state, str(path)) def load_lora_checkpoint(encoder, path: Path, device: str, optimizer=None, scheduler=None, scaler=None): """Load LoRA checkpoint and restore training state.""" logger.info("Resuming from checkpoint: %s", path) ckpt = torch.load(str(path), map_location=device, weights_only=False) # Restore LoRA weights for name, module in encoder.named_modules(): if isinstance(module, LoRALinear): a_key = f"{name}.lora_A.weight" b_key = f"{name}.lora_B.weight" if a_key in ckpt["lora_weights"]: module.lora_A.weight.data.copy_(ckpt["lora_weights"][a_key].to(device)) module.lora_B.weight.data.copy_(ckpt["lora_weights"][b_key].to(device)) # Restore MLPHead encoder.proj.load_state_dict(ckpt["proj"]) # Restore optimizer/scheduler/scaler if available if optimizer is not None and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"]) if scheduler is not None and "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"]) if scaler is not None and "scaler" in ckpt: scaler.load_state_dict(ckpt["scaler"]) logger.info("Resumed from epoch %d (best_f1=%.4f, patience=%d)", ckpt["epoch"], ckpt["best_f1"], ckpt["patience_counter"]) return { "epoch": ckpt["epoch"], "best_f1": ckpt["best_f1"], "patience_counter": ckpt["patience_counter"], "training_log": ckpt.get("training_log", []), } # ────────────────────────────────────────────── # Confusion Matrix Plot # ────────────────────────────────────────────── def plot_confusion_matrix(cm, output_path: Path, epoch: int): """Save confusion matrix as PNG.""" try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns fig, ax = plt.subplots(figsize=(9, 7)) cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1) sns.heatmap( cm_norm, annot=True, fmt=".2f", cmap="Blues", xticklabels=LABELS_7CLASS, yticklabels=LABELS_7CLASS, ax=ax, ) for i in range(NUM_CLASSES): for j in range(NUM_CLASSES): ax.text(j + 0.5, i + 0.7, f"({cm[i][j]})", ha="center", va="center", fontsize=6, color="gray") ax.set_xlabel("Predicted") ax.set_ylabel("True") ax.set_title(f"LoRA 7-Class Confusion Matrix (Epoch {epoch})") plt.tight_layout() plt.savefig(str(output_path), dpi=150) plt.close() logger.info("Confusion matrix saved to %s", output_path) except ImportError: logger.warning("matplotlib/seaborn not available — skipping confusion matrix plot") # ────────────────────────────────────────────── # Training # ────────────────────────────────────────────── def train(args): device = args.device use_amp = (device == "cuda") accumulate_steps = args.accumulate_steps # Load model encoder = load_model(device, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout) # Log trainable vs total trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad) total = sum(p.numel() for p in encoder.parameters()) logger.info("Trainable: %dK / Total: %dK (%.1f%%)", trainable // 1000, total // 1000, 100 * trainable / total) # Datasets train_dataset = EmotionDataset( args.train_manifest, max_duration_sec=8.0, phone_augment_prob=args.phone_augment_prob, noise_augment_prob=args.noise_augment_prob, ) val_dataset = EmotionDataset(args.val_manifest, max_duration_sec=8.0) logger.info("Train: %d samples, Val: %d samples", len(train_dataset), len(val_dataset)) logger.info("AMP: %s, Accumulate: %d, Effective batch: %d", use_amp, accumulate_steps, args.batch_size * accumulate_steps) n_workers = 2 use_pin = (device == "cuda") train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, drop_last=True, pin_memory=use_pin, persistent_workers=(n_workers > 0), ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, pin_memory=use_pin, persistent_workers=(n_workers > 0), ) # Class weights: inverse frequency + disgust 2.5x boost class_counts = np.zeros(NUM_CLASSES) for sample in train_dataset.samples: class_counts[LABEL2IDX[sample["label"]]] += 1 class_weights = 1.0 / np.maximum(class_counts, 1) class_weights = class_weights / class_weights.sum() * NUM_CLASSES class_weights[DISGUST_IDX] *= 2.5 # Disgust boost logger.info("Class weights: %s", {LABELS_7CLASS[i]: round(float(class_weights[i]), 3) for i in range(NUM_CLASSES)}) criterion = FocalLoss( weight=torch.tensor(class_weights, dtype=torch.float32).to(device), gamma=2.0, label_smoothing=0.05, ) # Optimizer: differential LR lora_params = [] proj_params = list(encoder.proj.parameters()) proj_ids = {id(p) for p in proj_params} for name, param in encoder.named_parameters(): if param.requires_grad and id(param) not in proj_ids: lora_params.append(param) optimizer = torch.optim.AdamW([ {"params": lora_params, "lr": args.lora_lr}, {"params": proj_params, "lr": args.proj_lr}, ], weight_decay=args.weight_decay) # OneCycleLR scheduler steps_per_epoch = max(len(train_loader) // accumulate_steps, 1) total_steps = steps_per_epoch * args.epochs scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=[args.lora_lr, args.proj_lr], total_steps=total_steps, pct_start=0.1, anneal_strategy="cos", ) # Output dir output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # AMP scaler scaler = torch.amp.GradScaler("cuda", enabled=use_amp) # Training state training_log = [] best_f1 = 0.0 patience_counter = 0 start_epoch = 1 disgust_gate = 0.3 # Resume from checkpoint if requested if args.resume: resume_path = output_dir / "last_lora.pt" if resume_path.exists(): resume_state = load_lora_checkpoint( encoder, resume_path, device, optimizer=optimizer, scheduler=scheduler, scaler=scaler, ) start_epoch = resume_state["epoch"] + 1 best_f1 = resume_state["best_f1"] patience_counter = resume_state["patience_counter"] training_log = resume_state["training_log"] logger.info("Resuming training from epoch %d (best_f1=%.4f)", start_epoch, best_f1) else: logger.warning("--resume set but no checkpoint found at %s, starting fresh", resume_path) for epoch in range(start_epoch, args.epochs + 1): epoch_start = time.time() encoder.train() total_loss = 0 correct = 0 total_samples = 0 optimizer.zero_grad() for batch_idx, (waveforms, labels) in enumerate(train_loader): labels = labels.to(device) with torch.amp.autocast("cuda", enabled=use_amp): logits = forward_pass(encoder, waveforms, device) loss = criterion(logits, labels) / accumulate_steps scaler.scale(loss).backward() if (batch_idx + 1) % accumulate_steps == 0 or (batch_idx + 1) == len(train_loader): scaler.unscale_(optimizer) trainable_params = [p for p in encoder.parameters() if p.requires_grad] torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() total_loss += loss.item() * accumulate_steps * labels.size(0) preds = logits.argmax(dim=-1) correct += (preds == labels).sum().item() total_samples += labels.size(0) if (batch_idx + 1) % 10 == 0: cur_lr = optimizer.param_groups[0]["lr"] logger.info(" Epoch %d [%d/%d] loss=%.4f lr=%.2e", epoch, batch_idx + 1, len(train_loader), loss.item() * accumulate_steps, cur_lr) train_loss = total_loss / max(total_samples, 1) train_acc = correct / max(total_samples, 1) # Validate val_metrics = validate(encoder, val_loader, device, criterion) epoch_time = time.time() - epoch_start logger.info( "Epoch %d/%d (%.0fs): train_loss=%.4f train_acc=%.3f | " "val_loss=%.4f val_f1=%.3f val_acc=%.3f disgust_f1=%.3f", epoch, args.epochs, epoch_time, train_loss, train_acc, val_metrics["loss"], val_metrics["macro_f1"], val_metrics["accuracy"], val_metrics["disgust_f1"], ) logger.info(" Per-class F1: %s", val_metrics["per_class_f1"]) epoch_log = { "epoch": epoch, "train_loss": round(train_loss, 4), "train_acc": round(train_acc, 4), "val_loss": val_metrics["loss"], "val_accuracy": val_metrics["accuracy"], "val_macro_f1": val_metrics["macro_f1"], "val_disgust_f1": val_metrics["disgust_f1"], "val_per_class_f1": val_metrics["per_class_f1"], "epoch_time_sec": round(epoch_time, 1), } training_log.append(epoch_log) # Disgust F1 Gate + best model save gate_pass = val_metrics["disgust_f1"] >= disgust_gate if not gate_pass: logger.warning("[GATE FAIL] disgust_f1=%.3f < %.1f — not saving as best", val_metrics["disgust_f1"], disgust_gate) if val_metrics["macro_f1"] > best_f1 and gate_pass: best_f1 = val_metrics["macro_f1"] patience_counter = 0 save_lora_checkpoint( encoder, output_dir / "best_lora.pt", epoch, val_metrics, best_f1, optimizer=optimizer, scheduler=scheduler, scaler=scaler, training_log=training_log, ) logger.info(" New best! macro_f1=%.4f (gate passed) saved to best_lora.pt", best_f1) if "confusion_matrix" in val_metrics: cm = np.array(val_metrics["confusion_matrix"]) plot_confusion_matrix(cm, output_dir / f"confusion_matrix_epoch{epoch}.png", epoch) elif gate_pass: patience_counter += 1 else: # Gate fail does not increment patience pass # Always save last (with full state for resume) save_lora_checkpoint( encoder, output_dir / "last_lora.pt", epoch, val_metrics, best_f1, patience_counter, optimizer=optimizer, scheduler=scheduler, scaler=scaler, training_log=training_log, ) # Save training log with open(output_dir / "training_log.json", "w") as f: json.dump(training_log, f, indent=2) # Early stopping (only on gate-passing epochs) if patience_counter >= args.patience: logger.info("Early stopping at epoch %d (patience=%d)", epoch, args.patience) break if device == "cuda": torch.cuda.empty_cache() # Save config config = { "model": "iic/emotion2vec_plus_base", "method": "LoRA", "lora_r": args.lora_r, "lora_alpha": args.lora_alpha, "num_classes": NUM_CLASSES, "labels": LABELS_7CLASS, "label2idx": LABEL2IDX, "best_val_f1": best_f1, "training_args": vars(args), } with open(output_dir / "config.json", "w") as f: json.dump(config, f, indent=2, ensure_ascii=False) logger.info("Training complete. Best F1=%.4f at %s", best_f1, output_dir / "best_lora.pt") # ────────────────────────────────────────────── # Main # ────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="LoRA fine-tune emotion2vec 7-class") parser.add_argument("--train-manifest", required=True) parser.add_argument("--val-manifest", required=True) parser.add_argument("--output-dir", default="data/models/lora_emotion2vec_7class") parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--accumulate-steps", type=int, default=8) parser.add_argument("--lora-r", type=int, default=16) parser.add_argument("--lora-alpha", type=int, default=32) parser.add_argument("--lora-dropout", type=float, default=0.1) parser.add_argument("--lora-lr", type=float, default=2e-4) parser.add_argument("--proj-lr", type=float, default=2e-3) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--patience", type=int, default=7) parser.add_argument("--phone-augment-prob", type=float, default=0.3) parser.add_argument("--noise-augment-prob", type=float, default=0.15) parser.add_argument("--resume", action="store_true", help="Resume from last_lora.pt checkpoint in output-dir") args = parser.parse_args() train(args) if __name__ == "__main__": main()