xinliu commited on
Commit
fa7872b
·
verified ·
1 Parent(s): bf7988a

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference_dnr_onnx.py +194 -0
  2. inference_onnx.py +99 -0
inference_dnr_onnx.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ US-DNR-003: Pure onnxruntime inference for TIGER-DnR (Dialog/Effect/Music separation).
4
+
5
+ Uses only onnxruntime + audio I/O, no look2hear import.
6
+ STFT/ISTFT performed in Python, separator network runs in ONNX.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import sys
12
+
13
+ import torch
14
+ import torchaudio
15
+ import numpy as np
16
+ import onnxruntime as ort
17
+
18
+
19
+ def load_audio(audio_path, target_sr=44100):
20
+ """Load and preprocess audio to 44.1kHz."""
21
+ waveform, sr = torchaudio.load(audio_path)
22
+
23
+ # Resample if needed
24
+ if sr != target_sr:
25
+ resampler = torchaudio.transforms.Resample(sr, target_sr)
26
+ waveform = resampler(waveform)
27
+
28
+ # Convert to mono if stereo
29
+ if waveform.shape[0] > 1:
30
+ waveform = waveform.mean(dim=0, keepdim=True)
31
+
32
+ return waveform, target_sr
33
+
34
+
35
+ def save_audio(audio_tensor, output_path, sample_rate=44100):
36
+ """Save audio tensor to file."""
37
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
38
+ torchaudio.save(output_path, audio_tensor, sample_rate)
39
+
40
+
41
+ def onnx_separate(onnx_path, audio_tensor, win=2048, stride=512):
42
+ """
43
+ Separate audio using ONNX model.
44
+
45
+ Args:
46
+ onnx_path: Path to ONNX separator model
47
+ audio_tensor: [C, T] audio tensor
48
+ win: STFT window size
49
+ stride: STFT hop length
50
+
51
+ Returns:
52
+ Tuple of (dialog, effect, music) tensors, each [C, T]
53
+ """
54
+ # Create ONNX session
55
+ sess_options = ort.SessionOptions()
56
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
57
+
58
+ # Try CUDA first, fallback to CPU
59
+ providers = []
60
+ if 'CUDAExecutionProvider' in ort.get_available_providers():
61
+ providers.append('CUDAExecutionProvider')
62
+ print(f"[inference_dnr_onnx] Using CUDAExecutionProvider")
63
+ else:
64
+ providers.append('CPUExecutionProvider')
65
+ print(f"[inference_dnr_onnx] Using CPUExecutionProvider")
66
+
67
+ session = ort.InferenceSession(onnx_path, sess_options, providers=providers)
68
+
69
+ # Ensure [C, T] shape
70
+ if audio_tensor.ndim == 1:
71
+ audio_tensor = audio_tensor.unsqueeze(0)
72
+
73
+ nch = audio_tensor.shape[0]
74
+ original_length = audio_tensor.shape[-1]
75
+ audio_flat = audio_tensor.view(-1) # Flatten to [nch*T]
76
+
77
+ # Compute STFT
78
+ print(f"[inference_dnr_onnx] Computing STFT...")
79
+ window = torch.hann_window(win).type(audio_flat.dtype)
80
+ spec = torch.stft(
81
+ audio_flat,
82
+ n_fft=win,
83
+ hop_length=stride,
84
+ window=window,
85
+ return_complex=True
86
+ ) # [F, T_frames]
87
+
88
+ # Extract real and imaginary parts
89
+ spec_real = spec.real.unsqueeze(0).numpy() # [1, F, T_frames]
90
+ spec_imag = spec.imag.unsqueeze(0).numpy() # [1, F, T_frames]
91
+ print(f"[inference_dnr_onnx] STFT shape: {spec_real.shape}")
92
+
93
+ # Run ONNX inference
94
+ print(f"[inference_dnr_onnx] Running ONNX separator...")
95
+ outputs = session.run(
96
+ None,
97
+ {
98
+ 'spec_real': spec_real.astype(np.float32),
99
+ 'spec_imag': spec_imag.astype(np.float32)
100
+ }
101
+ )
102
+
103
+ # outputs: [dialog_real, dialog_imag, effect_real, effect_imag, music_real, music_imag]
104
+ dialog_real, dialog_imag, effect_real, effect_imag, music_real, music_imag = outputs
105
+
106
+ # Convert back to complex spectrograms
107
+ dialog_spec = torch.complex(
108
+ torch.from_numpy(dialog_real).squeeze(0),
109
+ torch.from_numpy(dialog_imag).squeeze(0)
110
+ )
111
+ effect_spec = torch.complex(
112
+ torch.from_numpy(effect_real).squeeze(0),
113
+ torch.from_numpy(effect_imag).squeeze(0)
114
+ )
115
+ music_spec = torch.complex(
116
+ torch.from_numpy(music_real).squeeze(0),
117
+ torch.from_numpy(music_imag).squeeze(0)
118
+ )
119
+
120
+ # ISTFT to get time-domain signals
121
+ print(f"[inference_dnr_onnx] Computing ISTFT...")
122
+ dialog = torch.istft(
123
+ dialog_spec,
124
+ n_fft=win,
125
+ hop_length=stride,
126
+ window=window,
127
+ length=original_length
128
+ )
129
+ effect = torch.istft(
130
+ effect_spec,
131
+ n_fft=win,
132
+ hop_length=stride,
133
+ window=window,
134
+ length=original_length
135
+ )
136
+ music = torch.istft(
137
+ music_spec,
138
+ n_fft=win,
139
+ hop_length=stride,
140
+ window=window,
141
+ length=original_length
142
+ )
143
+
144
+ # Reshape to [C, T]
145
+ dialog = dialog.view(nch, -1)
146
+ effect = effect.view(nch, -1)
147
+ music = music.view(nch, -1)
148
+
149
+ return dialog, effect, music
150
+
151
+
152
+ def main():
153
+ parser = argparse.ArgumentParser(description="TIGER-DnR ONNX inference (no look2hear)")
154
+ parser.add_argument("--audio_path", default="test/test_mixture_466.wav", help="Input audio file")
155
+ parser.add_argument("--output_dir", default="separated_audio_dnr_onnx", help="Output directory")
156
+ parser.add_argument("--onnx_path", default="onnx/tiger_dnr_separator.onnx", help="ONNX model path")
157
+ args = parser.parse_args()
158
+
159
+ print(f"[inference_dnr_onnx] TIGER-DnR ONNX Inference")
160
+ print(f"[inference_dnr_onnx] Input: {args.audio_path}")
161
+ print(f"[inference_dnr_onnx] Output: {args.output_dir}")
162
+ print(f"[inference_dnr_onnx] Model: {args.onnx_path}")
163
+
164
+ # Check inputs
165
+ if not os.path.exists(args.audio_path):
166
+ print(f"[inference_dnr_onnx] ERROR: Audio file not found: {args.audio_path}")
167
+ sys.exit(1)
168
+
169
+ if not os.path.exists(args.onnx_path):
170
+ print(f"[inference_dnr_onnx] ERROR: ONNX model not found: {args.onnx_path}")
171
+ sys.exit(1)
172
+
173
+ # Load audio
174
+ print(f"[inference_dnr_onnx] Loading audio...")
175
+ audio, sr = load_audio(args.audio_path)
176
+ print(f"[inference_dnr_onnx] Audio shape: {audio.shape}, sample rate: {sr}")
177
+
178
+ # Separate
179
+ dialog, effect, music = onnx_separate(args.onnx_path, audio)
180
+
181
+ # Save outputs
182
+ print(f"[inference_dnr_onnx] Saving separated audio...")
183
+ save_audio(dialog, os.path.join(args.output_dir, "dialog.wav"), sr)
184
+ save_audio(effect, os.path.join(args.output_dir, "effect.wav"), sr)
185
+ save_audio(music, os.path.join(args.output_dir, "music.wav"), sr)
186
+
187
+ print(f"[inference_dnr_onnx] Saved dialog.wav")
188
+ print(f"[inference_dnr_onnx] Saved effect.wav")
189
+ print(f"[inference_dnr_onnx] Saved music.wav")
190
+ print(f"[inference_dnr_onnx] SUCCESS")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
inference_onnx.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import torch
8
+ import torchaudio
9
+ import torchaudio.transforms as T
10
+
11
+
12
+ TARGET_SR = 16000
13
+ CHUNK_LEN = TARGET_SR * 4 # must match dummy length in export_onnx.py
14
+
15
+
16
+ def parse_args():
17
+ p = argparse.ArgumentParser(description="Pure onnxruntime TIGER-speech inference.")
18
+ p.add_argument("--audio_path", default="test/mix.wav",
19
+ help="Path to mixture wav.")
20
+ p.add_argument("--output_dir", default="separated_audio_onnx",
21
+ help="Directory to save separated spkN.wav files.")
22
+ p.add_argument("--onnx_path", default="onnx/tiger_speech.onnx",
23
+ help="Exported ONNX model (from export_onnx.py).")
24
+ return p.parse_args()
25
+
26
+
27
+ def load_audio(audio_path):
28
+ waveform, original_sr = torchaudio.load(audio_path)
29
+ print(f"Loaded {audio_path}: sr={original_sr}, shape={tuple(waveform.shape)}")
30
+ if original_sr != TARGET_SR:
31
+ print(f"Resampling {original_sr} Hz -> {TARGET_SR} Hz")
32
+ waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
33
+ if waveform.dim() == 1:
34
+ waveform = waveform.unsqueeze(0)
35
+ if waveform.shape[0] > 1:
36
+ print(f"Downmixing {waveform.shape[0]} channels -> mono")
37
+ waveform = waveform.mean(dim=0, keepdim=True)
38
+ return waveform # [1, T]
39
+
40
+
41
+ def build_session(onnx_path):
42
+ available = ort.get_available_providers()
43
+ if "CUDAExecutionProvider" in available:
44
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
45
+ else:
46
+ providers = ["CPUExecutionProvider"]
47
+ sess = ort.InferenceSession(onnx_path, providers=providers)
48
+ chosen = sess.get_providers()[0]
49
+ print(f"onnxruntime provider: {chosen}")
50
+ return sess
51
+
52
+
53
+ def run_chunks(sess, mono_wave):
54
+ in_name = sess.get_inputs()[0].name
55
+ out_name = sess.get_outputs()[0].name
56
+ total = mono_wave.shape[-1]
57
+ outputs = []
58
+ for start in range(0, total, CHUNK_LEN):
59
+ end = min(start + CHUNK_LEN, total)
60
+ chunk = mono_wave[:, start:end]
61
+ pad = CHUNK_LEN - chunk.shape[-1]
62
+ if pad > 0:
63
+ chunk = torch.nn.functional.pad(chunk, (0, pad))
64
+ x = chunk.unsqueeze(0).contiguous().numpy().astype(np.float32) # [1,1,CHUNK_LEN]
65
+ y = sess.run([out_name], {in_name: x})[0] # [1, num_spk, CHUNK_LEN]
66
+ if pad > 0:
67
+ y = y[..., : CHUNK_LEN - pad]
68
+ outputs.append(y[0]) # [num_spk, chunk_len]
69
+ return np.concatenate(outputs, axis=-1) # [num_spk, total]
70
+
71
+
72
+ def main():
73
+ args = parse_args()
74
+
75
+ if not os.path.isfile(args.audio_path):
76
+ print(f"ERROR: audio not found: {args.audio_path}")
77
+ sys.exit(1)
78
+ if not os.path.isfile(args.onnx_path):
79
+ print(f"ERROR: onnx not found: {args.onnx_path}")
80
+ sys.exit(1)
81
+
82
+ waveform = load_audio(args.audio_path) # [1, T]
83
+ print(f"Preprocessed shape: {tuple(waveform.shape)} (16 kHz mono)")
84
+
85
+ sess = build_session(args.onnx_path)
86
+ estimates = run_chunks(sess, waveform) # [num_spk, T]
87
+ num_spk = estimates.shape[0]
88
+ print(f"Separation complete: num_spk={num_spk}, samples={estimates.shape[-1]}")
89
+
90
+ os.makedirs(args.output_dir, exist_ok=True)
91
+ for i in range(num_spk):
92
+ out_path = os.path.join(args.output_dir, f"spk{i+1}.wav")
93
+ track = torch.from_numpy(estimates[i]).unsqueeze(0)
94
+ torchaudio.save(out_path, track, TARGET_SR)
95
+ print(f"Saved spk{i+1} -> {out_path}")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()