matbee commited on
Commit
820e270
·
verified ·
1 Parent(s): 56a0ef4

Upload onnx_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. onnx_inference.py +439 -0
onnx_inference.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAM Audio ONNX Runtime Inference Example
4
+
5
+ This script demonstrates how to use the exported ONNX models for audio source
6
+ separation inference. It shows the complete pipeline from text input to
7
+ separated audio output.
8
+
9
+ Usage:
10
+ python onnx_inference.py --audio input.wav --text "a person speaking"
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import numpy as np
16
+ import json
17
+ from typing import Optional
18
+
19
+
20
+ def load_audio(path: str, target_sr: int = 44100) -> np.ndarray:
21
+ """Load audio file and resample to target sample rate."""
22
+ try:
23
+ import librosa
24
+ audio, sr = librosa.load(path, sr=target_sr, mono=True)
25
+ return audio.astype(np.float32)
26
+ except ImportError:
27
+ raise ImportError("Please install librosa: pip install librosa")
28
+
29
+
30
+ def save_audio(audio: np.ndarray, path: str, sample_rate: int = 44100):
31
+ """Save audio to WAV file."""
32
+ try:
33
+ import soundfile as sf
34
+ sf.write(path, audio, sample_rate)
35
+ print(f"Saved audio to {path}")
36
+ except ImportError:
37
+ raise ImportError("Please install soundfile: pip install soundfile")
38
+
39
+
40
+ class SAMAudioONNXPipeline:
41
+ """
42
+ ONNX-based SAM Audio inference pipeline.
43
+
44
+ This class orchestrates all the ONNX models to perform audio source separation.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ model_dir: str = ".",
50
+ device: str = "cpu",
51
+ num_ode_steps: int = 16,
52
+ ):
53
+ import onnxruntime as ort
54
+
55
+ self.model_dir = model_dir
56
+ self.num_ode_steps = num_ode_steps
57
+ self.step_size = 1.0 / num_ode_steps
58
+
59
+ # Set up ONNX Runtime providers
60
+ if device == "cuda":
61
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
62
+ else:
63
+ providers = ["CPUExecutionProvider"]
64
+
65
+ # Load models
66
+ print("Loading ONNX models...")
67
+
68
+ self.dacvae_encoder = ort.InferenceSession(
69
+ os.path.join(model_dir, "dacvae_encoder.onnx"),
70
+ providers=providers,
71
+ )
72
+ print(" ✓ DACVAE encoder loaded")
73
+
74
+ self.dacvae_decoder = ort.InferenceSession(
75
+ os.path.join(model_dir, "dacvae_decoder.onnx"),
76
+ providers=providers,
77
+ )
78
+ print(" ✓ DACVAE decoder loaded")
79
+
80
+ self.t5_encoder = ort.InferenceSession(
81
+ os.path.join(model_dir, "t5_encoder.onnx"),
82
+ providers=providers,
83
+ )
84
+ print(" ✓ T5 encoder loaded")
85
+
86
+ self.dit = ort.InferenceSession(
87
+ os.path.join(model_dir, "dit_single_step.onnx"),
88
+ providers=providers,
89
+ )
90
+ print(" ✓ DiT denoiser loaded")
91
+
92
+ # Load tokenizer
93
+ self._load_tokenizer()
94
+ print(" ✓ Tokenizer loaded")
95
+
96
+ print("All models loaded!")
97
+
98
+ def _load_tokenizer(self):
99
+ """Load the T5 tokenizer."""
100
+ from transformers import AutoTokenizer
101
+
102
+ tokenizer_path = os.path.join(self.model_dir, "tokenizer")
103
+ if os.path.exists(tokenizer_path):
104
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
105
+ else:
106
+ # Fall back to loading from HuggingFace
107
+ with open(os.path.join(self.model_dir, "tokenizer_config.json")) as f:
108
+ config = json.load(f)
109
+ self.tokenizer = AutoTokenizer.from_pretrained(config.get("model_name", "google-t5/t5-base"))
110
+
111
+ def encode_audio(self, audio: np.ndarray) -> np.ndarray:
112
+ """
113
+ Encode audio waveform to latent features.
114
+
115
+ Args:
116
+ audio: Audio waveform, shape (samples,) or (1, 1, samples)
117
+
118
+ Returns:
119
+ Latent features, shape (1, latent_dim, time_steps)
120
+ """
121
+ # Ensure correct shape (batch, channels, samples)
122
+ if audio.ndim == 1:
123
+ audio = audio.reshape(1, 1, -1)
124
+ elif audio.ndim == 2:
125
+ audio = audio.reshape(1, *audio.shape)
126
+
127
+ outputs = self.dacvae_encoder.run(
128
+ ["latent_features"],
129
+ {"audio": audio.astype(np.float32)},
130
+ )
131
+ return outputs[0]
132
+
133
+ def decode_audio(self, latent: np.ndarray) -> np.ndarray:
134
+ """
135
+ Decode latent features to audio waveform.
136
+
137
+ Uses chunked decoding since the DACVAE decoder was exported with
138
+ fixed 25 time steps. Processes in chunks and concatenates.
139
+
140
+ Args:
141
+ latent: Latent features, shape (1, latent_dim, time_steps)
142
+
143
+ Returns:
144
+ Audio waveform, shape (samples,)
145
+ """
146
+ chunk_size = 25 # DACVAE decoder's fixed time step size
147
+ hop_length = 1920 # Samples per time step at 48kHz
148
+
149
+ _, _, time_steps = latent.shape
150
+
151
+ audio_chunks = []
152
+ for start_idx in range(0, time_steps, chunk_size):
153
+ end_idx = min(start_idx + chunk_size, time_steps)
154
+ chunk = latent[:, :, start_idx:end_idx]
155
+
156
+ # Pad last chunk if needed
157
+ actual_size = chunk.shape[2]
158
+ if actual_size < chunk_size:
159
+ pad_size = chunk_size - actual_size
160
+ chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant')
161
+
162
+ # Decode chunk
163
+ chunk_audio = self.dacvae_decoder.run(
164
+ ["waveform"],
165
+ {"latent_features": chunk.astype(np.float32)},
166
+ )[0]
167
+
168
+ # Trim padded output
169
+ if actual_size < chunk_size:
170
+ trim_samples = actual_size * hop_length
171
+ chunk_audio = chunk_audio[:, :, :trim_samples]
172
+
173
+ audio_chunks.append(chunk_audio)
174
+
175
+ # Concatenate all chunks
176
+ full_audio = np.concatenate(audio_chunks, axis=2)
177
+ return full_audio.squeeze()
178
+
179
+ def encode_text(self, text: str) -> tuple[np.ndarray, np.ndarray]:
180
+ """
181
+ Encode text prompt to features.
182
+
183
+ Args:
184
+ text: Text description of the audio to separate
185
+
186
+ Returns:
187
+ Tuple of (hidden_states, attention_mask)
188
+ """
189
+ tokens = self.tokenizer(
190
+ text,
191
+ return_tensors="np",
192
+ padding=True,
193
+ truncation=True,
194
+ max_length=77,
195
+ )
196
+
197
+ outputs = self.t5_encoder.run(
198
+ ["hidden_states"],
199
+ {
200
+ "input_ids": tokens["input_ids"].astype(np.int64),
201
+ "attention_mask": tokens["attention_mask"].astype(np.int64),
202
+ },
203
+ )
204
+
205
+ return outputs[0], tokens["attention_mask"]
206
+
207
+ def dit_step(
208
+ self,
209
+ noisy_audio: np.ndarray,
210
+ time: np.ndarray,
211
+ audio_features: np.ndarray,
212
+ text_features: np.ndarray,
213
+ text_mask: np.ndarray,
214
+ anchor_ids: Optional[np.ndarray] = None,
215
+ anchor_alignment: Optional[np.ndarray] = None,
216
+ audio_pad_mask: Optional[np.ndarray] = None,
217
+ ) -> np.ndarray:
218
+ """
219
+ Run one step of the DiT denoiser.
220
+
221
+ Args:
222
+ noisy_audio: Current noisy latent, shape (batch, seq_len, latent_dim*2)
223
+ time: Current time step, shape (batch,)
224
+ audio_features: Encoded audio features
225
+ text_features: Encoded text features
226
+ text_mask: Text attention mask
227
+ anchor_ids: Optional anchor IDs
228
+ anchor_alignment: Optional anchor alignment
229
+ audio_pad_mask: Optional audio padding mask
230
+
231
+ Returns:
232
+ Velocity prediction for ODE step
233
+ """
234
+ batch_size, seq_len = noisy_audio.shape[:2]
235
+
236
+ # Create default values for optional inputs
237
+ if anchor_ids is None:
238
+ anchor_ids = np.zeros((batch_size, seq_len), dtype=np.int64)
239
+ if anchor_alignment is None:
240
+ anchor_alignment = np.zeros((batch_size, seq_len), dtype=np.int64)
241
+ if audio_pad_mask is None:
242
+ audio_pad_mask = np.ones((batch_size, seq_len), dtype=bool)
243
+
244
+ # Video features are zeros for audio-only inference
245
+ vision_dim = 1024
246
+ masked_video_features = np.zeros(
247
+ (batch_size, vision_dim, seq_len), dtype=np.float32
248
+ )
249
+
250
+ outputs = self.dit.run(
251
+ ["velocity"],
252
+ {
253
+ "noisy_audio": noisy_audio.astype(np.float32),
254
+ "time": time.astype(np.float32),
255
+ "audio_features": audio_features.astype(np.float32),
256
+ "text_features": text_features.astype(np.float32),
257
+ "text_mask": text_mask.astype(bool),
258
+ "masked_video_features": masked_video_features,
259
+ "anchor_ids": anchor_ids,
260
+ "anchor_alignment": anchor_alignment,
261
+ "audio_pad_mask": audio_pad_mask,
262
+ },
263
+ )
264
+ return outputs[0]
265
+
266
+ def ode_solve_midpoint(
267
+ self,
268
+ initial: np.ndarray,
269
+ audio_features: np.ndarray,
270
+ text_features: np.ndarray,
271
+ text_mask: np.ndarray,
272
+ ) -> np.ndarray:
273
+ """
274
+ Solve the ODE using midpoint method.
275
+
276
+ This implements the same midpoint solver as the PyTorch version,
277
+ unrolled for ONNX Runtime inference.
278
+
279
+ Args:
280
+ initial: Initial noisy latent (usually zeros or noise)
281
+ audio_features: Encoded audio features
282
+ text_features: Encoded text features
283
+ text_mask: Text attention mask
284
+
285
+ Returns:
286
+ Final denoised latent
287
+ """
288
+ dt = self.step_size
289
+ x = initial.copy()
290
+
291
+ for i in range(self.num_ode_steps):
292
+ t = np.array([i * dt], dtype=np.float32)
293
+ t_mid = np.array([t[0] + dt / 2], dtype=np.float32)
294
+
295
+ # Midpoint method: k1 = f(t, x)
296
+ k1 = self.dit_step(x, t, audio_features, text_features, text_mask)
297
+
298
+ # Midpoint: x_mid = x + dt/2 * k1
299
+ x_mid = x + (dt / 2) * k1
300
+
301
+ # k2 = f(t + dt/2, x_mid)
302
+ k2 = self.dit_step(x_mid, t_mid, audio_features, text_features, text_mask)
303
+
304
+ # Update: x = x + dt * k2
305
+ x = x + dt * k2
306
+
307
+ print(f" ODE step {i+1}/{self.num_ode_steps}")
308
+
309
+ return x
310
+
311
+ def separate(
312
+ self,
313
+ audio: np.ndarray,
314
+ text: str,
315
+ sample_rate: int = 44100,
316
+ ) -> np.ndarray:
317
+ """
318
+ Perform audio source separation.
319
+
320
+ Args:
321
+ audio: Input audio waveform at 44.1kHz
322
+ text: Text description of the source to separate
323
+ sample_rate: Sample rate of input audio
324
+
325
+ Returns:
326
+ Separated audio waveform
327
+ """
328
+ print(f"\nSeparating: '{text}'")
329
+
330
+ # 1. Encode audio to latent space
331
+ print("1. Encoding audio...")
332
+ audio_latent = self.encode_audio(audio)
333
+ print(f" Audio latent shape: {audio_latent.shape}")
334
+
335
+ # 2. Encode text
336
+ print("2. Encoding text...")
337
+ text_features, text_mask = self.encode_text(text)
338
+ print(f" Text features shape: {text_features.shape}")
339
+
340
+ # 3. Prepare initial state and audio features
341
+ # SAMAudio._get_audio_features: returns torch.cat([audio_features, audio_features], dim=2)
342
+ batch_size, latent_dim, time_steps = audio_latent.shape
343
+ mixture_features = audio_latent.transpose(0, 2, 1) # (B, T, C=128)
344
+
345
+ # Audio features is mixture DUPLICATED (not [mixture, zeros]!)
346
+ audio_features = np.concatenate([
347
+ mixture_features, # Mixture latent
348
+ mixture_features # Mixture latent (DUPLICATE)
349
+ ], axis=-1) # -> (B, T, 256)
350
+
351
+ # Initial state is random noise for ODE solving from t=0 to t=1
352
+ initial = np.random.randn(batch_size, time_steps, latent_dim * 2).astype(np.float32)
353
+
354
+ # 4. Run ODE solver
355
+ print("3. Running ODE solver...")
356
+ result = self.ode_solve_midpoint(
357
+ initial, audio_features, text_features, text_mask
358
+ )
359
+
360
+ # 5. Extract separated audio latent
361
+ # SAMAudio: target is first 128 dims, residual is second 128 dims
362
+ target_latent = result[:, :, :latent_dim].transpose(0, 2, 1) # (B, C, T) - TARGET
363
+ separated_latent = target_latent
364
+ print(f" Separated latent shape: {separated_latent.shape}")
365
+
366
+ # 6. Decode to waveform
367
+ print("4. Decoding audio...")
368
+ separated_audio = self.decode_audio(separated_latent)
369
+ print(f" Output audio shape: {separated_audio.shape}")
370
+
371
+ return separated_audio
372
+
373
+
374
+ def main():
375
+ parser = argparse.ArgumentParser(
376
+ description="SAM Audio ONNX Runtime Inference"
377
+ )
378
+ parser.add_argument(
379
+ "--audio",
380
+ type=str,
381
+ required=True,
382
+ help="Path to input audio file",
383
+ )
384
+ parser.add_argument(
385
+ "--text",
386
+ type=str,
387
+ required=True,
388
+ help="Text description of the source to separate",
389
+ )
390
+ parser.add_argument(
391
+ "--output",
392
+ type=str,
393
+ default="separated.wav",
394
+ help="Path for output audio file",
395
+ )
396
+ parser.add_argument(
397
+ "--model-dir",
398
+ type=str,
399
+ default=".",
400
+ help="Directory containing ONNX models",
401
+ )
402
+ parser.add_argument(
403
+ "--device",
404
+ type=str,
405
+ default="cpu",
406
+ choices=["cpu", "cuda"],
407
+ help="Device to use for inference",
408
+ )
409
+ parser.add_argument(
410
+ "--ode-steps",
411
+ type=int,
412
+ default=16,
413
+ help="Number of ODE solver steps",
414
+ )
415
+
416
+ args = parser.parse_args()
417
+
418
+ # Load pipeline
419
+ pipeline = SAMAudioONNXPipeline(
420
+ model_dir=args.model_dir,
421
+ device=args.device,
422
+ num_ode_steps=args.ode_steps,
423
+ )
424
+
425
+ # Load input audio
426
+ print(f"\nLoading audio: {args.audio}")
427
+ audio = load_audio(args.audio, target_sr=44100)
428
+ print(f"Audio duration: {len(audio) / 44100:.2f} seconds")
429
+
430
+ # Run separation
431
+ separated = pipeline.separate(audio, args.text)
432
+
433
+ # Save output
434
+ save_audio(separated, args.output, sample_rate=44100)
435
+ print(f"\n✓ Done! Separated audio saved to {args.output}")
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()