CogVideoXInterp / app.py
AhmadMustafa's picture
add: demo
2fa4732
raw
history blame
7.64 kB
import os
import tempfile
import gradio as gr
import torch
from diffusers.utils import export_to_video
from PIL import Image
from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline
# Global variable to store the pipeline
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_path):
"""Load the CogVideoX-Interpolation model"""
global pipe
print(f"Loading model from {model_path}...")
print(f"Using device: {device}")
# Determine dtype based on model variant
dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16
pipe = CogVideoXInterpolationPipeline.from_pretrained(model_path, torch_dtype=dtype)
# Memory optimization
if device == "cuda":
pipe.enable_sequential_cpu_offload()
else:
pipe = pipe.to(device)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
print("Model loaded successfully!")
return "✓ Model loaded successfully!"
def generate_interpolation(
first_image,
last_image,
prompt,
num_frames=49,
num_inference_steps=50,
guidance_scale=6.0,
fps=8,
seed=42,
):
"""Generate interpolated video between two keyframes"""
if pipe is None:
return None, "⚠️ Please load the model first!"
if first_image is None or last_image is None:
return None, "⚠️ Please upload both start and end frame images!"
if not prompt.strip():
return None, "⚠️ Please provide a text prompt describing the motion!"
try:
# Convert numpy arrays to PIL Images if needed
if not isinstance(first_image, Image.Image):
first_image = Image.fromarray(first_image)
if not isinstance(last_image, Image.Image):
last_image = Image.fromarray(last_image)
print(f"Generating video with prompt: {prompt}")
print(
f"Parameters: frames={num_frames}, steps={num_inference_steps}, guidance={guidance_scale}"
)
# Generate video
generator = torch.Generator(device=device).manual_seed(seed)
video = pipe(
prompt=prompt,
first_image=first_image,
last_image=last_image,
num_videos_per_prompt=1,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
guidance_scale=guidance_scale,
generator=generator,
)[0]
# Export to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
output_path = temp_file.name
temp_file.close()
export_to_video(video, output_path, fps=fps)
status = f"✓ Video generated successfully! ({num_frames} frames at {fps} fps)"
print(status)
return output_path, status
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
return None, error_msg
# Create Gradio interface
with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo:
gr.Markdown(
"""
# 🎬 CogVideoX Keyframe Interpolation
Generate smooth video transitions between two keyframe images using AI.
**Instructions:**
1. First, load the model by providing the path to your checkpoint
2. Upload start and end frame images
3. Describe the motion/transition in the text prompt
4. Adjust parameters and generate!
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("### 🔧 Model Setup")
model_path_input = gr.Textbox(
label="Model Path",
placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation",
value="feizhengcong/CogvideoX-Interpolation",
)
load_btn = gr.Button("Load Model", variant="primary")
model_status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("---")
with gr.Row():
with gr.Column():
gr.Markdown("### 🖼️ Input Keyframes")
first_image_input = gr.Image(label="Start Frame", type="pil", height=300)
last_image_input = gr.Image(label="End Frame", type="pil", height=300)
with gr.Column():
gr.Markdown("### ⚙️ Generation Settings")
prompt_input = gr.Textbox(
label="Motion Description",
placeholder="Describe the motion/transition between the frames...",
lines=4,
)
with gr.Row():
num_frames_slider = gr.Slider(
label="Number of Frames",
minimum=13,
maximum=49,
step=4,
value=49,
info="Must be 4k+1 format (13, 17, 21, ..., 49)",
)
fps_slider = gr.Slider(
label="FPS", minimum=4, maximum=16, step=2, value=8
)
with gr.Row():
num_steps_slider = gr.Slider(
label="Inference Steps",
minimum=20,
maximum=100,
step=5,
value=50,
info="More steps = better quality but slower",
)
guidance_slider = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=15.0,
step=0.5,
value=6.0,
info="Higher = stronger prompt following",
)
seed_input = gr.Number(label="Random Seed", value=42, precision=0)
generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
gr.Markdown("---")
with gr.Row():
with gr.Column():
gr.Markdown("### 🎥 Generated Video")
output_video = gr.Video(label="Output")
generation_status = gr.Textbox(label="Generation Status", interactive=False)
# Examples
gr.Markdown("---")
gr.Markdown("### 💡 Example Prompts")
gr.Examples(
examples=[
[
"A person walks forward slowly, their body moving naturally with each step."
],
["The camera smoothly pans from left to right, revealing the scene."],
["A dancer gracefully transitions from one pose to another."],
["The sun sets gradually, changing the lighting and colors of the scene."],
["A car accelerates down the street, moving from standstill to motion."],
],
inputs=prompt_input,
label="Click to use example prompts",
)
# Event handlers
load_btn.click(fn=load_model, inputs=[model_path_input], outputs=[model_status])
generate_btn.click(
fn=generate_interpolation,
inputs=[
first_image_input,
last_image_input,
prompt_input,
num_frames_slider,
num_steps_slider,
guidance_slider,
fps_slider,
seed_input,
],
outputs=[output_video, generation_status],
)
if __name__ == "__main__":
print("=" * 50)
print("CogVideoX Keyframe Interpolation Gradio App")
print("=" * 50)
print(f"Device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(
f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
)
print("=" * 50)
demo.launch()