vits-tts-mnn / reexport_vits_with_controls.py
developerabu's picture
Upload 7 files
6d774ce verified
import argparse
import shutil
import subprocess
from pathlib import Path
import onnx
import torch
from transformers import AutoModel, AutoTokenizer
class VitsExportWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model.eval()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
speaker_id: torch.Tensor,
emotion_id: torch.Tensor,
) -> torch.Tensor:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
speaker_id=speaker_id.to(torch.long),
emotion_id=emotion_id.to(torch.long),
return_dict=True,
)
return outputs.waveform
def inspect_onnx(onnx_path: Path) -> None:
model = onnx.load(str(onnx_path))
print("onnx inputs:")
for value in model.graph.input:
tensor_type = value.type.tensor_type
dims = []
for dim in tensor_type.shape.dim:
if dim.dim_value:
dims.append(dim.dim_value)
elif dim.dim_param:
dims.append(dim.dim_param)
else:
dims.append("?")
print(f" {value.name}: shape={dims}, elem_type={tensor_type.elem_type}")
def inspect_mnn(mnn_path: Path) -> None:
import MNN.expr as expr
graph_vars = expr.load_as_dict(str(mnn_path))
for name in ("input_ids", "attention_mask", "speaker_id", "emotion_id"):
if name not in graph_vars:
print(f"mnn input missing: {name}")
continue
var = graph_vars[name]
print(f"mnn input {name}: shape={var.shape}, dtype={var.dtype}, format={var.data_format}")
def export_onnx(args: argparse.Namespace) -> None:
model_dir = Path(args.model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
wrapper = VitsExportWrapper(model)
tokenized = tokenizer(text=args.text, return_tensors="pt")
input_ids = tokenized["input_ids"].to(torch.long)
attention_mask = tokenized.get("attention_mask", torch.ones_like(input_ids)).to(torch.long)
speaker_id = torch.tensor([args.speaker_id], dtype=torch.long)
emotion_id = torch.tensor([args.style_id], dtype=torch.long)
torch.onnx.export(
wrapper,
(input_ids, attention_mask, speaker_id, emotion_id),
str(args.onnx_output),
input_names=["input_ids", "attention_mask", "speaker_id", "emotion_id"],
output_names=["waveform"],
dynamic_axes={
"input_ids": {1: "text_length"},
"attention_mask": {1: "text_length"},
"waveform": {1: "audio_length"},
},
opset_version=args.opset,
do_constant_folding=True,
dynamo=False,
)
print(f"wrote {args.onnx_output}")
inspect_onnx(Path(args.onnx_output))
def convert_to_mnn(args: argparse.Namespace) -> None:
if not args.mnn_output:
return
converter = shutil.which("MNNConvert")
if converter is None:
raise FileNotFoundError("MNNConvert not found in PATH")
command = [
converter,
"-f",
"ONNX",
"--modelFile",
str(args.onnx_output),
"--MNNModel",
str(args.mnn_output),
"--bizCode",
"MNN",
]
subprocess.run(command, check=True)
print(f"wrote {args.mnn_output}")
inspect_mnn(Path(args.mnn_output))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Re-export local VITS weights with speaker/style control inputs.")
parser.add_argument("--model-dir", default=".", help="Directory containing config, tokenizer, custom code, and weights.")
parser.add_argument("--text", default="வணக்கம்", help="Sample text used to trace the export graph.")
parser.add_argument("--speaker-id", type=int, default=18, help="Sample speaker ID used during tracing.")
parser.add_argument("--style-id", type=int, default=0, help="Sample style or emotion ID used during tracing.")
parser.add_argument("--onnx-output", default="vits_tamil_with_controls.onnx", help="Path for exported ONNX.")
parser.add_argument("--mnn-output", default="vits_tamil_with_controls.mnn", help="Optional output path for converted MNN.")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.")
return parser.parse_args()
def main() -> None:
args = parse_args()
export_onnx(args)
convert_to_mnn(args)
if __name__ == "__main__":
main()