Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -259,6 +259,12 @@ def initialize_models():
|
|
| 259 |
clip_image_encoder=clip_image_encoder,
|
| 260 |
)
|
| 261 |
pipeline.to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
print("β
Pipeline created and moved to device")
|
| 263 |
|
| 264 |
print("π Loading Wav2Vec models...")
|
|
@@ -343,6 +349,15 @@ def generate_video(
|
|
| 343 |
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
|
| 344 |
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
video_length = int(audio_clip.duration * fps)
|
| 347 |
video_length = (
|
| 348 |
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
|
|
@@ -405,7 +420,6 @@ def generate_video(
|
|
| 405 |
audio_start_frame = init_frames * 2
|
| 406 |
audio_end_frame = (init_frames + current_partial_length) * 2
|
| 407 |
|
| 408 |
-
# Ensure audio embeds are long enough
|
| 409 |
if audio_embeds.shape[1] < audio_end_frame:
|
| 410 |
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
|
| 411 |
audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
|
|
@@ -414,9 +428,9 @@ def generate_video(
|
|
| 414 |
|
| 415 |
with torch.no_grad():
|
| 416 |
sample = pipeline(
|
| 417 |
-
|
|
|
|
| 418 |
num_frames=current_partial_length,
|
| 419 |
-
negative_prompt=negative_prompt,
|
| 420 |
audio_embeds=partial_audio_embeds,
|
| 421 |
audio_scale=audio_scale,
|
| 422 |
ip_mask=ip_mask,
|
|
|
|
| 259 |
clip_image_encoder=clip_image_encoder,
|
| 260 |
)
|
| 261 |
pipeline.to(device=device)
|
| 262 |
+
|
| 263 |
+
if torch.__version__ >= "2.0":
|
| 264 |
+
print("π Compiling the pipeline with torch.compile()...")
|
| 265 |
+
pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True)
|
| 266 |
+
print("β
Pipeline transformer compiled!")
|
| 267 |
+
|
| 268 |
print("β
Pipeline created and moved to device")
|
| 269 |
|
| 270 |
print("π Loading Wav2Vec models...")
|
|
|
|
| 349 |
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
|
| 350 |
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
|
| 351 |
|
| 352 |
+
progress(0.25, desc="Encoding prompts...")
|
| 353 |
+
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
|
| 354 |
+
prompt,
|
| 355 |
+
device=device,
|
| 356 |
+
num_images_per_prompt=1,
|
| 357 |
+
do_classifier_free_guidance=(guidance_scale > 1.0),
|
| 358 |
+
negative_prompt=negative_prompt
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
video_length = int(audio_clip.duration * fps)
|
| 362 |
video_length = (
|
| 363 |
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
|
|
|
|
| 420 |
audio_start_frame = init_frames * 2
|
| 421 |
audio_end_frame = (init_frames + current_partial_length) * 2
|
| 422 |
|
|
|
|
| 423 |
if audio_embeds.shape[1] < audio_end_frame:
|
| 424 |
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
|
| 425 |
audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
|
|
|
|
| 428 |
|
| 429 |
with torch.no_grad():
|
| 430 |
sample = pipeline(
|
| 431 |
+
prompt_embeds=prompt_embeds,
|
| 432 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 433 |
num_frames=current_partial_length,
|
|
|
|
| 434 |
audio_embeds=partial_audio_embeds,
|
| 435 |
audio_scale=audio_scale,
|
| 436 |
ip_mask=ip_mask,
|