matbee commited on
Commit
61531ed
·
verified ·
1 Parent(s): 0abf616

Upload onnx_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. onnx_inference.py +325 -205
onnx_inference.py CHANGED
@@ -17,26 +17,78 @@ import json
17
  from typing import Optional
18
 
19
 
20
- def load_audio(path: str, target_sr: int = 44100) -> np.ndarray:
21
- """Load audio file and resample to target sample rate."""
 
22
  try:
23
- import librosa
24
- audio, sr = librosa.load(path, sr=target_sr, mono=True)
25
- return audio.astype(np.float32)
26
- except ImportError:
27
- raise ImportError("Please install librosa: pip install librosa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
- def save_audio(audio: np.ndarray, path: str, sample_rate: int = 44100):
31
  """Save audio to WAV file."""
32
  try:
33
  import soundfile as sf
 
 
 
34
  sf.write(path, audio, sample_rate)
35
  print(f"Saved audio to {path}")
36
  except ImportError:
37
  raise ImportError("Please install soundfile: pip install soundfile")
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class SAMAudioONNXPipeline:
41
  """
42
  ONNX-based SAM Audio inference pipeline.
@@ -46,7 +98,7 @@ class SAMAudioONNXPipeline:
46
 
47
  def __init__(
48
  self,
49
- model_dir: str = ".",
50
  device: str = "cpu",
51
  num_ode_steps: int = 16,
52
  ):
@@ -89,6 +141,16 @@ class SAMAudioONNXPipeline:
89
  )
90
  print(" ✓ DiT denoiser loaded")
91
 
 
 
 
 
 
 
 
 
 
 
92
  # Load tokenizer
93
  self._load_tokenizer()
94
  print(" ✓ Tokenizer loaded")
@@ -96,17 +158,120 @@ class SAMAudioONNXPipeline:
96
  print("All models loaded!")
97
 
98
  def _load_tokenizer(self):
99
- """Load the T5 tokenizer."""
100
- from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- tokenizer_path = os.path.join(self.model_dir, "tokenizer")
103
- if os.path.exists(tokenizer_path):
104
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
105
- else:
106
- # Fall back to loading from HuggingFace
107
- with open(os.path.join(self.model_dir, "tokenizer_config.json")) as f:
108
- config = json.load(f)
109
- self.tokenizer = AutoTokenizer.from_pretrained(config.get("model_name", "google-t5/t5-base"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def encode_audio(self, audio: np.ndarray) -> np.ndarray:
112
  """
@@ -186,189 +351,141 @@ class SAMAudioONNXPipeline:
186
  Returns:
187
  Tuple of (hidden_states, attention_mask)
188
  """
189
- tokens = self.tokenizer(
190
- text,
191
- return_tensors="np",
192
- padding=True,
193
- truncation=True,
194
- max_length=77,
195
- )
196
 
197
  outputs = self.t5_encoder.run(
198
  ["hidden_states"],
199
  {
200
- "input_ids": tokens["input_ids"].astype(np.int64),
201
- "attention_mask": tokens["attention_mask"].astype(np.int64),
202
  },
203
  )
204
 
205
- return outputs[0], tokens["attention_mask"]
206
 
207
  def dit_step(
208
  self,
209
  noisy_audio: np.ndarray,
210
- time: np.ndarray,
211
  audio_features: np.ndarray,
212
  text_features: np.ndarray,
213
  text_mask: np.ndarray,
214
- anchor_ids: Optional[np.ndarray] = None,
215
- anchor_alignment: Optional[np.ndarray] = None,
216
- audio_pad_mask: Optional[np.ndarray] = None,
217
  ) -> np.ndarray:
218
- """
219
- Run one step of the DiT denoiser.
220
-
221
- Args:
222
- noisy_audio: Current noisy latent, shape (batch, seq_len, latent_dim*2)
223
- time: Current time step, shape (batch,)
224
- audio_features: Encoded audio features
225
- text_features: Encoded text features
226
- text_mask: Text attention mask
227
- anchor_ids: Optional anchor IDs
228
- anchor_alignment: Optional anchor alignment
229
- audio_pad_mask: Optional audio padding mask
 
 
 
 
 
 
 
 
230
 
231
- Returns:
232
- Velocity prediction for ODE step
233
- """
234
- batch_size, seq_len = noisy_audio.shape[:2]
235
-
236
- # Create default values for optional inputs
237
- if anchor_ids is None:
238
- anchor_ids = np.zeros((batch_size, seq_len), dtype=np.int64)
239
- if anchor_alignment is None:
240
- anchor_alignment = np.zeros((batch_size, seq_len), dtype=np.int64)
241
- if audio_pad_mask is None:
242
- audio_pad_mask = np.ones((batch_size, seq_len), dtype=bool)
243
-
244
- # Video features are zeros for audio-only inference
245
- vision_dim = 1024
246
- masked_video_features = np.zeros(
247
- (batch_size, vision_dim, seq_len), dtype=np.float32
248
- )
249
-
250
- outputs = self.dit.run(
251
- ["velocity"],
252
- {
253
- "noisy_audio": noisy_audio.astype(np.float32),
254
- "time": time.astype(np.float32),
255
- "audio_features": audio_features.astype(np.float32),
256
- "text_features": text_features.astype(np.float32),
257
- "text_mask": text_mask.astype(bool),
258
- "masked_video_features": masked_video_features,
259
- "anchor_ids": anchor_ids,
260
- "anchor_alignment": anchor_alignment,
261
- "audio_pad_mask": audio_pad_mask,
262
- },
263
- )
264
  return outputs[0]
265
-
266
- def ode_solve_midpoint(
267
- self,
268
- initial: np.ndarray,
269
- audio_features: np.ndarray,
270
- text_features: np.ndarray,
271
- text_mask: np.ndarray,
272
- ) -> np.ndarray:
273
- """
274
- Solve the ODE using midpoint method.
275
-
276
- This implements the same midpoint solver as the PyTorch version,
277
- unrolled for ONNX Runtime inference.
278
-
279
- Args:
280
- initial: Initial noisy latent (usually zeros or noise)
281
- audio_features: Encoded audio features
282
- text_features: Encoded text features
283
- text_mask: Text attention mask
284
-
285
- Returns:
286
- Final denoised latent
287
- """
288
- dt = self.step_size
289
- x = initial.copy()
290
-
291
- for i in range(self.num_ode_steps):
292
- t = np.array([i * dt], dtype=np.float32)
293
- t_mid = np.array([t[0] + dt / 2], dtype=np.float32)
294
-
295
- # Midpoint method: k1 = f(t, x)
296
- k1 = self.dit_step(x, t, audio_features, text_features, text_mask)
297
-
298
- # Midpoint: x_mid = x + dt/2 * k1
299
- x_mid = x + (dt / 2) * k1
300
-
301
- # k2 = f(t + dt/2, x_mid)
302
- k2 = self.dit_step(x_mid, t_mid, audio_features, text_features, text_mask)
303
-
304
- # Update: x = x + dt * k2
305
- x = x + dt * k2
306
-
307
- print(f" ODE step {i+1}/{self.num_ode_steps}")
308
-
309
- return x
310
 
311
  def separate(
312
- self,
313
- audio: np.ndarray,
314
  text: str,
315
- sample_rate: int = 44100,
316
- ) -> np.ndarray:
 
317
  """
318
- Perform audio source separation.
319
 
320
  Args:
321
- audio: Input audio waveform at 44.1kHz
322
- text: Text description of the source to separate
323
- sample_rate: Sample rate of input audio
 
324
 
325
  Returns:
326
- Separated audio waveform
327
  """
328
- print(f"\nSeparating: '{text}'")
329
-
330
- # 1. Encode audio to latent space
331
  print("1. Encoding audio...")
332
- audio_latent = self.encode_audio(audio)
333
- print(f" Audio latent shape: {audio_latent.shape}")
 
 
 
 
 
334
 
335
- # 2. Encode text
336
  print("2. Encoding text...")
337
  text_features, text_mask = self.encode_text(text)
338
  print(f" Text features shape: {text_features.shape}")
339
 
340
- # 3. Prepare initial state and audio features
341
- # SAMAudio._get_audio_features: returns torch.cat([audio_features, audio_features], dim=2)
342
- batch_size, latent_dim, time_steps = audio_latent.shape
343
- mixture_features = audio_latent.transpose(0, 2, 1) # (B, T, C=128)
344
-
345
- # Audio features is mixture DUPLICATED (not [mixture, zeros]!)
346
- audio_features = np.concatenate([
347
- mixture_features, # Mixture latent
348
- mixture_features # Mixture latent (DUPLICATE)
349
- ], axis=-1) # -> (B, T, 256)
 
 
 
 
 
 
350
 
351
- # Initial state is random noise for ODE solving from t=0 to t=1
352
- initial = np.random.randn(batch_size, time_steps, latent_dim * 2).astype(np.float32)
353
 
354
- # 4. Run ODE solver
355
- print("3. Running ODE solver...")
356
- result = self.ode_solve_midpoint(
357
- initial, audio_features, text_features, text_mask
358
- )
 
 
 
 
359
 
360
- # 5. Extract separated audio latent
361
- # SAMAudio: target is first 128 dims, residual is second 128 dims
362
- target_latent = result[:, :, :latent_dim].transpose(0, 2, 1) # (B, C, T) - TARGET
363
- separated_latent = target_latent
364
- print(f" Separated latent shape: {separated_latent.shape}")
365
 
366
  # 6. Decode to waveform
367
  print("4. Decoding audio...")
368
  separated_audio = self.decode_audio(separated_latent)
369
  print(f" Output audio shape: {separated_audio.shape}")
370
 
371
- return separated_audio
372
 
373
 
374
  def main():
@@ -378,61 +495,64 @@ def main():
378
  parser.add_argument(
379
  "--audio",
380
  type=str,
381
- required=True,
382
- help="Path to input audio file",
383
- )
384
- parser.add_argument(
385
- "--text",
386
- type=str,
387
- required=True,
388
- help="Text description of the source to separate",
389
- )
390
- parser.add_argument(
391
- "--output",
392
- type=str,
393
- default="separated.wav",
394
- help="Path for output audio file",
395
- )
396
- parser.add_argument(
397
- "--model-dir",
398
- type=str,
399
- default=".",
400
- help="Directory containing ONNX models",
401
- )
402
- parser.add_argument(
403
- "--device",
404
- type=str,
405
- default="cpu",
406
- choices=["cpu", "cuda"],
407
- help="Device to use for inference",
408
- )
409
- parser.add_argument(
410
- "--ode-steps",
411
- type=int,
412
- default=16,
413
- help="Number of ODE solver steps",
414
  )
 
 
 
 
 
 
 
 
415
 
416
  args = parser.parse_args()
417
 
418
- # Load pipeline
419
  pipeline = SAMAudioONNXPipeline(
420
  model_dir=args.model_dir,
421
  device=args.device,
422
- num_ode_steps=args.ode_steps,
423
  )
424
 
425
- # Load input audio
426
- print(f"\nLoading audio: {args.audio}")
427
- audio = load_audio(args.audio, target_sr=44100)
428
- print(f"Audio duration: {len(audio) / 44100:.2f} seconds")
429
 
430
- # Run separation
431
- separated = pipeline.separate(audio, args.text)
 
 
 
 
 
 
 
 
432
 
433
- # Save output
434
- save_audio(separated, args.output, sample_rate=44100)
435
- print(f"\n✓ Done! Separated audio saved to {args.output}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
 
438
  if __name__ == "__main__":
 
17
  from typing import Optional
18
 
19
 
20
+ def load_audio(path: str, target_sr: int = 48000) -> np.ndarray:
21
+ """Load audio file and resample to target sample rate. Supports video files via torchaudio/librosa."""
22
+ # Try torchaudio first as it handles video files well
23
  try:
24
+ import torchaudio
25
+ import torch
26
+ wav, sr = torchaudio.load(path)
27
+ if wav.shape[0] > 1:
28
+ wav = wav.mean(0, keepdim=True)
29
+ if sr != target_sr:
30
+ resampler = torchaudio.transforms.Resample(sr, target_sr)
31
+ wav = resampler(wav)
32
+ return wav.squeeze().numpy().astype(np.float32)
33
+ except Exception as e:
34
+ # Fallback to librosa
35
+ try:
36
+ import librosa
37
+ audio, sr = librosa.load(path, sr=target_sr, mono=True)
38
+ return audio.astype(np.float32)
39
+ except ImportError:
40
+ raise ImportError("Please install torchaudio or librosa: pip install torchaudio librosa")
41
+ except Exception as e2:
42
+ raise RuntimeError(f"Failed to load audio from {path}: {e2}")
43
 
44
 
45
+ def save_audio(audio: np.ndarray, path: str, sample_rate: int = 48000):
46
  """Save audio to WAV file."""
47
  try:
48
  import soundfile as sf
49
+ # Ensure audio is 1D for mono output
50
+ if audio.ndim > 1:
51
+ audio = audio.flatten()
52
  sf.write(path, audio, sample_rate)
53
  print(f"Saved audio to {path}")
54
  except ImportError:
55
  raise ImportError("Please install soundfile: pip install soundfile")
56
 
57
 
58
+ def save_video_with_audio(frames: np.ndarray, audio: np.ndarray, path: str, sample_rate: int = 48000, fps: float = 24.0):
59
+ """Save masked video frames and separated audio to a movie file."""
60
+ try:
61
+ import torch
62
+ import torchvision
63
+ import torchaudio
64
+
65
+ # frames is [T, C, H, W] in 0-255 or -1 to 1?
66
+ # load_video_frames returns [-1, 1], we want [0, 255]
67
+ frames_uint8 = ((frames * 0.5 + 0.5) * 255).astype(np.uint8)
68
+
69
+ # torchvision.io.write_video expects [T, H, W, C]
70
+ video_tensor = torch.from_numpy(frames_uint8).permute(0, 2, 3, 1)
71
+
72
+ # Prepare audio
73
+ if audio.ndim == 1:
74
+ audio = audio[None, :] # [1, Samples]
75
+ audio_tensor = torch.from_numpy(audio)
76
+
77
+ print(f"Saving merged video to {path}...")
78
+ torchvision.io.write_video(
79
+ path,
80
+ video_tensor,
81
+ fps=fps,
82
+ video_codec="libx264",
83
+ audio_array=audio_tensor,
84
+ audio_fps=sample_rate,
85
+ audio_codec="aac"
86
+ )
87
+ print(f" ✓ Video saved to {path}")
88
+ except Exception as e:
89
+ print(f"Warning: Failed to save video: {e}")
90
+
91
+
92
  class SAMAudioONNXPipeline:
93
  """
94
  ONNX-based SAM Audio inference pipeline.
 
98
 
99
  def __init__(
100
  self,
101
+ model_dir: str = "onnx_models",
102
  device: str = "cpu",
103
  num_ode_steps: int = 16,
104
  ):
 
141
  )
