prismaudio-project commited on
Commit
8be4220
·
1 Parent(s): 6a864cd
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -246,30 +246,32 @@ def extract_video_frames(video_path: str):
246
 
247
 
248
  # ==================== Feature Extraction ====================
249
-
250
- def extract_features(clip_chunk: torch.Tensor, sync_chunk: torch.Tensor, caption: str) -> dict:
251
- """Reuses globally loaded FeaturesUtils — no reload per call."""
252
  model = _MODELS["feature_extractor"]
253
- assert model is not None, "FeaturesUtils not initialized."
254
-
255
  info = {}
256
  with torch.no_grad():
 
 
257
  text_features = model.encode_t5_text([caption])
258
  info['text_features'] = text_features[0].cpu()
 
259
 
260
- clip_input = torch.from_numpy(clip_chunk).unsqueeze(0)
261
- video_feat, frame_embed, _, text_feat = \
262
- model.encode_video_and_text_with_videoprism(clip_input, [caption])
263
 
264
- info['global_video_features'] = torch.tensor(np.array(video_feat)).squeeze(0).cpu()
265
- info['video_features'] = torch.tensor(np.array(frame_embed)).squeeze(0).cpu()
266
- info['global_text_features'] = torch.tensor(np.array(text_feat)).squeeze(0).cpu()
267
-
268
- sync_input = sync_chunk.unsqueeze(0).to(DEVICE)
269
  info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu()
 
270
 
271
  return info
272
 
 
 
 
 
 
 
 
273
 
274
  # ==================== Build Meta ====================
275
 
@@ -288,7 +290,7 @@ def build_meta(info: dict, duration: float, caption: str):
288
 
289
 
290
  # ==================== Diffusion Sampling ====================
291
-
292
  def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor:
293
  """Reuses globally loaded diffusion model — no reload per call."""
294
  from PrismAudio.inference.sampling import sample, sample_discrete_euler
@@ -296,20 +298,22 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
296
 
297
  diffusion = _MODELS["diffusion"]
298
  model_config = _MODELS["model_config"]
 
 
299
  assert diffusion is not None, "Diffusion model not initialized."
300
 
301
  diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"]
302
  latent_length = round(SAMPLE_RATE * duration / 2048)
303
 
304
  meta_on_device = {
305
- k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
306
  for k, v in meta.items()
307
  }
308
  metadata = (meta_on_device,)
309
 
310
  with torch.no_grad():
311
  with torch.amp.autocast('cuda'):
312
- conditioning = diffusion.conditioner(metadata, DEVICE)
313
 
314
  video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0)
315
  if 'metaclip_features' in conditioning:
@@ -320,7 +324,7 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
320
  diffusion.model.model.empty_sync_feat
321
 
322
  cond_inputs = diffusion.get_conditioning_inputs(conditioning)
323
- noise = torch.randn([1, diffusion.io_channels, latent_length]).to(DEVICE)
324
 
325
  with torch.amp.autocast('cuda'):
326
  if diffusion_objective == "v":
@@ -339,6 +343,7 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
339
  if diffusion.pretransform is not None:
340
  fakes = diffusion.pretransform.decode(fakes)
341
 
 
342
  return (
343
  fakes.to(torch.float32)
344
  .div(torch.max(torch.abs(fakes)))
@@ -351,14 +356,9 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
351
 
352
  # ==================== Full Inference Pipeline ====================
353
 
354
- @spaces.GPU(duration=120)
355
  def generate_audio_core(video_file, caption):
356
 
357
- global DEVICE
358
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
359
- _MODELS["feature_extractor"].to(DEVICE)
360
- _MODELS["diffusion"].to(DEVICE)
361
-
362
  total_start_time = time.time()
363
 
364
  if video_file is None:
 
246
 
247
 
248
  # ==================== Feature Extraction ====================
249
+ @spaces.GPU
250
+ def extract_features_gpu(clip_chunk, sync_chunk, caption):
 
251
  model = _MODELS["feature_extractor"]
 
 
252
  info = {}
253
  with torch.no_grad():
254
+
255
+ model.t5_model.to('cuda')
256
  text_features = model.encode_t5_text([caption])
257
  info['text_features'] = text_features[0].cpu()
258
+ model.t5_model.to('cpu')
259
 
 
 
 
260
 
261
+ model.synchformer.to('cuda')
262
+ sync_input = sync_chunk.unsqueeze(0).to('cuda')
 
 
 
263
  info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu()
264
+ model.synchformer.to('cpu')
265
 
266
  return info
267
 
268
+ def extract_features(clip_chunk, sync_chunk, caption):
269
+
270
+ info = extract_features_cpu(clip_chunk, sync_chunk, caption)
271
+
272
+ info.update(extract_features_gpu(clip_chunk, sync_chunk, caption))
273
+ return info
274
+
275
 
276
  # ==================== Build Meta ====================
277
 
 
290
 
291
 
292
  # ==================== Diffusion Sampling ====================
293
+ @spaces.GPU
294
  def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor:
295
  """Reuses globally loaded diffusion model — no reload per call."""
296
  from PrismAudio.inference.sampling import sample, sample_discrete_euler
 
298
 
299
  diffusion = _MODELS["diffusion"]
300
  model_config = _MODELS["model_config"]
301
+ device = 'cuda'
302
+ diffusion.to("cuda")
303
  assert diffusion is not None, "Diffusion model not initialized."
304
 
305
  diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"]
306
  latent_length = round(SAMPLE_RATE * duration / 2048)
307
 
308
  meta_on_device = {
309
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
310
  for k, v in meta.items()
311
  }
312
  metadata = (meta_on_device,)
313
 
314
  with torch.no_grad():
315
  with torch.amp.autocast('cuda'):
316
+ conditioning = diffusion.conditioner(metadata, device)
317
 
318
  video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0)
319
  if 'metaclip_features' in conditioning:
 
324
  diffusion.model.model.empty_sync_feat
325
 
326
  cond_inputs = diffusion.get_conditioning_inputs(conditioning)
327
+ noise = torch.randn([1, diffusion.io_channels, latent_length]).to(device)
328
 
329
  with torch.amp.autocast('cuda'):
330
  if diffusion_objective == "v":
 
343
  if diffusion.pretransform is not None:
344
  fakes = diffusion.pretransform.decode(fakes)
345
 
346
+ diffusion.to('cpu')
347
  return (
348
  fakes.to(torch.float32)
349
  .div(torch.max(torch.abs(fakes)))
 
356
 
357
  # ==================== Full Inference Pipeline ====================
358
 
359
+
360
  def generate_audio_core(video_file, caption):
361
 
 
 
 
 
 
362
  total_start_time = time.time()
363
 
364
  if video_file is None: