#!/usr/bin/env python from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import torch import torchaudio import onnxruntime as ort def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Decode saved decoder[4] features with exported iSTFTNet2.") parser.add_argument("--feature", required=True, help="Path to .npy feature array with shape [frames, 768].") parser.add_argument("--decoder-dir", default="istftnet2_decoder4_50hz") parser.add_argument("--runtime", default="torchscript_cuda") parser.add_argument("--output", default="outputs/decoder4_feature_48k.wav") return parser.parse_args() def run_torchscript(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor: if runtime == "torchscript_cuda": if not torch.cuda.is_available(): raise RuntimeError("torchscript_cuda requested, but torch.cuda.is_available() is false.") artifact = decoder_dir / "istftnet2_decoder_cuda.ts" device = torch.device("cuda") elif runtime == "torchscript_cpu": artifact = decoder_dir / "istftnet2_decoder_cpu.ts" device = torch.device("cpu") else: raise ValueError(runtime) decoder = torch.jit.load(str(artifact), map_location=device).eval() with torch.inference_mode(): return decoder(features.to(device)).detach().float().cpu() def run_onnx(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor: provider_map = { "onnx_rocm": "ROCMExecutionProvider", "onnx_cuda": "CUDAExecutionProvider", "onnx_cpu": "CPUExecutionProvider", } provider = provider_map[runtime] available = ort.get_available_providers() if provider not in available: raise RuntimeError(f"{provider} is unavailable. Available providers: {available}") session = ort.InferenceSession(str(decoder_dir / "istftnet2_decoder.onnx"), providers=[provider]) input_name = session.get_inputs()[0].name audio = session.run(None, {input_name: features.numpy().astype(np.float32, copy=False)})[0] return torch.from_numpy(audio).float() def main() -> None: args = parse_args() decoder_dir = Path(args.decoder_dir) arr = np.load(args.feature).astype(np.float32, copy=False) if arr.ndim != 2 or arr.shape[1] != 768: raise ValueError(f"Expected feature shape [frames, 768], got {arr.shape}.") features = torch.from_numpy(arr).T.unsqueeze(0).contiguous() if args.runtime.startswith("torchscript_"): audio = run_torchscript(features, decoder_dir, args.runtime) else: audio = run_onnx(features, decoder_dir, args.runtime) wav = audio[0].clamp(-1.0, 1.0).contiguous() expected_samples = int(arr.shape[0]) * 960 wav = wav[..., :expected_samples] output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) torchaudio.save(str(output), wav, 48000) manifest = { "output": str(output), "feature": args.feature, "decoder_dir": args.decoder_dir, "runtime": args.runtime, "sample_rate": 48000, "feature_shape": list(arr.shape), "num_samples": int(wav.shape[-1]), "duration_sec": float(wav.shape[-1] / 48000), } output.with_suffix(".json").write_text(json.dumps(manifest, indent=2)) print(json.dumps(manifest, indent=2)) if __name__ == "__main__": main()