142
  print(" ✓ DiT denoiser loaded")
143
 
144
+ # Load Vision Encoder if available
145
+ self.vision_encoder = None
146
+ vision_path = os.path.join(model_dir, "vision_encoder.onnx")
147
+ if os.path.exists(vision_path):
148
+ self.vision_encoder = ort.InferenceSession(
149
+ vision_path,
150
+ providers=providers,
151
+ )
152
+ print(" ✓ Vision encoder loaded")
153
+
154
  # Load tokenizer
155
  self._load_tokenizer()
156
  print(" ✓ Tokenizer loaded")
 
158
  print("All models loaded!")
159
 
160
  def _load_tokenizer(self):
161
+ """
162
+ Load the T5 tokenizer using SentencePiece.
163
+ This avoids the dependency on the 'transformers' library.
164
+ """
165
+ try:
166
+ import sentencepiece as spm
167
+ except ImportError:
168
+ raise ImportError("Please install sentencepiece: pip install sentencepiece")
169
+
170
+ # Load the sentencepiece model file
171
+ sp_path = os.path.join(self.model_dir, "tokenizer", "spiece.model")
172
+ if not os.path.exists(sp_path):
173
+ sp_path = os.path.join(self.model_dir, "spiece.model")
174
+
175
+ if not os.path.exists(sp_path):
176
+ raise FileNotFoundError(f"SentencePiece model not found at {sp_path}")
177
+
178
+ # Create a T5-compatible tokenizer wrapper
179
+ class T5ONNXTokenizer:
180
+ def __init__(self, sp_path):
181
+ self.sp = spm.SentencePieceProcessor()
182
+ self.sp.load(sp_path)
183
+
184
+ def encode(self, text: str) -> np.ndarray:
185
+ ids = self.sp.encode(text)
186
+ if len(ids) > 0 and ids[-1] != 1: # Ensure </s> (ID 1)
187
+ ids.append(1)
188
+ elif len(ids) == 0:
189
+ ids = [1]
190
+ return np.array(ids, dtype=np.int64).reshape(1, -1)
191
+
192
+ def decode(self, tokens: np.ndarray) -> str:
193
+ if tokens.ndim > 1:
194
+ tokens = tokens.flatten()
195
+ return self.sp.decode(tokens.tolist())
196
+
197
+ self.tokenizer = T5ONNXTokenizer(sp_path)
198
+
199
+ def load_video_frames(self, path: str, num_steps: int, mask_path: Optional[str] = None) -> tuple[np.ndarray, np.ndarray, float]:
200
+ """
201
+ Load video frames and align them to audio latent steps.
202
+ Optionally applies a binary mask for visual prompting.
203
+ Returns (normalized_frames, visual_frames).
204
+ """
205
+ try:
206
+ from torchcodec.decoders import VideoDecoder
207
+ import torch
208
+ import torch.nn.functional as F
209
+ except ImportError:
210
+ raise ImportError("Please install torchcodec and torch: pip install torchcodec torch")
211
+
212
+ decoder = VideoDecoder(path, dimension_order="NCHW")
213
+ all_data = decoder.get_frames_in_range(0, len(decoder))
214
+
215
+ # Audio feature steps are aligned to timestamps
216
+ # SAM Audio DACVAE: 48kHz, rates [2, 8, 10, 12] -> hop_length = 1536
217
+ hop_length = 1536
218
+ sample_rate = 48000
219
+ step_timestamps = np.arange(num_steps) * hop_length / sample_rate
220
+
221
+ # Get actual video framerate
222
+ metadata = decoder.metadata
223
+ fps = metadata.average_fps if metadata.average_fps is not None else 24.0
224
+
225
+ # Find nearest frame for each step
226
+ diffs = np.abs(all_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :])
227
+ frame_idxs = np.argmin(diffs, axis=0)
228
+
229
+ frames = all_data.data[frame_idxs] # [num_steps, 3, H, W]
230
+
231
+ # Apply mask if provided (SAM3 style masking)
232
+ if mask_path:
233
+ print(f" Applying mask from {mask_path}...")
234
+ mask_decoder = VideoDecoder(mask_path, dimension_order="NCHW")
235
+ mask_data = mask_decoder.get_frames_in_range(0, len(mask_decoder))
236
+
237
+ # Align mask frames same as video frames
238
+ m_diffs = np.abs(mask_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :])
239
+ m_frame_idxs = np.argmin(m_diffs, axis=0)
240
+ masks = mask_data.data[m_frame_idxs] # [num_steps, C, H, W]
241
+
242
+ # Convert to binary mask (any non-zero is 1)
243
+ # In SAM Audio, masking means zeroing out the object: v * (mask == 0)
244
+ binary_mask = (masks.float().mean(dim=1, keepdim=True) > 128).float()
245
+ frames = frames.float() * (1.0 - binary_mask)
246
 
