#!/usr/bin/env python3 """ONNX export for LoRA-finetuned emotion2vec 7-class model. Merges LoRA weights into base model, wraps as a single waveform-to-logits module, and exports to ONNX with dynamic batch/time axes. Usage: python scripts/export_lora_onnx.py \ --checkpoint data/models/lora_emotion2vec_7class/best_lora.pt \ --output data/models/lora_emotion2vec_7class/emotion2vec_lora.onnx \ --device cpu """ from __future__ import annotations import argparse import json import logging from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # Import LoRA components import sys sys.path.insert(0, str(Path(__file__).resolve().parent)) from train_lora_emotion2vec import ( LoRALinear, MLPHead, inject_lora, merge_lora_linear, LABELS_7CLASS, NUM_CLASSES, ) def merge_all_lora(encoder: nn.Module) -> None: """Walk encoder.blocks, replace each LoRALinear with merged nn.Linear. Modifies the encoder in-place. """ for block in encoder.blocks: if isinstance(block.attn.qkv, LoRALinear): block.attn.qkv = merge_lora_linear(block.attn.qkv) if isinstance(block.attn.proj, LoRALinear): block.attn.proj = merge_lora_linear(block.attn.proj) class Emotion2vecONNXWrapper(nn.Module): """Wraps emotion2vec encoder for ONNX export. forward(waveform: (B, T)) -> logits: (B, 7) Includes layer_norm + extract_features + mean pool + proj. """ def __init__(self, encoder): super().__init__() self.encoder = encoder def forward(self, waveform: torch.Tensor) -> torch.Tensor: """ Args: waveform: (B, T) float32, 16kHz Returns: logits: (B, 7) """ # Layer norm (per-sample, ONNX-compatible: normalize along time axis) if self.encoder.cfg.normalize: mean = waveform.mean(dim=-1, keepdim=True) var = waveform.var(dim=-1, keepdim=True, unbiased=False) waveform = (waveform - mean) / torch.sqrt(var + 1e-5) # Extract features feats = self.encoder.extract_features(waveform, padding_mask=None) x = feats["x"] # (B, T', 768) # Mean pool pooled = x.mean(dim=1) # (B, 768) # Classify logits = self.encoder.proj(pooled) # (B, 7) return logits def export_onnx(checkpoint_path: str, output_path: str, device: str = "cpu"): """Load base emotion2vec, inject LoRA, load checkpoint, merge, export ONNX. Args: checkpoint_path: path to LoRA checkpoint (.pt) output_path: path for output ONNX file device: "cpu" or "cuda" """ from funasr import AutoModel checkpoint_path = Path(checkpoint_path) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # Load checkpoint to get config logger.info("Loading checkpoint: %s", checkpoint_path) ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=True) # Load base model logger.info("Loading emotion2vec_plus_base...") fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf") encoder = fmodel.model # Freeze + inject LoRA (dropout=0 for inference) for param in encoder.parameters(): param.requires_grad = False inject_lora(encoder, r=16, alpha=32, dropout=0.0) # Replace proj with MLPHead num_classes = ckpt.get("num_classes", NUM_CLASSES) encoder.proj = MLPHead(768, num_classes, dropout=0.0).to(device) # Load LoRA weights lora_weights = ckpt["lora_weights"] for name, module in encoder.named_modules(): if isinstance(module, LoRALinear): a_key = f"{name}.lora_A.weight" b_key = f"{name}.lora_B.weight" if a_key in lora_weights: module.lora_A.weight.data.copy_(lora_weights[a_key]) if b_key in lora_weights: module.lora_B.weight.data.copy_(lora_weights[b_key]) # Load proj state encoder.proj.load_state_dict(ckpt["proj"]) # Merge LoRA into base weights logger.info("Merging LoRA weights...") merge_all_lora(encoder) # Verify no LoRALinear remains lora_count = sum(1 for m in encoder.modules() if isinstance(m, LoRALinear)) assert lora_count == 0, f"Merge failed: {lora_count} LoRALinear remain" # Wrap for ONNX wrapper = Emotion2vecONNXWrapper(encoder) wrapper.eval() # Dummy input (1 second of audio at 16kHz) dummy_input = torch.randn(1, 16000, device=device) # Export logger.info("Exporting ONNX to %s ...", output_path) torch.onnx.export( wrapper, dummy_input, str(output_path), opset_version=17, input_names=["waveform"], output_names=["logits"], dynamic_axes={ "waveform": {0: "batch", 1: "time"}, "logits": {0: "batch"}, }, ) logger.info("ONNX export complete: %s", output_path) # Save label metadata JSON alongside meta_path = output_path.with_suffix(".json") meta = { "model": "emotion2vec_plus_base + LoRA (merged)", "num_classes": num_classes, "labels": ckpt.get("labels", LABELS_7CLASS), "input": "waveform: (batch, time) float32, 16kHz mono", "output": f"logits: (batch, {num_classes}) float32", "checkpoint": str(checkpoint_path), } with open(meta_path, "w") as f: json.dump(meta, f, indent=2) logger.info("Metadata saved: %s", meta_path) def export_kcelectra_onnx(model_dir: str, output_path: str, base_model_id: str = "beomi/KcELECTRA-base-v2022"): """Export LoRA-finetuned KcELECTRA to ONNX. Merges PEFT LoRA into base model, then exports text → logits ONNX. """ from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel output_path = Path(output_path) model_dir = Path(model_dir) logger.info("Loading base model: %s", base_model_id) base_model = AutoModelForSequenceClassification.from_pretrained( base_model_id, num_labels=7, ) tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) logger.info("Loading PEFT adapter: %s", model_dir) model = PeftModel.from_pretrained(base_model, str(model_dir)) logger.info("Merging LoRA weights...") model = model.merge_and_unload() model.eval() # Dummy input dummy = tokenizer("테스트 문장입니다", return_tensors="pt", max_length=128, truncation=True, padding="max_length") output_path.parent.mkdir(parents=True, exist_ok=True) logger.info("Exporting ONNX to %s", output_path) torch.onnx.export( model, (dummy["input_ids"], dummy["attention_mask"]), str(output_path), opset_version=17, input_names=["input_ids", "attention_mask"], output_names=["logits"], dynamic_axes={ "input_ids": {0: "batch", 1: "seq_len"}, "attention_mask": {0: "batch", 1: "seq_len"}, "logits": {0: "batch"}, }, ) # Save metadata labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] meta_path = output_path.with_suffix(".json") meta = { "model": f"{base_model_id} + LoRA (merged)", "num_classes": 7, "labels": labels, "input": "input_ids: (batch, seq_len) int64, attention_mask: (batch, seq_len) int64", "output": "logits: (batch, 7) float32", "max_length": 128, } with open(meta_path, "w") as f: json.dump(meta, f, indent=2, ensure_ascii=False) logger.info("ONNX export complete: %s (+ %s)", output_path, meta_path) def main(): parser = argparse.ArgumentParser(description="Export LoRA models to ONNX") parser.add_argument("--mode", required=True, choices=["audio", "text"], help="audio: emotion2vec, text: KcELECTRA") # Audio args parser.add_argument("--checkpoint", help="Path to LoRA checkpoint (.pt) for audio mode") # Text args parser.add_argument("--model-dir", help="Path to PEFT model directory for text mode") parser.add_argument("--base-model", default="beomi/KcELECTRA-base-v2022") # Common parser.add_argument("--output", required=True, help="Output ONNX path") parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) args = parser.parse_args() if args.mode == "audio": if not args.checkpoint: parser.error("--checkpoint required for audio mode") export_onnx(args.checkpoint, args.output, args.device) else: if not args.model_dir: parser.error("--model-dir required for text mode") export_kcelectra_onnx(args.model_dir, args.output, args.base_model) if __name__ == "__main__": main()