Spaces:
Paused
Paused
| import os | |
| import tempfile | |
| import gradio as gr | |
| import torch | |
| from diffusers.utils import export_to_video | |
| from PIL import Image | |
| from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline | |
| # Global variable to store the pipeline | |
| pipe = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(model_path): | |
| """Load the CogVideoX-Interpolation model""" | |
| global pipe | |
| print(f"Loading model from {model_path}...") | |
| print(f"Using device: {device}") | |
| # Determine dtype based on model variant | |
| dtype = torch.bfloat16 if "5b" in model_path.lower() else torch.float16 | |
| pipe = CogVideoXInterpolationPipeline.from_pretrained(model_path, torch_dtype=dtype) | |
| # Memory optimization | |
| if device == "cuda": | |
| pipe.enable_sequential_cpu_offload() | |
| else: | |
| pipe = pipe.to(device) | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| print("Model loaded successfully!") | |
| return "✓ Model loaded successfully!" | |
| def generate_interpolation( | |
| first_image, | |
| last_image, | |
| prompt, | |
| num_frames=49, | |
| num_inference_steps=50, | |
| guidance_scale=6.0, | |
| fps=8, | |
| seed=42, | |
| ): | |
| """Generate interpolated video between two keyframes""" | |
| if pipe is None: | |
| return None, "⚠️ Please load the model first!" | |
| if first_image is None or last_image is None: | |
| return None, "⚠️ Please upload both start and end frame images!" | |
| if not prompt.strip(): | |
| return None, "⚠️ Please provide a text prompt describing the motion!" | |
| try: | |
| # Convert numpy arrays to PIL Images if needed | |
| if not isinstance(first_image, Image.Image): | |
| first_image = Image.fromarray(first_image) | |
| if not isinstance(last_image, Image.Image): | |
| last_image = Image.fromarray(last_image) | |
| print(f"Generating video with prompt: {prompt}") | |
| print( | |
| f"Parameters: frames={num_frames}, steps={num_inference_steps}, guidance={guidance_scale}" | |
| ) | |
| # Generate video | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| video = pipe( | |
| prompt=prompt, | |
| first_image=first_image, | |
| last_image=last_image, | |
| num_videos_per_prompt=1, | |
| num_inference_steps=num_inference_steps, | |
| num_frames=num_frames, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| )[0] | |
| # Export to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| output_path = temp_file.name | |
| temp_file.close() | |
| export_to_video(video, output_path, fps=fps) | |
| status = f"✓ Video generated successfully! ({num_frames} frames at {fps} fps)" | |
| print(status) | |
| return output_path, status | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Create Gradio interface | |
| with gr.Blocks(title="CogVideoX Keyframe Interpolation") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎬 CogVideoX Keyframe Interpolation | |
| Generate smooth video transitions between two keyframe images using AI. | |
| **Instructions:** | |
| 1. First, load the model by providing the path to your checkpoint | |
| 2. Upload start and end frame images | |
| 3. Describe the motion/transition in the text prompt | |
| 4. Adjust parameters and generate! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 🔧 Model Setup") | |
| model_path_input = gr.Textbox( | |
| label="Model Path", | |
| placeholder="e.g., /path/to/CogVideoX-5b-I2V-inter or feizhengcong/CogvideoX-Interpolation", | |
| value="feizhengcong/CogvideoX-Interpolation", | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| model_status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 🖼️ Input Keyframes") | |
| first_image_input = gr.Image(label="Start Frame", type="pil", height=300) | |
| last_image_input = gr.Image(label="End Frame", type="pil", height=300) | |
| with gr.Column(): | |
| gr.Markdown("### ⚙️ Generation Settings") | |
| prompt_input = gr.Textbox( | |
| label="Motion Description", | |
| placeholder="Describe the motion/transition between the frames...", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| num_frames_slider = gr.Slider( | |
| label="Number of Frames", | |
| minimum=13, | |
| maximum=49, | |
| step=4, | |
| value=49, | |
| info="Must be 4k+1 format (13, 17, 21, ..., 49)", | |
| ) | |
| fps_slider = gr.Slider( | |
| label="FPS", minimum=4, maximum=16, step=2, value=8 | |
| ) | |
| with gr.Row(): | |
| num_steps_slider = gr.Slider( | |
| label="Inference Steps", | |
| minimum=20, | |
| maximum=100, | |
| step=5, | |
| value=50, | |
| info="More steps = better quality but slower", | |
| ) | |
| guidance_slider = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=15.0, | |
| step=0.5, | |
| value=6.0, | |
| info="Higher = stronger prompt following", | |
| ) | |
| seed_input = gr.Number(label="Random Seed", value=42, precision=0) | |
| generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 🎥 Generated Video") | |
| output_video = gr.Video(label="Output") | |
| generation_status = gr.Textbox(label="Generation Status", interactive=False) | |
| # Examples | |
| gr.Markdown("---") | |
| gr.Markdown("### 💡 Example Prompts") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "A person walks forward slowly, their body moving naturally with each step." | |
| ], | |
| ["The camera smoothly pans from left to right, revealing the scene."], | |
| ["A dancer gracefully transitions from one pose to another."], | |
| ["The sun sets gradually, changing the lighting and colors of the scene."], | |
| ["A car accelerates down the street, moving from standstill to motion."], | |
| ], | |
| inputs=prompt_input, | |
| label="Click to use example prompts", | |
| ) | |
| # Event handlers | |
| load_btn.click(fn=load_model, inputs=[model_path_input], outputs=[model_status]) | |
| generate_btn.click( | |
| fn=generate_interpolation, | |
| inputs=[ | |
| first_image_input, | |
| last_image_input, | |
| prompt_input, | |
| num_frames_slider, | |
| num_steps_slider, | |
| guidance_slider, | |
| fps_slider, | |
| seed_input, | |
| ], | |
| outputs=[output_video, generation_status], | |
| ) | |
| if __name__ == "__main__": | |
| print("=" * 50) | |
| print("CogVideoX Keyframe Interpolation Gradio App") | |
| print("=" * 50) | |
| print(f"Device: {device}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print( | |
| f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" | |
| ) | |
| print("=" * 50) | |
| demo.launch() | |