247
+ # Resize and normalize as per PerceptionEncoder
248
+ image_size = 336
249
+ frames_resized = F.interpolate(frames.float(), size=(image_size, image_size), mode="bicubic")
250
+ frames_norm = (frames_resized / 255.0 - 0.5) / 0.5
251
+
252
+ return frames_norm.numpy(), frames_norm.numpy(), fps
253
+
254
+ def encode_video(self, frames: np.ndarray) -> np.ndarray:
255
+ """Run vision encoder on framed images."""
256
+ if self.vision_encoder is None:
257
+ raise RuntimeError("Vision encoder model not loaded")
258
+
259
+ # Vision encoder might have hardcoded batch size 1 from export
260
+ # We run it in a loop for each frame to be safe
261
+ all_features = []
262
+ for i in range(len(frames)):
263
+ frame = frames[i:i+1] # [1, 3, H, W]
264
+ outputs = self.vision_encoder.run(
265
+ ["vision_features"],
266
+ {"video_frames": frame}
267
+ )
268
+ all_features.append(outputs[0]) # [1, 1024]
269
+
270
+ features = np.concatenate(all_features, axis=0) # [N, 1024]
271
+
272
+ # DiT expects (B, 1024, T)
273
+ return features.transpose(1, 0)[None, :, :]
274
+
275
 
