| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler |
| | from PIL import Image |
| | import numpy as np |
| | from typing import List, Optional, Dict, Any |
| | from collections import deque |
| | import cv2 |
| | import os |
| | import tempfile |
| | import imageio |
| | from datetime import datetime |
| |
|
| | class SimpleTemporalBuffer: |
| | """Simplified temporal buffer for SD1.5 img2img""" |
| | |
| | def __init__(self, buffer_size: int = 6): |
| | self.buffer_size = buffer_size |
| | self.frames = deque(maxlen=buffer_size) |
| | self.frame_embeddings = deque(maxlen=buffer_size) |
| | self.motion_vectors = deque(maxlen=buffer_size-1) |
| | |
| | def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None): |
| | """Add frame to buffer""" |
| | try: |
| | |
| | if len(self.frames) > 0: |
| | prev_frame = np.array(self.frames[-1]) |
| | curr_frame = np.array(frame) |
| | |
| | |
| | prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY) |
| | curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY) |
| | |
| | |
| | flow = cv2.calcOpticalFlowPyrLK( |
| | prev_gray, curr_gray, |
| | np.array([[frame.width//2, frame.height//2]], dtype=np.float32), |
| | None |
| | )[0] |
| | |
| | if flow is not None: |
| | motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2]) |
| | self.motion_vectors.append(motion_magnitude) |
| | except Exception as e: |
| | print(f"Motion calculation error: {e}") |
| | |
| | self.frames.append(frame) |
| | if embedding is not None: |
| | self.frame_embeddings.append(embedding) |
| | |
| | def get_reference_frame(self) -> Optional[Image.Image]: |
| | """Get most recent frame as reference""" |
| | return self.frames[-1] if self.frames else None |
| | |
| | def get_motion_context(self) -> Dict[str, Any]: |
| | """Get motion context for next frame generation""" |
| | if len(self.motion_vectors) == 0: |
| | return {"has_motion": False, "predicted_motion": 0.0} |
| | |
| | |
| | recent_motion = list(self.motion_vectors)[-3:] |
| | avg_motion = np.mean(recent_motion) |
| | motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0 |
| | |
| | predicted_motion = avg_motion + motion_trend * 0.5 |
| | |
| | return { |
| | "has_motion": True, |
| | "current_motion": avg_motion, |
| | "predicted_motion": predicted_motion, |
| | "motion_trend": motion_trend, |
| | "motion_history": recent_motion |
| | } |
| |
|
| | class SD15FlexibleI2VGenerator: |
| | """Flexible I2V generator using SD1.5 img2img pipeline""" |
| | |
| | def __init__( |
| | self, |
| | model_id: str = "runwayml/stable-diffusion-v1-5", |
| | device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| | ): |
| | self.device = device |
| | self.pipe = None |
| | self.temporal_buffer = SimpleTemporalBuffer() |
| | self.is_loaded = False |
| | |
| | def load_model(self): |
| | """Load the SD1.5 pipeline""" |
| | if self.is_loaded: |
| | return "Model already loaded" |
| | |
| | try: |
| | print(f"π Loading SD1.5 pipeline on {self.device}...") |
| | |
| | |
| | self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
| | safety_checker=None, |
| | requires_safety_checker=False |
| | ) |
| | |
| | |
| | self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
| | self.pipe = self.pipe.to(self.device) |
| | |
| | |
| | if self.device == "cuda": |
| | self.pipe.enable_attention_slicing() |
| | try: |
| | self.pipe.enable_xformers_memory_efficient_attention() |
| | except: |
| | print("β οΈ xformers not available, using standard attention") |
| | |
| | self.is_loaded = True |
| | return "β
Model loaded successfully!" |
| | |
| | except Exception as e: |
| | return f"β Error loading model: {str(e)}" |
| | |
| | def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float: |
| | """Calculate adaptive denoising strength based on motion""" |
| | if not motion_context.get("has_motion", False): |
| | return base_strength |
| | |
| | motion = motion_context["current_motion"] |
| | |
| | |
| | |
| | motion_factor = np.clip(motion / 50.0, 0.0, 1.0) |
| | adaptive_strength = base_strength * (1.0 - motion_factor * 0.3) |
| | |
| | return np.clip(adaptive_strength, 0.3, 0.9) |
| | |
| | def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str: |
| | """Enhance prompt based on motion context""" |
| | if not motion_context.get("has_motion", False): |
| | return base_prompt |
| | |
| | motion = motion_context["current_motion"] |
| | trend = motion_context.get("motion_trend", 0) |
| | |
| | |
| | if motion > 30: |
| | if trend > 5: |
| | motion_desc = ", fast movement, dynamic motion, motion blur" |
| | else: |
| | motion_desc = ", steady movement, continuous motion" |
| | elif motion > 10: |
| | motion_desc = ", gentle movement, smooth transition" |
| | else: |
| | motion_desc = ", subtle movement, slight change" |
| | |
| | return base_prompt + motion_desc |
| | |
| | def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image: |
| | """Blend current frame with reference for temporal consistency""" |
| | current_array = np.array(current_frame, dtype=np.float32) |
| | reference_array = np.array(reference_frame, dtype=np.float32) |
| | |
| | |
| | blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio |
| | blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) |
| | |
| | return Image.fromarray(blended_array) |
| | |
| | @torch.no_grad() |
| | def generate_frame_batch( |
| | self, |
| | init_image: Image.Image, |
| | prompt: str, |
| | num_frames: int = 1, |
| | strength: float = 0.75, |
| | guidance_scale: float = 7.5, |
| | num_inference_steps: int = 20, |
| | generator: Optional[torch.Generator] = None, |
| | progress_callback=None |
| | ) -> List[Image.Image]: |
| | """Generate a batch of frames using img2img""" |
| | |
| | if not self.is_loaded: |
| | raise ValueError("Model not loaded. Please load the model first.") |
| | |
| | frames = [] |
| | current_image = init_image |
| | |
| | for i in range(num_frames): |
| | if progress_callback: |
| | progress_callback(f"Generating frame {i+1}/{num_frames}") |
| | |
| | |
| | motion_context = self.temporal_buffer.get_motion_context() |
| | |
| | |
| | adaptive_strength = self.calculate_adaptive_strength(motion_context, strength) |
| | enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context) |
| | |
| | |
| | result = self.pipe( |
| | prompt=enhanced_prompt, |
| | image=current_image, |
| | strength=adaptive_strength, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | generator=generator |
| | ) |
| | |
| | generated_frame = result.images[0] |
| | |
| | |
| | if len(self.temporal_buffer.frames) > 0: |
| | reference_frame = self.temporal_buffer.get_reference_frame() |
| | blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2 |
| | generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio) |
| | |
| | |
| | self.temporal_buffer.add_frame(generated_frame) |
| | frames.append(generated_frame) |
| | |
| | |
| | current_image = generated_frame |
| | |
| | return frames |
| | |
| | def generate_i2v_sequence( |
| | self, |
| | init_image: Image.Image, |
| | prompt: str, |
| | total_frames: int = 16, |
| | frames_per_batch: int = 2, |
| | strength: float = 0.75, |
| | guidance_scale: float = 7.5, |
| | num_inference_steps: int = 20, |
| | seed: Optional[int] = None, |
| | progress_callback=None |
| | ) -> List[Image.Image]: |
| | """Generate I2V sequence with flexible batch sizes""" |
| | |
| | if not self.is_loaded: |
| | raise ValueError("Model not loaded. Please load the model first.") |
| | |
| | |
| | generator = torch.Generator(device=self.device) |
| | if seed is not None: |
| | generator.manual_seed(seed) |
| | |
| | |
| | self.temporal_buffer = SimpleTemporalBuffer() |
| | self.temporal_buffer.add_frame(init_image) |
| | |
| | all_frames = [init_image] |
| | frames_generated = 1 |
| | current_reference = init_image |
| | |
| | |
| | while frames_generated < total_frames: |
| | remaining_frames = total_frames - frames_generated |
| | current_batch_size = min(frames_per_batch, remaining_frames) |
| | |
| | if progress_callback: |
| | progress_callback(f"Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}") |
| | |
| | |
| | batch_frames = self.generate_frame_batch( |
| | init_image=current_reference, |
| | prompt=prompt, |
| | num_frames=current_batch_size, |
| | strength=strength, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | generator=generator, |
| | progress_callback=progress_callback |
| | ) |
| | |
| | |
| | all_frames.extend(batch_frames) |
| | frames_generated += current_batch_size |
| | |
| | |
| | current_reference = batch_frames[-1] |
| | |
| | return all_frames |
| |
|
| | |
| | generator = SD15FlexibleI2VGenerator() |
| |
|
| | def load_model_interface(): |
| | """Interface function to load the model""" |
| | status = generator.load_model() |
| | return status |
| |
|
| | def create_frames_to_gif(frames: List[Image.Image], duration: int = 200) -> str: |
| | """Convert frames to GIF and return file path""" |
| | temp_dir = tempfile.mkdtemp() |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | gif_path = os.path.join(temp_dir, f"i2v_sequence_{timestamp}.gif") |
| | |
| | frames[0].save( |
| | gif_path, |
| | save_all=True, |
| | append_images=frames[1:], |
| | duration=duration, |
| | loop=0 |
| | ) |
| | |
| | return gif_path |
| |
|
| | def create_frames_to_video(frames: List[Image.Image], fps: int = 8) -> str: |
| | """Convert frames to MP4 video and return file path""" |
| | temp_dir = tempfile.mkdtemp() |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | video_path = os.path.join(temp_dir, f"i2v_sequence_{timestamp}.mp4") |
| | |
| | try: |
| | with imageio.get_writer(video_path, fps=fps) as writer: |
| | for frame in frames: |
| | writer.append_data(np.array(frame)) |
| | return video_path |
| | except ImportError: |
| | |
| | return create_frames_to_gif(frames, duration=int(1000/fps)) |
| |
|
| | def generate_i2v_interface( |
| | init_image, |
| | prompt, |
| | total_frames, |
| | frames_per_batch, |
| | strength, |
| | guidance_scale, |
| | num_inference_steps, |
| | seed, |
| | output_format, |
| | progress=gr.Progress() |
| | ): |
| | """Main interface function for I2V generation""" |
| | |
| | if init_image is None: |
| | return None, None, "β Please upload an initial image" |
| | |
| | if not prompt.strip(): |
| | return None, None, "β Please enter a prompt" |
| | |
| | try: |
| | |
| | def update_progress(message): |
| | progress(0.5, desc=message) |
| | |
| | progress(0.1, desc="Starting generation...") |
| | |
| | |
| | if init_image.size != (512, 512): |
| | init_image = init_image.resize((512, 512), Image.Resampling.LANCZOS) |
| | |
| | |
| | frames = generator.generate_i2v_sequence( |
| | init_image=init_image, |
| | prompt=prompt, |
| | total_frames=total_frames, |
| | frames_per_batch=frames_per_batch, |
| | strength=strength, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | seed=seed if seed > 0 else None, |
| | progress_callback=update_progress |
| | ) |
| | |
| | progress(0.8, desc="Creating output file...") |
| | |
| | |
| | if output_format == "GIF": |
| | output_path = create_frames_to_gif(frames, duration=200) |
| | else: |
| | output_path = create_frames_to_video(frames, fps=8) |
| | |
| | progress(1.0, desc="Complete!") |
| | |
| | |
| | return frames[-1], output_path, f"β
Generated {len(frames)} frames successfully!" |
| | |
| | except Exception as e: |
| | return None, None, f"β Error: {str(e)}" |
| |
|
| | def generate_variable_pattern_interface( |
| | init_image, |
| | prompt, |
| | total_frames, |
| | batch_pattern_str, |
| | strength, |
| | guidance_scale, |
| | num_inference_steps, |
| | seed, |
| | output_format, |
| | progress=gr.Progress() |
| | ): |
| | """Interface for variable batch pattern generation""" |
| | |
| | if init_image is None: |
| | return None, None, "β Please upload an initial image" |
| | |
| | if not prompt.strip(): |
| | return None, None, "β Please enter a prompt" |
| | |
| | try: |
| | |
| | batch_pattern = [int(x.strip()) for x in batch_pattern_str.split(",")] |
| | if not batch_pattern or any(x <= 0 for x in batch_pattern): |
| | raise ValueError("Invalid batch pattern") |
| | |
| | progress(0.1, desc="Starting variable pattern generation...") |
| | |
| | |
| | if init_image.size != (512, 512): |
| | init_image = init_image.resize((512, 512), Image.Resampling.LANCZOS) |
| | |
| | |
| | frames = [init_image] |
| | frames_generated = 1 |
| | current_reference = init_image |
| | pattern_idx = 0 |
| | |
| | generator.temporal_buffer = SimpleTemporalBuffer() |
| | generator.temporal_buffer.add_frame(init_image) |
| | |
| | gen = torch.Generator(device=generator.device) |
| | if seed > 0: |
| | gen.manual_seed(seed) |
| | |
| | while frames_generated < total_frames: |
| | current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)] |
| | remaining_frames = total_frames - frames_generated |
| | actual_batch_size = min(current_batch_size, remaining_frames) |
| | |
| | progress(frames_generated / total_frames, |
| | desc=f"Pattern step {pattern_idx+1}: {actual_batch_size} frames") |
| | |
| | batch_frames = generator.generate_frame_batch( |
| | init_image=current_reference, |
| | prompt=prompt, |
| | num_frames=actual_batch_size, |
| | strength=strength, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | generator=gen |
| | ) |
| | |
| | frames.extend(batch_frames) |
| | frames_generated += actual_batch_size |
| | current_reference = batch_frames[-1] |
| | pattern_idx += 1 |
| | |
| | progress(0.9, desc="Creating output file...") |
| | |
| | |
| | final_frames = frames[:total_frames+1] |
| | if output_format == "GIF": |
| | output_path = create_frames_to_gif(final_frames, duration=200) |
| | else: |
| | output_path = create_frames_to_video(final_frames, fps=8) |
| | |
| | progress(1.0, desc="Complete!") |
| | |
| | return final_frames[-1], output_path, f"β
Generated {len(final_frames)} frames with pattern {batch_pattern}!" |
| | |
| | except Exception as e: |
| | return None, None, f"β Error: {str(e)}" |
| |
|
| | |
| | def create_gradio_app(): |
| | """Create the main Gradio application""" |
| | |
| | with gr.Blocks(title="SD1.5 Flexible I2V Generator", theme=gr.themes.Soft()) as app: |
| | |
| | gr.Markdown(""" |
| | # π¬ SD1.5 Flexible I2V Generator |
| | |
| | Generate image-to-video sequences with **flexible batch processing** and **temporal consistency**! |
| | |
| | ## Key Features: |
| | - π― **Flexible Batch Sizes**: Generate 1, 2, 3+ frames at a time |
| | - π **Motion-Aware Processing**: Adapts based on detected motion |
| | - π¨ **Temporal Consistency**: Smooth transitions between frames |
| | - π **Variable Patterns**: Dynamic batch sizing patterns |
| | """) |
| | |
| | |
| | with gr.Row(): |
| | load_btn = gr.Button("π Load SD1.5 Model", variant="primary", size="lg") |
| | model_status = gr.Textbox( |
| | label="Model Status", |
| | value="Model not loaded. Click 'Load SD1.5 Model' to start.", |
| | interactive=False |
| | ) |
| | |
| | load_btn.click(load_model_interface, outputs=model_status) |
| | |
| | |
| | with gr.Tabs(): |
| | |
| | |
| | with gr.Tab("π― Fixed Batch Generation"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | init_image_1 = gr.Image( |
| | label="Initial Image", |
| | type="pil", |
| | height=300 |
| | ) |
| | prompt_1 = gr.Textbox( |
| | label="Prompt", |
| | placeholder="e.g., a cat walking through a peaceful garden, cinematic lighting", |
| | lines=3 |
| | ) |
| | |
| | with gr.Row(): |
| | total_frames_1 = gr.Slider( |
| | label="Total Frames", |
| | minimum=4, |
| | maximum=32, |
| | value=12, |
| | step=1 |
| | ) |
| | frames_per_batch_1 = gr.Slider( |
| | label="Frames per Batch (Key Parameter!)", |
| | minimum=1, |
| | maximum=4, |
| | value=2, |
| | step=1 |
| | ) |
| | |
| | with gr.Accordion("Advanced Settings", open=False): |
| | strength_1 = gr.Slider( |
| | label="Strength", |
| | minimum=0.3, |
| | maximum=0.9, |
| | value=0.75, |
| | step=0.05 |
| | ) |
| | guidance_scale_1 = gr.Slider( |
| | label="Guidance Scale", |
| | minimum=3.0, |
| | maximum=15.0, |
| | value=7.5, |
| | step=0.5 |
| | ) |
| | num_inference_steps_1 = gr.Slider( |
| | label="Inference Steps", |
| | minimum=10, |
| | maximum=50, |
| | value=20, |
| | step=5 |
| | ) |
| | seed_1 = gr.Number( |
| | label="Seed (-1 for random)", |
| | value=-1 |
| | ) |
| | output_format_1 = gr.Radio( |
| | label="Output Format", |
| | choices=["GIF", "MP4"], |
| | value="GIF" |
| | ) |
| | |
| | generate_btn_1 = gr.Button("π¬ Generate I2V Sequence", variant="primary", size="lg") |
| | |
| | with gr.Column(scale=1): |
| | preview_1 = gr.Image(label="Last Frame Preview", height=300) |
| | output_file_1 = gr.File(label="Download Generated Video/GIF") |
| | status_1 = gr.Textbox(label="Status", interactive=False) |
| | |
| | generate_btn_1.click( |
| | generate_i2v_interface, |
| | inputs=[ |
| | init_image_1, prompt_1, total_frames_1, frames_per_batch_1, |
| | strength_1, guidance_scale_1, num_inference_steps_1, seed_1, output_format_1 |
| | ], |
| | outputs=[preview_1, output_file_1, status_1] |
| | ) |
| | |
| | |
| | with gr.Tab("π Variable Pattern Generation"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | init_image_2 = gr.Image( |
| | label="Initial Image", |
| | type="pil", |
| | height=300 |
| | ) |
| | prompt_2 = gr.Textbox( |
| | label="Prompt", |
| | placeholder="e.g., smooth camera movement through a scene", |
| | lines=3 |
| | ) |
| | |
| | total_frames_2 = gr.Slider( |
| | label="Total Frames", |
| | minimum=6, |
| | maximum=40, |
| | value=16, |
| | step=1 |
| | ) |
| | |
| | batch_pattern_2 = gr.Textbox( |
| | label="Batch Pattern (comma-separated)", |
| | value="1,2,3,2,1", |
| | placeholder="e.g., 1,2,3,2,1 or 2,4,2" |
| | ) |
| | |
| | gr.Markdown(""" |
| | **Pattern Examples:** |
| | - `1,2,3,2,1` - Start slow, ramp up, slow down |
| | - `2,2,2,2` - Consistent 2-frame batches |
| | - `1,3,1,3` - Alternating single and triple |
| | """) |
| | |
| | with gr.Accordion("Advanced Settings", open=False): |
| | strength_2 = gr.Slider(label="Strength", minimum=0.3, maximum=0.9, value=0.75, step=0.05) |
| | guidance_scale_2 = gr.Slider(label="Guidance Scale", minimum=3.0, maximum=15.0, value=7.5, step=0.5) |
| | num_inference_steps_2 = gr.Slider(label="Inference Steps", minimum=10, maximum=50, value=20, step=5) |
| | seed_2 = gr.Number(label="Seed (-1 for random)", value=-1) |
| | output_format_2 = gr.Radio(label="Output Format", choices=["GIF", "MP4"], value="GIF") |
| | |
| | generate_btn_2 = gr.Button("π¨ Generate with Pattern", variant="primary", size="lg") |
| | |
| | with gr.Column(scale=1): |
| | preview_2 = gr.Image(label="Last Frame Preview", height=300) |
| | output_file_2 = gr.File(label="Download Generated Video/GIF") |
| | status_2 = gr.Textbox(label="Status", interactive=False) |
| | |
| | generate_btn_2.click( |
| | generate_variable_pattern_interface, |
| | inputs=[ |
| | init_image_2, prompt_2, total_frames_2, batch_pattern_2, |
| | strength_2, guidance_scale_2, num_inference_steps_2, seed_2, output_format_2 |
| | ], |
| | outputs=[preview_2, output_file_2, status_2] |
| | ) |
| | |
| | |
| | with gr.Accordion("π Example Prompts & Tips", open=False): |
| | gr.Markdown(""" |
| | ## π― Good Prompts for I2V: |
| | - `a peaceful lake with gentle ripples, soft sunlight, cinematic` |
| | - `a cat slowly walking through a garden, smooth movement` |
| | - `camera slowly panning across a mountain landscape` |
| | - `a flower blooming in timelapse, natural lighting` |
| | - `gentle waves on a beach, golden hour lighting` |
| | |
| | ## π Parameter Tips: |
| | - **Frames per Batch**: |
| | - `1` = Maximum consistency, slower generation |
| | - `2-3` = Balanced quality and speed |
| | - `4+` = Faster but less consistent |
| | - **Strength**: |
| | - `0.6-0.7` = Subtle changes |
| | - `0.7-0.8` = Moderate animation |
| | - `0.8-0.9` = More dramatic changes |
| | - **Batch Patterns**: |
| | - Use `1,2,3,2,1` for organic acceleration/deceleration |
| | - Use consistent values like `2,2,2` for steady pacing |
| | """) |
| | |
| | gr.Markdown(""" |
| | --- |
| | |
| | ## π **Innovation Highlights:** |
| | |
| | This app demonstrates **flexible batch processing** for I2V generation: |
| | - Generate multiple frames simultaneously with `frames_per_batch` |
| | - Motion-aware strength adaptation based on optical flow |
| | - Temporal consistency through intelligent frame blending |
| | - Variable stepping patterns for dynamic control |
| | |
| | **Built with SD1.5 img2img pipeline + custom temporal processing!** |
| | """) |
| | |
| | return app |
| |
|
| | if __name__ == "__main__": |
| | app = create_gradio_app() |
| | app.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False, |
| | debug=True |
| | ) |