| |
| """ONNX export for LoRA-finetuned emotion2vec 7-class model. |
| |
| Merges LoRA weights into base model, wraps as a single waveform-to-logits |
| module, and exports to ONNX with dynamic batch/time axes. |
| |
| Usage: |
| python scripts/export_lora_onnx.py \ |
| --checkpoint data/models/lora_emotion2vec_7class/best_lora.pt \ |
| --output data/models/lora_emotion2vec_7class/emotion2vec_lora.onnx \ |
| --device cpu |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
|
|
| from train_lora_emotion2vec import ( |
| LoRALinear, |
| MLPHead, |
| inject_lora, |
| merge_lora_linear, |
| LABELS_7CLASS, |
| NUM_CLASSES, |
| ) |
|
|
|
|
| def merge_all_lora(encoder: nn.Module) -> None: |
| """Walk encoder.blocks, replace each LoRALinear with merged nn.Linear. |
| |
| Modifies the encoder in-place. |
| """ |
| for block in encoder.blocks: |
| if isinstance(block.attn.qkv, LoRALinear): |
| block.attn.qkv = merge_lora_linear(block.attn.qkv) |
| if isinstance(block.attn.proj, LoRALinear): |
| block.attn.proj = merge_lora_linear(block.attn.proj) |
|
|
|
|
| class Emotion2vecONNXWrapper(nn.Module): |
| """Wraps emotion2vec encoder for ONNX export. |
| |
| forward(waveform: (B, T)) -> logits: (B, 7) |
| Includes layer_norm + extract_features + mean pool + proj. |
| """ |
|
|
| def __init__(self, encoder): |
| super().__init__() |
| self.encoder = encoder |
|
|
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| waveform: (B, T) float32, 16kHz |
| |
| Returns: |
| logits: (B, 7) |
| """ |
| |
| if self.encoder.cfg.normalize: |
| mean = waveform.mean(dim=-1, keepdim=True) |
| var = waveform.var(dim=-1, keepdim=True, unbiased=False) |
| waveform = (waveform - mean) / torch.sqrt(var + 1e-5) |
|
|
| |
| feats = self.encoder.extract_features(waveform, padding_mask=None) |
| x = feats["x"] |
|
|
| |
| pooled = x.mean(dim=1) |
|
|
| |
| logits = self.encoder.proj(pooled) |
| return logits |
|
|
|
|
| def export_onnx(checkpoint_path: str, output_path: str, device: str = "cpu"): |
| """Load base emotion2vec, inject LoRA, load checkpoint, merge, export ONNX. |
| |
| Args: |
| checkpoint_path: path to LoRA checkpoint (.pt) |
| output_path: path for output ONNX file |
| device: "cpu" or "cuda" |
| """ |
| from funasr import AutoModel |
|
|
| checkpoint_path = Path(checkpoint_path) |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| logger.info("Loading checkpoint: %s", checkpoint_path) |
| ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=True) |
|
|
| |
| logger.info("Loading emotion2vec_plus_base...") |
| fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf") |
| encoder = fmodel.model |
|
|
| |
| for param in encoder.parameters(): |
| param.requires_grad = False |
| inject_lora(encoder, r=16, alpha=32, dropout=0.0) |
|
|
| |
| num_classes = ckpt.get("num_classes", NUM_CLASSES) |
| encoder.proj = MLPHead(768, num_classes, dropout=0.0).to(device) |
|
|
| |
| lora_weights = ckpt["lora_weights"] |
| for name, module in encoder.named_modules(): |
| if isinstance(module, LoRALinear): |
| a_key = f"{name}.lora_A.weight" |
| b_key = f"{name}.lora_B.weight" |
| if a_key in lora_weights: |
| module.lora_A.weight.data.copy_(lora_weights[a_key]) |
| if b_key in lora_weights: |
| module.lora_B.weight.data.copy_(lora_weights[b_key]) |
|
|
| |
| encoder.proj.load_state_dict(ckpt["proj"]) |
|
|
| |
| logger.info("Merging LoRA weights...") |
| merge_all_lora(encoder) |
|
|
| |
| lora_count = sum(1 for m in encoder.modules() if isinstance(m, LoRALinear)) |
| assert lora_count == 0, f"Merge failed: {lora_count} LoRALinear remain" |
|
|
| |
| wrapper = Emotion2vecONNXWrapper(encoder) |
| wrapper.eval() |
|
|
| |
| dummy_input = torch.randn(1, 16000, device=device) |
|
|
| |
| logger.info("Exporting ONNX to %s ...", output_path) |
| torch.onnx.export( |
| wrapper, |
| dummy_input, |
| str(output_path), |
| opset_version=17, |
| input_names=["waveform"], |
| output_names=["logits"], |
| dynamic_axes={ |
| "waveform": {0: "batch", 1: "time"}, |
| "logits": {0: "batch"}, |
| }, |
| ) |
| logger.info("ONNX export complete: %s", output_path) |
|
|
| |
| meta_path = output_path.with_suffix(".json") |
| meta = { |
| "model": "emotion2vec_plus_base + LoRA (merged)", |
| "num_classes": num_classes, |
| "labels": ckpt.get("labels", LABELS_7CLASS), |
| "input": "waveform: (batch, time) float32, 16kHz mono", |
| "output": f"logits: (batch, {num_classes}) float32", |
| "checkpoint": str(checkpoint_path), |
| } |
| with open(meta_path, "w") as f: |
| json.dump(meta, f, indent=2) |
| logger.info("Metadata saved: %s", meta_path) |
|
|
|
|
| def export_kcelectra_onnx(model_dir: str, output_path: str, base_model_id: str = "beomi/KcELECTRA-base-v2022"): |
| """Export LoRA-finetuned KcELECTRA to ONNX. |
| |
| Merges PEFT LoRA into base model, then exports text โ logits ONNX. |
| """ |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from peft import PeftModel |
|
|
| output_path = Path(output_path) |
| model_dir = Path(model_dir) |
|
|
| logger.info("Loading base model: %s", base_model_id) |
| base_model = AutoModelForSequenceClassification.from_pretrained( |
| base_model_id, num_labels=7, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) |
|
|
| logger.info("Loading PEFT adapter: %s", model_dir) |
| model = PeftModel.from_pretrained(base_model, str(model_dir)) |
|
|
| logger.info("Merging LoRA weights...") |
| model = model.merge_and_unload() |
| model.eval() |
|
|
| |
| dummy = tokenizer("ํ
์คํธ ๋ฌธ์ฅ์
๋๋ค", return_tensors="pt", max_length=128, truncation=True, padding="max_length") |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| logger.info("Exporting ONNX to %s", output_path) |
|
|
| torch.onnx.export( |
| model, |
| (dummy["input_ids"], dummy["attention_mask"]), |
| str(output_path), |
| opset_version=17, |
| input_names=["input_ids", "attention_mask"], |
| output_names=["logits"], |
| dynamic_axes={ |
| "input_ids": {0: "batch", 1: "seq_len"}, |
| "attention_mask": {0: "batch", 1: "seq_len"}, |
| "logits": {0: "batch"}, |
| }, |
| ) |
|
|
| |
| labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"] |
| meta_path = output_path.with_suffix(".json") |
| meta = { |
| "model": f"{base_model_id} + LoRA (merged)", |
| "num_classes": 7, |
| "labels": labels, |
| "input": "input_ids: (batch, seq_len) int64, attention_mask: (batch, seq_len) int64", |
| "output": "logits: (batch, 7) float32", |
| "max_length": 128, |
| } |
| with open(meta_path, "w") as f: |
| json.dump(meta, f, indent=2, ensure_ascii=False) |
|
|
| logger.info("ONNX export complete: %s (+ %s)", output_path, meta_path) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export LoRA models to ONNX") |
| parser.add_argument("--mode", required=True, choices=["audio", "text"], |
| help="audio: emotion2vec, text: KcELECTRA") |
| |
| parser.add_argument("--checkpoint", help="Path to LoRA checkpoint (.pt) for audio mode") |
| |
| parser.add_argument("--model-dir", help="Path to PEFT model directory for text mode") |
| parser.add_argument("--base-model", default="beomi/KcELECTRA-base-v2022") |
| |
| parser.add_argument("--output", required=True, help="Output ONNX path") |
| parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) |
| args = parser.parse_args() |
|
|
| if args.mode == "audio": |
| if not args.checkpoint: |
| parser.error("--checkpoint required for audio mode") |
| export_onnx(args.checkpoint, args.output, args.device) |
| else: |
| if not args.model_dir: |
| parser.error("--model-dir required for text mode") |
| export_kcelectra_onnx(args.model_dir, args.output, args.base_model) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|