276
  def encode_audio(self, audio: np.ndarray) -> np.ndarray:
277
  """
 
351
  Returns:
352
  Tuple of (hidden_states, attention_mask)
353
  """
354
+ input_ids = self.tokenizer.encode(text)
355
+ attention_mask = np.ones_like(input_ids)
 
 
 
 
 
356
 
357
  outputs = self.t5_encoder.run(
358
  ["hidden_states"],
359
  {
360
+ "input_ids": input_ids.astype(np.int64),
361
+ "attention_mask": attention_mask.astype(np.int64),
362
  },
363
  )
364
 
365
+ return outputs[0], attention_mask
366
 
367
  def dit_step(
368
  self,
369
  noisy_audio: np.ndarray,
370
+ time: float,
371
  audio_features: np.ndarray,
372
  text_features: np.ndarray,
373
  text_mask: np.ndarray,
374
+ masked_video_features: Optional[np.ndarray] = None,
 
 
375
  ) -> np.ndarray:
376
+ """Run a single DiT denoiser step."""
377
+ batch_size = noisy_audio.shape[0]
378
+ seq_len = noisy_audio.shape[1]
379
+
380
+ # Prepare placeholders for anchors if not used
381
+ # anchor_ids: <null>=0, <pad>=3. [B, 2]
382
+ anchor_ids = np.zeros((batch_size, 2), dtype=np.int64)
383
+ anchor_ids[:, 1] = 3
384
+
385
+ # anchor_alignment: 0 for active, 1 for pad. [B, T]
386
+ anchor_alignment = np.zeros((batch_size, seq_len), dtype=np.int64)
387
+
388
+ # audio_pad_mask: True/1 for valid, False/0 for pad. [B, T]
389
+ audio_pad_mask = np.ones((batch_size, seq_len), dtype=np.bool_)
390
+
391
+ # video features placeholder if not provided
392
+ if masked_video_features is None:
393
+ # Vision dimension is 1024 for small
394
+ vision_dim = 1024
395
+ masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=np.float32)
396
 
