| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torchaudio |
| import onnxruntime as ort |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Decode saved decoder[4] features with exported iSTFTNet2.") |
| parser.add_argument("--feature", required=True, help="Path to .npy feature array with shape [frames, 768].") |
| parser.add_argument("--decoder-dir", default="istftnet2_decoder4_50hz") |
| parser.add_argument("--runtime", default="torchscript_cuda") |
| parser.add_argument("--output", default="outputs/decoder4_feature_48k.wav") |
| return parser.parse_args() |
|
|
|
|
| def run_torchscript(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor: |
| if runtime == "torchscript_cuda": |
| if not torch.cuda.is_available(): |
| raise RuntimeError("torchscript_cuda requested, but torch.cuda.is_available() is false.") |
| artifact = decoder_dir / "istftnet2_decoder_cuda.ts" |
| device = torch.device("cuda") |
| elif runtime == "torchscript_cpu": |
| artifact = decoder_dir / "istftnet2_decoder_cpu.ts" |
| device = torch.device("cpu") |
| else: |
| raise ValueError(runtime) |
| decoder = torch.jit.load(str(artifact), map_location=device).eval() |
| with torch.inference_mode(): |
| return decoder(features.to(device)).detach().float().cpu() |
|
|
|
|
| def run_onnx(features: torch.Tensor, decoder_dir: Path, runtime: str) -> torch.Tensor: |
| provider_map = { |
| "onnx_rocm": "ROCMExecutionProvider", |
| "onnx_cuda": "CUDAExecutionProvider", |
| "onnx_cpu": "CPUExecutionProvider", |
| } |
| provider = provider_map[runtime] |
| available = ort.get_available_providers() |
| if provider not in available: |
| raise RuntimeError(f"{provider} is unavailable. Available providers: {available}") |
| session = ort.InferenceSession(str(decoder_dir / "istftnet2_decoder.onnx"), providers=[provider]) |
| input_name = session.get_inputs()[0].name |
| audio = session.run(None, {input_name: features.numpy().astype(np.float32, copy=False)})[0] |
| return torch.from_numpy(audio).float() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| decoder_dir = Path(args.decoder_dir) |
| arr = np.load(args.feature).astype(np.float32, copy=False) |
| if arr.ndim != 2 or arr.shape[1] != 768: |
| raise ValueError(f"Expected feature shape [frames, 768], got {arr.shape}.") |
|
|
| features = torch.from_numpy(arr).T.unsqueeze(0).contiguous() |
| if args.runtime.startswith("torchscript_"): |
| audio = run_torchscript(features, decoder_dir, args.runtime) |
| else: |
| audio = run_onnx(features, decoder_dir, args.runtime) |
|
|
| wav = audio[0].clamp(-1.0, 1.0).contiguous() |
| expected_samples = int(arr.shape[0]) * 960 |
| wav = wav[..., :expected_samples] |
|
|
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| torchaudio.save(str(output), wav, 48000) |
| manifest = { |
| "output": str(output), |
| "feature": args.feature, |
| "decoder_dir": args.decoder_dir, |
| "runtime": args.runtime, |
| "sample_rate": 48000, |
| "feature_shape": list(arr.shape), |
| "num_samples": int(wav.shape[-1]), |
| "duration_sec": float(wav.shape[-1] / 48000), |
| } |
| output.with_suffix(".json").write_text(json.dumps(manifest, indent=2)) |
| print(json.dumps(manifest, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|