taresh18 commited on
Commit
c82ec83
·
verified ·
1 Parent(s): 2dd6e38

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +94 -0
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nano-codec inference: reconstruct audio through the codec.
3
+
4
+ Usage:
5
+ python inference.py --input input.wav --output reconstructed.wav
6
+
7
+ Downloads model weights from HuggingFace on first run.
8
+ """
9
+
10
+ import argparse
11
+ import torch
12
+ import soundfile as sf
13
+ import torchaudio
14
+ import yaml
15
+ from huggingface_hub import hf_hub_download
16
+ from model import RVQCodec
17
+
18
+ REPO_ID = "taresh18/nano-codec"
19
+
20
+
21
+ def load_model(device="cpu"):
22
+ model_path = hf_hub_download(REPO_ID, "model.pt")
23
+ config_path = hf_hub_download(REPO_ID, "config.yaml")
24
+
25
+ with open(config_path) as f:
26
+ cfg = yaml.safe_load(f)
27
+
28
+ model = RVQCodec(
29
+ in_ch=1,
30
+ latent_ch=cfg['latent_dim'],
31
+ K=cfg['codebook_size'],
32
+ num_rvq_levels=cfg['num_rvq_levels'],
33
+ codebook_dim=cfg.get('codebook_dim', 8),
34
+ )
35
+
36
+ state = torch.load(model_path, map_location=device, weights_only=True)
37
+ model.load_state_dict(state)
38
+ model = model.to(device)
39
+ model.eval()
40
+
41
+ return model, cfg
42
+
43
+
44
+ def reconstruct(model, audio_path, output_path, sample_rate=16000, chunk_size=16384, device="cpu"):
45
+ audio, sr = sf.read(audio_path, dtype='float32')
46
+ if audio.ndim > 1:
47
+ audio = audio.mean(axis=1)
48
+ waveform = torch.from_numpy(audio).unsqueeze(0)
49
+
50
+ if sr != sample_rate:
51
+ waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
52
+
53
+ waveform = waveform / waveform.abs().max().clamp(min=1e-8)
54
+
55
+ total_samples = waveform.shape[1]
56
+ pad_len = (chunk_size - total_samples % chunk_size) % chunk_size
57
+ if pad_len > 0:
58
+ waveform = torch.nn.functional.pad(waveform, (0, pad_len))
59
+
60
+ recon_chunks = []
61
+ with torch.no_grad():
62
+ for start in range(0, waveform.shape[1], chunk_size):
63
+ chunk = waveform[:, start:start + chunk_size].unsqueeze(0).to(device)
64
+ recon, _, _, _ = model(chunk)
65
+ recon = recon[..., :chunk_size]
66
+ recon_chunks.append(recon.cpu())
67
+
68
+ recon_full = torch.cat(recon_chunks, dim=-1)
69
+ recon_full = recon_full[0, :, :total_samples]
70
+
71
+ sf.write(output_path, recon_full[0].float().numpy(), sample_rate)
72
+ print(f"saved: {output_path} ({total_samples / sample_rate:.2f}s)")
73
+
74
+
75
+ def main():
76
+ parser = argparse.ArgumentParser(description="nano-codec inference")
77
+ parser.add_argument("--input", required=True, help="input wav file")
78
+ parser.add_argument("--output", default="reconstructed.wav", help="output wav file")
79
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
80
+ args = parser.parse_args()
81
+
82
+ model, cfg = load_model(device=args.device)
83
+ reconstruct(
84
+ model,
85
+ args.input,
86
+ args.output,
87
+ sample_rate=cfg['sample_rate'],
88
+ chunk_size=cfg['chunk_size'],
89
+ device=args.device,
90
+ )
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()