397
+ inputs = {
398
+ "noisy_audio": noisy_audio.astype(np.float32),
399
+ "time": np.array([time], dtype=np.float32),
400
+ "audio_features": audio_features.astype(np.float32),
401
+ "text_features": text_features.astype(np.float32),
402
+ "text_mask": text_mask.astype(np.bool_),
403
+ "masked_video_features": masked_video_features.astype(np.float32),
404
+ "anchor_ids": anchor_ids.astype(np.int64),
405
+ "anchor_alignment": anchor_alignment.astype(np.int64),
406
+ "audio_pad_mask": audio_pad_mask.astype(np.bool_),
407
+ }
408
+
409
+ outputs = self.dit.run(None, inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  return outputs[0]
411
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  def separate(
414
+ self,
415
+ audio: np.ndarray,
416
  text: str,
417
+ video_path: Optional[str] = None,
418
+ mask_path: Optional[str] = None
419
+ ) -> tuple[np.ndarray, Optional[np.ndarray], float]:
420
  """
421
+ Perform the full separation pipeline.
422
 
423
  Args:
424
+ audio: Input mixture waveform
425
+ text: Text description of the target source
426
+ video_path: Optional path to a video for visual conditioning
427
+ mask_path: Optional path to a video/image mask for visual prompting
428
 
429
  Returns:
430
+ Tuple of (Separated source waveform, Masked video frames if any, fps)
431
  """
432
+ # 1. Encode audio to latents
 
 
433
  print("1. Encoding audio...")
434
+ latent_features = self.encode_audio(audio)
435
+ # latent_features is (B, 128, T), DiT expects (B, T, 128)
436
+ latent_features = latent_features.transpose(0, 2, 1)
437
+
438
+ # Mixture features are duplicated (mixture, mixture) for conditioning
439
+ audio_features = np.concatenate([latent_features, latent_features], axis=2)
440
+ print(f" Audio latent shape: {latent_features.shape}")
441
 
442
+ # 2. Encode text to features
443
  print("2. Encoding text...")
444
  text_features, text_mask = self.encode_text(text)
445
  print(f" Text features shape: {text_features.shape}")
446
 
447
+ # 3. Encode video if provided
448
+ masked_video_features = None
449
+ visual_frames = None
450
+ fps = 24.0
451
+ if video_path and self.vision_encoder:
452
+ print("3a. Loading and encoding video...")
453
+ norm_frames, visual_frames, fps = self.load_video_frames(video_path, latent_features.shape[1], mask_path)
454
+ masked_video_features = self.encode_video(norm_frames) # This returns [B, 1024, T] (BCT)
455
+ print(f" Video features shape: {masked_video_features.shape}")
456
+
457
+ # 4. Run ODE solver (midpoint method)
458
+ print("3. Running ODE solver...")
459
+ # Start from random noise
460
+ # Note: audio_features is [B, T, 256], DiT output is [B, T, 256]
461
+ B, T, C = audio_features.shape
462
+ x = np.random.randn(B, T, C).astype(np.float32)
463
 
464
+ steps = self.num_ode_steps
465
+ dt = 1.0 / steps
466
 
467
+ for i in range(steps):
468
+ t = i * dt
469
+ print(f" ODE step {i+1}/{steps}", end="\r")
470
+
471
+ k1 = self.dit_step(x, t, audio_features, text_features, text_mask, masked_video_features)
472
+ x_mid = x + k1 * (dt / 2.0)
473
+ k2 = self.dit_step(x_mid, t + dt/2.0, audio_features, text_features, text_mask, masked_video_features)
474
+
475
+ x = x + k2 * dt
476
 
477
+ # Extract the target source (first 128 dimensions)
478
+ # The DiT model produces [B, T, 256] -> we want [B, T, 128]
479
+ separated_latent = x[:, :, :128].transpose(0, 2, 1) # Back to [B, 128, T] for decoder
480
+ print(f"\n Separated latent shape: {separated_latent.shape}")
481
+
482
 
483
  # 6. Decode to waveform
484
  print("4. Decoding audio...")
485
  separated_audio = self.decode_audio(separated_latent)
486
  print(f" Output audio shape: {separated_audio.shape}")
487
 
488
+ return separated_audio, visual_frames, fps
489
 
490
 
491
  def main():
 
495
  parser.add_argument(
496
  "--audio",
497
  type=str,
498
+ help="Path to input audio file (optional if --video is provided)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  )
500
+ parser.add_argument("--text", type=str, default="", help="Text description of the target source (optional if --video is provided)")
501
+ parser.add_argument("--video", type=str, help="Optional path to video file for conditional separation")
502
+ parser.add_argument("--mask", type=str, help="Optional path to mask file (visual prompting)")
503
+ parser.add_argument("--output", type=str, default="separated.wav", help="Output WAV file path")
504
+ parser.add_argument("--output-video", type=str, help="Optional path to save masked video with separated audio")
505
+ parser.add_argument("--model-dir", type=str, default="onnx_models", help="Directory containing ONNX models")
506
+ parser.add_argument("--steps", type=int, default=16, help="Number of ODE solver steps")
507
+ parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Inference device")
508
 
509
  args = parser.parse_args()
510
 
511
+ # 0. Initialize pipeline
512
  pipeline = SAMAudioONNXPipeline(
513
  model_dir=args.model_dir,
514
  device=args.device,
515
+ num_ode_steps=args.steps,
516
  )
517
 
518
+ # 1. Resolve audio/video paths
519
+ if not args.audio and not args.video:
520
+ parser.error("At least one of --audio or --video must be provided.")
 
521
 
522
+ # If no text is provided but a mask is, that's a pure visual prompt
523
+ if not args.text and not args.video:
524
+ parser.error("--text is required for audio-only separation.")
525
+
526
+ audio_path = args.audio if args.audio else args.video
527
+
528
+ # 1. Load audio
529
+ print(f"\nLoading audio from: {audio_path}")
530
+ audio = load_audio(audio_path, target_sr=48000)
531
+ print(f"Audio duration: {len(audio)/48000:.2f} seconds")
532
 
533
+ # 3. Run separation
534
+ try:
535
+ # Separate
536
+ separated_audio, masked_frames, fps = pipeline.separate(
537
+ audio,
538
+ args.text,
539
+ video_path=args.video if args.video else None,
540
+ mask_path=args.mask
541
+ )
542
+
543
+ # Save output audio
544
+ save_audio(separated_audio, args.output, sample_rate=48000)
545
+
546
+ # Save output video if requested
547
+ if args.output_video and masked_frames is not None:
548
+ save_video_with_audio(masked_frames, separated_audio, args.output_video, sample_rate=48000, fps=fps)
549
+
550
+ print(f"\n✓ Done! Separated audio saved to {args.output}")
551
+
552
+ except Exception as e:
553
+ print(f"\nError during separation: {e}")
554
+ import traceback
555
+ traceback.print_exc()
556
 
557
 
558
  if __name__ == "__main__":