MOSS-TTS-PNY / run_decoder4_features.py
ZDisket's picture
Upload MOSS-TTS Clipper iSTFTNet2 checkpoint bundle
dd51636 verified
Raw
History Blame Contribute Delete
3.47 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import torch
import torchaudio
import onnxruntime as ort
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Decode saved decoder[4] features with exported iSTFTNet2.")
parser.add_argument("--feature", required=True, help="Path to .npy feature array with shape [frames, 768].")
parser.add_argument("--decoder-dir", default="istftnet2_decoder4_50hz")
parser.add_argument("--runtime", default="torchscript_cuda")
parser.add_argument("--output", default="outputs/decoder4_feature_48k.wav")
return parser.parse_args()
def run_torchscript(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor:
if runtime == "torchscript_cuda":
if not torch.cuda.is_available():
raise RuntimeError("torchscript_cuda requested, but torch.cuda.is_available() is false.")
artifact = decoder_dir / "istftnet2_decoder_cuda.ts"
device = torch.device("cuda")
elif runtime == "torchscript_cpu":
artifact = decoder_dir / "istftnet2_decoder_cpu.ts"
device = torch.device("cpu")
else:
raise ValueError(runtime)
decoder = torch.jit.load(str(artifact), map_location=device).eval()
with torch.inference_mode():
return decoder(features.to(device)).detach().float().cpu()
def run_onnx(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor:
provider_map = {
"onnx_rocm": "ROCMExecutionProvider",
"onnx_cuda": "CUDAExecutionProvider",
"onnx_cpu": "CPUExecutionProvider",
}
provider = provider_map[runtime]
available = ort.get_available_providers()
if provider not in available:
raise RuntimeError(f"{provider} is unavailable. Available providers: {available}")
session = ort.InferenceSession(str(decoder_dir / "istftnet2_decoder.onnx"), providers=[provider])
input_name = session.get_inputs()[0].name
audio = session.run(None, {input_name: features.numpy().astype(np.float32, copy=False)})[0]
return torch.from_numpy(audio).float()
def main() -> None:
args = parse_args()
decoder_dir = Path(args.decoder_dir)
arr = np.load(args.feature).astype(np.float32, copy=False)
if arr.ndim != 2 or arr.shape[1] != 768:
raise ValueError(f"Expected feature shape [frames, 768], got {arr.shape}.")
features = torch.from_numpy(arr).T.unsqueeze(0).contiguous()
if args.runtime.startswith("torchscript_"):
audio = run_torchscript(features, decoder_dir, args.runtime)
else:
audio = run_onnx(features, decoder_dir, args.runtime)
wav = audio[0].clamp(-1.0, 1.0).contiguous()
expected_samples = int(arr.shape[0]) * 960
wav = wav[..., :expected_samples]
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(output), wav, 48000)
manifest = {
"output": str(output),
"feature": args.feature,
"decoder_dir": args.decoder_dir,
"runtime": args.runtime,
"sample_rate": 48000,
"feature_shape": list(arr.shape),
"num_samples": int(wav.shape[-1]),
"duration_sec": float(wav.shape[-1] / 48000),
}
output.with_suffix(".json").write_text(json.dumps(manifest, indent=2))
print(json.dumps(manifest, indent=2))
if __name__ == "__main__":
main()