Spaces:
Running on Zero
Running on Zero
prismaudio-project commited on
Commit ·
8be4220
1
Parent(s): 6a864cd
fix
Browse files
app.py
CHANGED
|
@@ -246,30 +246,32 @@ def extract_video_frames(video_path: str):
|
|
| 246 |
|
| 247 |
|
| 248 |
# ==================== Feature Extraction ====================
|
| 249 |
-
|
| 250 |
-
def
|
| 251 |
-
"""Reuses globally loaded FeaturesUtils — no reload per call."""
|
| 252 |
model = _MODELS["feature_extractor"]
|
| 253 |
-
assert model is not None, "FeaturesUtils not initialized."
|
| 254 |
-
|
| 255 |
info = {}
|
| 256 |
with torch.no_grad():
|
|
|
|
|
|
|
| 257 |
text_features = model.encode_t5_text([caption])
|
| 258 |
info['text_features'] = text_features[0].cpu()
|
|
|
|
| 259 |
|
| 260 |
-
clip_input = torch.from_numpy(clip_chunk).unsqueeze(0)
|
| 261 |
-
video_feat, frame_embed, _, text_feat = \
|
| 262 |
-
model.encode_video_and_text_with_videoprism(clip_input, [caption])
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
info['global_text_features'] = torch.tensor(np.array(text_feat)).squeeze(0).cpu()
|
| 267 |
-
|
| 268 |
-
sync_input = sync_chunk.unsqueeze(0).to(DEVICE)
|
| 269 |
info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu()
|
|
|
|
| 270 |
|
| 271 |
return info
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
# ==================== Build Meta ====================
|
| 275 |
|
|
@@ -288,7 +290,7 @@ def build_meta(info: dict, duration: float, caption: str):
|
|
| 288 |
|
| 289 |
|
| 290 |
# ==================== Diffusion Sampling ====================
|
| 291 |
-
|
| 292 |
def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor:
|
| 293 |
"""Reuses globally loaded diffusion model — no reload per call."""
|
| 294 |
from PrismAudio.inference.sampling import sample, sample_discrete_euler
|
|
@@ -296,20 +298,22 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
|
|
| 296 |
|
| 297 |
diffusion = _MODELS["diffusion"]
|
| 298 |
model_config = _MODELS["model_config"]
|
|
|
|
|
|
|
| 299 |
assert diffusion is not None, "Diffusion model not initialized."
|
| 300 |
|
| 301 |
diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"]
|
| 302 |
latent_length = round(SAMPLE_RATE * duration / 2048)
|
| 303 |
|
| 304 |
meta_on_device = {
|
| 305 |
-
k: v.to(
|
| 306 |
for k, v in meta.items()
|
| 307 |
}
|
| 308 |
metadata = (meta_on_device,)
|
| 309 |
|
| 310 |
with torch.no_grad():
|
| 311 |
with torch.amp.autocast('cuda'):
|
| 312 |
-
conditioning = diffusion.conditioner(metadata,
|
| 313 |
|
| 314 |
video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0)
|
| 315 |
if 'metaclip_features' in conditioning:
|
|
@@ -320,7 +324,7 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
|
|
| 320 |
diffusion.model.model.empty_sync_feat
|
| 321 |
|
| 322 |
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
| 323 |
-
noise = torch.randn([1, diffusion.io_channels, latent_length]).to(
|
| 324 |
|
| 325 |
with torch.amp.autocast('cuda'):
|
| 326 |
if diffusion_objective == "v":
|
|
@@ -339,6 +343,7 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
|
|
| 339 |
if diffusion.pretransform is not None:
|
| 340 |
fakes = diffusion.pretransform.decode(fakes)
|
| 341 |
|
|
|
|
| 342 |
return (
|
| 343 |
fakes.to(torch.float32)
|
| 344 |
.div(torch.max(torch.abs(fakes)))
|
|
@@ -351,14 +356,9 @@ def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> to
|
|
| 351 |
|
| 352 |
# ==================== Full Inference Pipeline ====================
|
| 353 |
|
| 354 |
-
|
| 355 |
def generate_audio_core(video_file, caption):
|
| 356 |
|
| 357 |
-
global DEVICE
|
| 358 |
-
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 359 |
-
_MODELS["feature_extractor"].to(DEVICE)
|
| 360 |
-
_MODELS["diffusion"].to(DEVICE)
|
| 361 |
-
|
| 362 |
total_start_time = time.time()
|
| 363 |
|
| 364 |
if video_file is None:
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
# ==================== Feature Extraction ====================
|
| 249 |
+
@spaces.GPU
|
| 250 |
+
def extract_features_gpu(clip_chunk, sync_chunk, caption):
|
|
|
|
| 251 |
model = _MODELS["feature_extractor"]
|
|
|
|
|
|
|
| 252 |
info = {}
|
| 253 |
with torch.no_grad():
|
| 254 |
+
|
| 255 |
+
model.t5_model.to('cuda')
|
| 256 |
text_features = model.encode_t5_text([caption])
|
| 257 |
info['text_features'] = text_features[0].cpu()
|
| 258 |
+
model.t5_model.to('cpu')
|
| 259 |
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
model.synchformer.to('cuda')
|
| 262 |
+
sync_input = sync_chunk.unsqueeze(0).to('cuda')
|
|
|
|
|
|
|
|
|
|
| 263 |
info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu()
|
| 264 |
+
model.synchformer.to('cpu')
|
| 265 |
|
| 266 |
return info
|
| 267 |
|
| 268 |
+
def extract_features(clip_chunk, sync_chunk, caption):
|
| 269 |
+
|
| 270 |
+
info = extract_features_cpu(clip_chunk, sync_chunk, caption)
|
| 271 |
+
|
| 272 |
+
info.update(extract_features_gpu(clip_chunk, sync_chunk, caption))
|
| 273 |
+
return info
|
| 274 |
+
|
| 275 |
|
| 276 |
# ==================== Build Meta ====================
|
| 277 |
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
# ==================== Diffusion Sampling ====================
|
| 293 |
+
@spaces.GPU
|
| 294 |
def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor:
|
| 295 |
"""Reuses globally loaded diffusion model — no reload per call."""
|
| 296 |
from PrismAudio.inference.sampling import sample, sample_discrete_euler
|
|
|
|
| 298 |
|
| 299 |
diffusion = _MODELS["diffusion"]
|
| 300 |
model_config = _MODELS["model_config"]
|
| 301 |
+
device = 'cuda'
|
| 302 |
+
diffusion.to("cuda")
|
| 303 |
assert diffusion is not None, "Diffusion model not initialized."
|
| 304 |
|
| 305 |
diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"]
|
| 306 |
latent_length = round(SAMPLE_RATE * duration / 2048)
|
| 307 |
|
| 308 |
meta_on_device = {
|
| 309 |
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 310 |
for k, v in meta.items()
|
| 311 |
}
|
| 312 |
metadata = (meta_on_device,)
|
| 313 |
|
| 314 |
with torch.no_grad():
|
| 315 |
with torch.amp.autocast('cuda'):
|
| 316 |
+
conditioning = diffusion.conditioner(metadata, device)
|
| 317 |
|
| 318 |
video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0)
|
| 319 |
if 'metaclip_features' in conditioning:
|
|
|
|
| 324 |
diffusion.model.model.empty_sync_feat
|
| 325 |
|
| 326 |
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
| 327 |
+
noise = torch.randn([1, diffusion.io_channels, latent_length]).to(device)
|
| 328 |
|
| 329 |
with torch.amp.autocast('cuda'):
|
| 330 |
if diffusion_objective == "v":
|
|
|
|
| 343 |
if diffusion.pretransform is not None:
|
| 344 |
fakes = diffusion.pretransform.decode(fakes)
|
| 345 |
|
| 346 |
+
diffusion.to('cpu')
|
| 347 |
return (
|
| 348 |
fakes.to(torch.float32)
|
| 349 |
.div(torch.max(torch.abs(fakes)))
|
|
|
|
| 356 |
|
| 357 |
# ==================== Full Inference Pipeline ====================
|
| 358 |
|
| 359 |
+
|
| 360 |
def generate_audio_core(video_file, caption):
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
total_start_time = time.time()
|
| 363 |
|
| 364 |
if video_file is None:
|