ustwo-api / scripts /export_lora_onnx.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
9.01 kB
#!/usr/bin/env python3
"""ONNX export for LoRA-finetuned emotion2vec 7-class model.
Merges LoRA weights into base model, wraps as a single waveform-to-logits
module, and exports to ONNX with dynamic batch/time axes.
Usage:
python scripts/export_lora_onnx.py \
--checkpoint data/models/lora_emotion2vec_7class/best_lora.pt \
--output data/models/lora_emotion2vec_7class/emotion2vec_lora.onnx \
--device cpu
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# Import LoRA components
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent))
from train_lora_emotion2vec import (
LoRALinear,
MLPHead,
inject_lora,
merge_lora_linear,
LABELS_7CLASS,
NUM_CLASSES,
)
def merge_all_lora(encoder: nn.Module) -> None:
"""Walk encoder.blocks, replace each LoRALinear with merged nn.Linear.
Modifies the encoder in-place.
"""
for block in encoder.blocks:
if isinstance(block.attn.qkv, LoRALinear):
block.attn.qkv = merge_lora_linear(block.attn.qkv)
if isinstance(block.attn.proj, LoRALinear):
block.attn.proj = merge_lora_linear(block.attn.proj)
class Emotion2vecONNXWrapper(nn.Module):
"""Wraps emotion2vec encoder for ONNX export.
forward(waveform: (B, T)) -> logits: (B, 7)
Includes layer_norm + extract_features + mean pool + proj.
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Args:
waveform: (B, T) float32, 16kHz
Returns:
logits: (B, 7)
"""
# Layer norm (per-sample, ONNX-compatible: normalize along time axis)
if self.encoder.cfg.normalize:
mean = waveform.mean(dim=-1, keepdim=True)
var = waveform.var(dim=-1, keepdim=True, unbiased=False)
waveform = (waveform - mean) / torch.sqrt(var + 1e-5)
# Extract features
feats = self.encoder.extract_features(waveform, padding_mask=None)
x = feats["x"] # (B, T', 768)
# Mean pool
pooled = x.mean(dim=1) # (B, 768)
# Classify
logits = self.encoder.proj(pooled) # (B, 7)
return logits
def export_onnx(checkpoint_path: str, output_path: str, device: str = "cpu"):
"""Load base emotion2vec, inject LoRA, load checkpoint, merge, export ONNX.
Args:
checkpoint_path: path to LoRA checkpoint (.pt)
output_path: path for output ONNX file
device: "cpu" or "cuda"
"""
from funasr import AutoModel
checkpoint_path = Path(checkpoint_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Load checkpoint to get config
logger.info("Loading checkpoint: %s", checkpoint_path)
ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=True)
# Load base model
logger.info("Loading emotion2vec_plus_base...")
fmodel = AutoModel(model="iic/emotion2vec_plus_base", device=device, hub="hf")
encoder = fmodel.model
# Freeze + inject LoRA (dropout=0 for inference)
for param in encoder.parameters():
param.requires_grad = False
inject_lora(encoder, r=16, alpha=32, dropout=0.0)
# Replace proj with MLPHead
num_classes = ckpt.get("num_classes", NUM_CLASSES)
encoder.proj = MLPHead(768, num_classes, dropout=0.0).to(device)
# Load LoRA weights
lora_weights = ckpt["lora_weights"]
for name, module in encoder.named_modules():
if isinstance(module, LoRALinear):
a_key = f"{name}.lora_A.weight"
b_key = f"{name}.lora_B.weight"
if a_key in lora_weights:
module.lora_A.weight.data.copy_(lora_weights[a_key])
if b_key in lora_weights:
module.lora_B.weight.data.copy_(lora_weights[b_key])
# Load proj state
encoder.proj.load_state_dict(ckpt["proj"])
# Merge LoRA into base weights
logger.info("Merging LoRA weights...")
merge_all_lora(encoder)
# Verify no LoRALinear remains
lora_count = sum(1 for m in encoder.modules() if isinstance(m, LoRALinear))
assert lora_count == 0, f"Merge failed: {lora_count} LoRALinear remain"
# Wrap for ONNX
wrapper = Emotion2vecONNXWrapper(encoder)
wrapper.eval()
# Dummy input (1 second of audio at 16kHz)
dummy_input = torch.randn(1, 16000, device=device)
# Export
logger.info("Exporting ONNX to %s ...", output_path)
torch.onnx.export(
wrapper,
dummy_input,
str(output_path),
opset_version=17,
input_names=["waveform"],
output_names=["logits"],
dynamic_axes={
"waveform": {0: "batch", 1: "time"},
"logits": {0: "batch"},
},
)
logger.info("ONNX export complete: %s", output_path)
# Save label metadata JSON alongside
meta_path = output_path.with_suffix(".json")
meta = {
"model": "emotion2vec_plus_base + LoRA (merged)",
"num_classes": num_classes,
"labels": ckpt.get("labels", LABELS_7CLASS),
"input": "waveform: (batch, time) float32, 16kHz mono",
"output": f"logits: (batch, {num_classes}) float32",
"checkpoint": str(checkpoint_path),
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
logger.info("Metadata saved: %s", meta_path)
def export_kcelectra_onnx(model_dir: str, output_path: str, base_model_id: str = "beomi/KcELECTRA-base-v2022"):
"""Export LoRA-finetuned KcELECTRA to ONNX.
Merges PEFT LoRA into base model, then exports text โ†’ logits ONNX.
"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
output_path = Path(output_path)
model_dir = Path(model_dir)
logger.info("Loading base model: %s", base_model_id)
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_id, num_labels=7,
)
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
logger.info("Loading PEFT adapter: %s", model_dir)
model = PeftModel.from_pretrained(base_model, str(model_dir))
logger.info("Merging LoRA weights...")
model = model.merge_and_unload()
model.eval()
# Dummy input
dummy = tokenizer("ํ…Œ์ŠคํŠธ ๋ฌธ์žฅ์ž…๋‹ˆ๋‹ค", return_tensors="pt", max_length=128, truncation=True, padding="max_length")
output_path.parent.mkdir(parents=True, exist_ok=True)
logger.info("Exporting ONNX to %s", output_path)
torch.onnx.export(
model,
(dummy["input_ids"], dummy["attention_mask"]),
str(output_path),
opset_version=17,
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq_len"},
"attention_mask": {0: "batch", 1: "seq_len"},
"logits": {0: "batch"},
},
)
# Save metadata
labels = ["happiness", "anger", "disgust", "fear", "neutral", "sadness", "surprise"]
meta_path = output_path.with_suffix(".json")
meta = {
"model": f"{base_model_id} + LoRA (merged)",
"num_classes": 7,
"labels": labels,
"input": "input_ids: (batch, seq_len) int64, attention_mask: (batch, seq_len) int64",
"output": "logits: (batch, 7) float32",
"max_length": 128,
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
logger.info("ONNX export complete: %s (+ %s)", output_path, meta_path)
def main():
parser = argparse.ArgumentParser(description="Export LoRA models to ONNX")
parser.add_argument("--mode", required=True, choices=["audio", "text"],
help="audio: emotion2vec, text: KcELECTRA")
# Audio args
parser.add_argument("--checkpoint", help="Path to LoRA checkpoint (.pt) for audio mode")
# Text args
parser.add_argument("--model-dir", help="Path to PEFT model directory for text mode")
parser.add_argument("--base-model", default="beomi/KcELECTRA-base-v2022")
# Common
parser.add_argument("--output", required=True, help="Output ONNX path")
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"])
args = parser.parse_args()
if args.mode == "audio":
if not args.checkpoint:
parser.error("--checkpoint required for audio mode")
export_onnx(args.checkpoint, args.output, args.device)
else:
if not args.model_dir:
parser.error("--model-dir required for text mode")
export_kcelectra_onnx(args.model_dir, args.output, args.base_model)
if __name__ == "__main__":
main()