| import argparse |
| import shutil |
| import subprocess |
| from pathlib import Path |
|
|
| import onnx |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
|
|
|
|
| class VitsExportWrapper(torch.nn.Module): |
| def __init__(self, model: torch.nn.Module): |
| super().__init__() |
| self.model = model.eval() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| speaker_id: torch.Tensor, |
| emotion_id: torch.Tensor, |
| ) -> torch.Tensor: |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| speaker_id=speaker_id.to(torch.long), |
| emotion_id=emotion_id.to(torch.long), |
| return_dict=True, |
| ) |
| return outputs.waveform |
|
|
|
|
| def inspect_onnx(onnx_path: Path) -> None: |
| model = onnx.load(str(onnx_path)) |
| print("onnx inputs:") |
| for value in model.graph.input: |
| tensor_type = value.type.tensor_type |
| dims = [] |
| for dim in tensor_type.shape.dim: |
| if dim.dim_value: |
| dims.append(dim.dim_value) |
| elif dim.dim_param: |
| dims.append(dim.dim_param) |
| else: |
| dims.append("?") |
| print(f" {value.name}: shape={dims}, elem_type={tensor_type.elem_type}") |
|
|
|
|
| def inspect_mnn(mnn_path: Path) -> None: |
| import MNN.expr as expr |
|
|
| graph_vars = expr.load_as_dict(str(mnn_path)) |
| for name in ("input_ids", "attention_mask", "speaker_id", "emotion_id"): |
| if name not in graph_vars: |
| print(f"mnn input missing: {name}") |
| continue |
| var = graph_vars[name] |
| print(f"mnn input {name}: shape={var.shape}, dtype={var.dtype}, format={var.data_format}") |
|
|
|
|
| def export_onnx(args: argparse.Namespace) -> None: |
| model_dir = Path(args.model_dir) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) |
| model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) |
| wrapper = VitsExportWrapper(model) |
|
|
| tokenized = tokenizer(text=args.text, return_tensors="pt") |
| input_ids = tokenized["input_ids"].to(torch.long) |
| attention_mask = tokenized.get("attention_mask", torch.ones_like(input_ids)).to(torch.long) |
| speaker_id = torch.tensor([args.speaker_id], dtype=torch.long) |
| emotion_id = torch.tensor([args.style_id], dtype=torch.long) |
|
|
| torch.onnx.export( |
| wrapper, |
| (input_ids, attention_mask, speaker_id, emotion_id), |
| str(args.onnx_output), |
| input_names=["input_ids", "attention_mask", "speaker_id", "emotion_id"], |
| output_names=["waveform"], |
| dynamic_axes={ |
| "input_ids": {1: "text_length"}, |
| "attention_mask": {1: "text_length"}, |
| "waveform": {1: "audio_length"}, |
| }, |
| opset_version=args.opset, |
| do_constant_folding=True, |
| dynamo=False, |
| ) |
| print(f"wrote {args.onnx_output}") |
| inspect_onnx(Path(args.onnx_output)) |
|
|
|
|
| def convert_to_mnn(args: argparse.Namespace) -> None: |
| if not args.mnn_output: |
| return |
|
|
| converter = shutil.which("MNNConvert") |
| if converter is None: |
| raise FileNotFoundError("MNNConvert not found in PATH") |
|
|
| command = [ |
| converter, |
| "-f", |
| "ONNX", |
| "--modelFile", |
| str(args.onnx_output), |
| "--MNNModel", |
| str(args.mnn_output), |
| "--bizCode", |
| "MNN", |
| ] |
| subprocess.run(command, check=True) |
| print(f"wrote {args.mnn_output}") |
| inspect_mnn(Path(args.mnn_output)) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Re-export local VITS weights with speaker/style control inputs.") |
| parser.add_argument("--model-dir", default=".", help="Directory containing config, tokenizer, custom code, and weights.") |
| parser.add_argument("--text", default="வணக்கம்", help="Sample text used to trace the export graph.") |
| parser.add_argument("--speaker-id", type=int, default=18, help="Sample speaker ID used during tracing.") |
| parser.add_argument("--style-id", type=int, default=0, help="Sample style or emotion ID used during tracing.") |
| parser.add_argument("--onnx-output", default="vits_tamil_with_controls.onnx", help="Path for exported ONNX.") |
| parser.add_argument("--mnn-output", default="vits_tamil_with_controls.mnn", help="Optional output path for converted MNN.") |
| parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| export_onnx(args) |
| convert_to_mnn(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |