Vicente Alvarez commited on
Commit
dcbdf35
·
1 Parent(s): 9c32fea

Add multi-clip generation with audio looping + high res 1024x640: Generate 1-3 clips, loop to match audio duration (CPU work free)

Browse files
Files changed (2) hide show
  1. app.py +209 -85
  2. requirements.txt +2 -1
app.py CHANGED
@@ -103,7 +103,7 @@ DEFAULT_FRAME_RATE = 24.0
103
 
104
  # Resolution presets: (width, height)
105
  RESOLUTIONS = {
106
- "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
107
  "low": {"16:9": (512, 320), "9:16": (320, 512), "1:1": (512, 512)},
108
  }
109
 
@@ -329,6 +329,73 @@ def apply_gaussian_blur(video_tensor: torch.Tensor, blur_amount: int) -> torch.T
329
  return blurred
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  @spaces.GPU(duration=90)
333
  @torch.inference_mode()
334
  def generate_video(
@@ -344,100 +411,152 @@ def generate_video(
344
  negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
345
  blur_amount: int = 0,
346
  remove_music: bool = False,
 
347
  progress=gr.Progress(track_tqdm=True),
348
  ):
349
  try:
350
  torch.cuda.reset_peak_memory_stats()
351
  log_memory("start")
352
 
353
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
354
-
355
- frame_rate = DEFAULT_FRAME_RATE
356
- num_frames = int(duration * frame_rate) + 1
357
- num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
358
-
359
- print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
360
-
361
- images = []
362
- output_dir = Path("outputs")
363
- output_dir.mkdir(exist_ok=True)
364
-
365
- if first_image is not None:
366
- temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
367
- if hasattr(first_image, "save"):
368
- first_image.save(temp_first_path)
369
- else:
370
- temp_first_path = Path(first_image)
371
- images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))
372
-
373
- if last_image is not None:
374
- temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
375
- if hasattr(last_image, "save"):
376
- last_image.save(temp_last_path)
377
- else:
378
- temp_last_path = Path(last_image)
379
- images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
380
-
381
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
382
-
383
- tiling_config = TilingConfig.default()
384
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
385
-
386
- log_memory("before pipeline call")
387
-
388
- # Run inference - DistilledPipeline has simpler API
389
- video_frames_iter, audio = pipeline(
390
- prompt=prompt,
391
- seed=current_seed,
392
- height=int(height),
393
- width=int(width),
394
- num_frames=num_frames,
395
- frame_rate=frame_rate,
396
- images=images,
397
- enhance_prompt=enhance_prompt,
398
- )
399
-
400
- # Collect video frames
401
- frames = [frame for frame in video_frames_iter]
402
- video_tensor = torch.cat(frames, dim=0) if len(frames) > 1 else frames[0]
403
-
404
- log_memory("after pipeline call")
405
-
406
- # Apply Gaussian blur if requested (for censoring/teaser effect)
407
- if blur_amount > 0:
408
- print(f"Applying Gaussian blur (amount={blur_amount})...")
409
- video_tensor = apply_gaussian_blur(video_tensor, blur_amount)
410
- log_memory("after blur")
411
-
412
- output_path = tempfile.mktemp(suffix=".mp4")
413
- encode_video(
414
- video=video_tensor,
415
- fps=frame_rate,
416
- audio=audio,
417
- output_path=output_path,
418
- video_chunks_number=video_chunks_number,
419
- )
420
-
421
- log_memory("after encode_video")
422
-
423
- # Remove background music if requested
424
- if remove_music:
425
- print(f"Removing background music with Demucs...")
426
- processed_path = tempfile.mktemp(suffix=".mp4")
427
- success = remove_music_demucs(output_path, processed_path)
428
- if success:
429
- output_path = processed_path
430
- log_memory("after demucs")
431
- else:
432
- print(f"Warning: Music removal failed, using original video")
433
-
434
- return str(output_path), current_seed
 
 
 
 
 
 
 
 
 
435
 
436
  except Exception as e:
437
  import traceback
438
  log_memory("on error")
439
  print(f"Error: {str(e)}\n{traceback.format_exc()}")
440
- return None, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
 
443
  with gr.Blocks(title="Element-8 Video", delete_cache=(3600, 7200)) as demo: # cleanup: check every 1h, delete files >2h old
@@ -462,6 +581,10 @@ with gr.Blocks(title="Element-8 Video", delete_cache=(3600, 7200)) as demo: # c
462
  )
463
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
464
 
 
 
 
 
465
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
466
 
467
  with gr.Accordion("Advanced Settings", open=False):
@@ -530,10 +653,11 @@ with gr.Blocks(title="Element-8 Video", delete_cache=(3600, 7200)) as demo: # c
530
  )
531
 
532
  generate_btn.click(
533
- fn=generate_video,
534
  inputs=[
535
  first_image, last_image, prompt, duration, enhance_prompt,
536
  seed, randomize_seed, height, width, negative_prompt, blur_amount, remove_music,
 
537
  ],
538
  outputs=[output_video, seed],
539
  )
 
103
 
104
  # Resolution presets: (width, height)
105
  RESOLUTIONS = {
106
+ "high": {"16:9": (1024, 640), "9:16": (640, 1024), "1:1": (1024, 1024)},
107
  "low": {"16:9": (512, 320), "9:16": (320, 512), "1:1": (512, 512)},
108
  }
109
 
 
329
  return blurred
330
 
331
 
332
+ def loop_clips_with_audio_track(clip_paths: list[str], audio_path: str) -> str:
333
+ """Loop video clips to match audio duration. CPU work - free."""
334
+ import subprocess
335
+ from pydub import AudioSegment
336
+
337
+ try:
338
+ # Get audio duration
339
+ audio = AudioSegment.from_file(audio_path)
340
+ audio_duration = len(audio) / 1000.0 # Convert to seconds
341
+
342
+ # Get total clips duration
343
+ clips_duration = 0.0
344
+ for clip in clip_paths:
345
+ probe = subprocess.run([
346
+ 'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
347
+ '-of', 'default=noprint_wrappers=1:nokey=1', clip
348
+ ], capture_output=True, text=True, check=True)
349
+ clips_duration += float(probe.stdout.strip())
350
+
351
+ # Calculate loop count
352
+ loop_count = int(audio_duration / clips_duration) + 1
353
+
354
+ print(f"[loop] Audio: {audio_duration:.2f}s, Clips: {clips_duration:.2f}s, Loops: {loop_count}")
355
+
356
+ # Create concat file with loops
357
+ concat_file = tempfile.mktemp(suffix=".txt")
358
+ with open(concat_file, 'w') as f:
359
+ for _ in range(loop_count):
360
+ for clip in clip_paths:
361
+ f.write(f"file '{clip}'\n")
362
+
363
+ # Concat videos
364
+ concat_video = tempfile.mktemp(suffix=".mp4")
365
+ result = subprocess.run([
366
+ 'ffmpeg', '-y', '-f', 'concat', '-safe', '0', '-i', concat_file,
367
+ '-c', 'copy', concat_video
368
+ ], capture_output=True, text=True)
369
+
370
+ if result.returncode != 0:
371
+ raise Exception(f"Concat failed: {result.stderr[-200:]}")
372
+
373
+ # Replace audio and trim to audio duration
374
+ final_video = tempfile.mktemp(suffix=".mp4")
375
+ result = subprocess.run([
376
+ 'ffmpeg', '-y',
377
+ '-i', concat_video,
378
+ '-i', audio_path,
379
+ '-map', '0:v:0', '-map', '1:a:0',
380
+ '-c:v', 'copy', '-c:a', 'aac', '-b:a', '192k',
381
+ '-t', str(audio_duration),
382
+ '-shortest',
383
+ final_video
384
+ ], capture_output=True, text=True)
385
+
386
+ if result.returncode != 0:
387
+ raise Exception(f"Audio merge failed: {result.stderr[-200:]}")
388
+
389
+ print(f"[loop] Created looped video: {audio_duration:.2f}s")
390
+ return final_video
391
+
392
+ except Exception as e:
393
+ print(f"[loop] Error: {e}")
394
+ import traceback
395
+ traceback.print_exc()
396
+ return clip_paths[0] if clip_paths else None
397
+
398
+
399
  @spaces.GPU(duration=90)
400
  @torch.inference_mode()
401
  def generate_video(
 
411
  negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
412
  blur_amount: int = 0,
413
  remove_music: bool = False,
414
+ num_clips: int = 1,
415
  progress=gr.Progress(track_tqdm=True),
416
  ):
417
  try:
418
  torch.cuda.reset_peak_memory_stats()
419
  log_memory("start")
420
 
421
+ base_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
422
+ generated_clips = []
423
+
424
+ # Generate multiple clips in one GPU session
425
+ for clip_idx in range(num_clips):
426
+ current_seed = base_seed + clip_idx
427
+ print(f"[GPU] Generating clip {clip_idx + 1}/{num_clips}, seed={current_seed}")
428
+
429
+ frame_rate = DEFAULT_FRAME_RATE
430
+ num_frames = int(duration * frame_rate) + 1
431
+ num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
432
+
433
+ print(f"Generating: {height}x{width}, {num_frames} frames ({duration}s), seed={current_seed}")
434
+
435
+ images = []
436
+ output_dir = Path("outputs")
437
+ output_dir.mkdir(exist_ok=True)
438
+
439
+ if first_image is not None:
440
+ temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
441
+ if hasattr(first_image, "save"):
442
+ first_image.save(temp_first_path)
443
+ else:
444
+ temp_first_path = Path(first_image)
445
+ images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))
446
+
447
+ if last_image is not None:
448
+ temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
449
+ if hasattr(last_image, "save"):
450
+ last_image.save(temp_last_path)
451
+ else:
452
+ temp_last_path = Path(last_image)
453
+ images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
454
+
455
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
456
+
457
+ tiling_config = TilingConfig.default()
458
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
459
+
460
+ log_memory("before pipeline call")
461
+
462
+ # Run inference - DistilledPipeline has simpler API
463
+ video_frames_iter, audio = pipeline(
464
+ prompt=prompt,
465
+ seed=current_seed,
466
+ height=int(height),
467
+ width=int(width),
468
+ num_frames=num_frames,
469
+ frame_rate=frame_rate,
470
+ images=images,
471
+ enhance_prompt=enhance_prompt,
472
+ )
473
+
474
+ # Collect video frames
475
+ frames = [frame for frame in video_frames_iter]
476
+ video_tensor = torch.cat(frames, dim=0) if len(frames) > 1 else frames[0]
477
+
478
+ log_memory("after pipeline call")
479
+
480
+ # Apply Gaussian blur if requested (for censoring/teaser effect)
481
+ if blur_amount > 0:
482
+ print(f"Applying Gaussian blur (amount={blur_amount})...")
483
+ video_tensor = apply_gaussian_blur(video_tensor, blur_amount)
484
+ log_memory("after blur")
485
+
486
+ output_path = tempfile.mktemp(suffix=".mp4")
487
+ encode_video(
488
+ video=video_tensor,
489
+ fps=frame_rate,
490
+ audio=audio,
491
+ output_path=output_path,
492
+ video_chunks_number=video_chunks_number,
493
+ )
494
+
495
+ log_memory("after encode_video")
496
+
497
+ # Remove background music if requested
498
+ if remove_music:
499
+ print(f"Removing background music with Demucs...")
500
+ processed_path = tempfile.mktemp(suffix=".mp4")
501
+ success = remove_music_demucs(output_path, processed_path)
502
+ if success:
503
+ output_path = processed_path
504
+ log_memory("after demucs")
505
+ else:
506
+ print(f"Warning: Music removal failed, using original video")
507
+
508
+ generated_clips.append(str(output_path))
509
+
510
+ # Return all generated clips
511
+ return generated_clips, base_seed
512
 
513
  except Exception as e:
514
  import traceback
515
  log_memory("on error")
516
  print(f"Error: {str(e)}\n{traceback.format_exc()}")
517
+ return [], base_seed
518
+
519
+
520
+ def full_generation_process(
521
+ first_image,
522
+ last_image,
523
+ prompt: str,
524
+ duration: float,
525
+ enhance_prompt: bool,
526
+ seed: int,
527
+ randomize_seed: bool,
528
+ height: int,
529
+ width: int,
530
+ negative_prompt: str,
531
+ blur_amount: int,
532
+ remove_music: bool,
533
+ num_clips: int,
534
+ audio_track,
535
+ progress=gr.Progress(track_tqdm=True),
536
+ ):
537
+ """Main entry point: generates clips (GPU) then optionally loops with audio (CPU)."""
538
+
539
+ # Phase 1: Generate clips (GPU time counted)
540
+ clips, final_seed = generate_video(
541
+ first_image, last_image, prompt, duration, enhance_prompt,
542
+ seed, randomize_seed, height, width, negative_prompt,
543
+ blur_amount, remove_music, num_clips, progress
544
+ )
545
+
546
+ if not clips:
547
+ return None, final_seed
548
+
549
+ # Phase 2: CPU work (free) - loop clips with audio if provided
550
+ if audio_track and num_clips > 1:
551
+ print("[CPU] Looping clips to match audio duration...")
552
+ final_video = loop_clips_with_audio_track(clips, audio_track)
553
+ return final_video, final_seed
554
+ elif num_clips == 1:
555
+ # Single clip - return it directly
556
+ return clips[0], final_seed
557
+ else:
558
+ # Multiple clips, no audio - return first clip (could be gallery in future)
559
+ return clips[0], final_seed
560
 
561
 
562
  with gr.Blocks(title="Element-8 Video", delete_cache=(3600, 7200)) as demo: # cleanup: check every 1h, delete files >2h old
 
581
  )
582
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
583
 
584
+ with gr.Row():
585
+ num_clips = gr.Slider(label="Number of Clips", info="Generate multiple variations", minimum=1, maximum=3, value=1, step=1)
586
+ audio_track = gr.Audio(label="Audio Track (Optional)", type="filepath", sources=["upload"])
587
+
588
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
589
 
590
  with gr.Accordion("Advanced Settings", open=False):
 
653
  )
654
 
655
  generate_btn.click(
656
+ fn=full_generation_process,
657
  inputs=[
658
  first_image, last_image, prompt, duration, enhance_prompt,
659
  seed, randomize_seed, height, width, negative_prompt, blur_amount, remove_music,
660
+ num_clips, audio_track,
661
  ],
662
  outputs=[output_video, seed],
663
  )
requirements.txt CHANGED
@@ -11,4 +11,5 @@ scikit-image>=0.25.2
11
  flashpack==0.1.2
12
  torchaudio==2.8.0
13
  demucs
14
- soundfile
 
 
11
  flashpack==0.1.2
12
  torchaudio==2.8.0
13
  demucs
14
+ soundfile
15
+ pydub