Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| import numpy as np | |
| import cv2 | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| # Force CPU usage for better compatibility on HF Spaces | |
| device = "cpu" | |
| torch.set_num_threads(4) # Optimize for CPU | |
| class VideoGenerator: | |
| def __init__(self): | |
| self.pipe = None | |
| self.load_model() | |
| def load_model(self): | |
| try: | |
| print("Loading Wan2.1-T2V model...") | |
| self.pipe = DiffusionPipeline.from_pretrained( | |
| "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| variant=None, | |
| use_safetensors=True, | |
| ) | |
| self.pipe = self.pipe.to(device) | |
| # Enable memory efficient attention if available | |
| if hasattr(self.pipe, "enable_attention_slicing"): | |
| self.pipe.enable_attention_slicing() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.pipe = None | |
| def adjust_frame_count(self, num_frames): | |
| """Adjust frame count so that (num_frames - 1) is divisible by 4""" | |
| remainder = (num_frames - 1) % 4 | |
| if remainder == 0: | |
| return num_frames | |
| # Round to nearest valid frame count | |
| option1 = num_frames - remainder | |
| option2 = num_frames + (4 - remainder) | |
| # Choose the closest option, but prefer lower count for performance | |
| if remainder <= 2: | |
| return option1 | |
| else: | |
| return option2 | |
| def generate_video(self, prompt, negative_prompt="", num_frames=16, height=320, width=512, num_inference_steps=20, guidance_scale=7.5): | |
| if self.pipe is None: | |
| return None, "Model not loaded properly" | |
| try: | |
| # Fix num_frames to satisfy requirement: (num_frames - 1) must be divisible by 4 | |
| adjusted_frames = self.adjust_frame_count(num_frames) | |
| if adjusted_frames != num_frames: | |
| print(f"Adjusted frames from {num_frames} to {adjusted_frames} to satisfy model requirements") | |
| print(f"Generating video for prompt: {prompt}") | |
| print(f"Using {adjusted_frames} frames") | |
| # Generate video | |
| with torch.no_grad(): | |
| result = self.pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=adjusted_frames, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator(device=device).manual_seed(42) | |
| ) | |
| # Extract frames | |
| if hasattr(result, 'frames'): | |
| frames = result.frames[0] # Get first batch | |
| else: | |
| frames = result.images | |
| # Convert frames to video | |
| video_path = self.frames_to_video(frames) | |
| return video_path, "Video generated successfully!" | |
| except Exception as e: | |
| error_msg = f"Error generating video: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def frames_to_video(self, frames, fps=8): | |
| """Convert frames to video file with proper browser compatibility""" | |
| try: | |
| # Create temporary file | |
| temp_dir = tempfile.gettempdir() | |
| video_path = os.path.join(temp_dir, f"generated_video_{np.random.randint(1000, 9999)}.mp4") | |
| # Get frame dimensions | |
| if isinstance(frames[0], Image.Image): | |
| frame_array = np.array(frames[0]) | |
| height, width = frame_array.shape[:2] | |
| else: | |
| height, width = frames[0].shape[:2] | |
| # Use H.264 codec for better browser compatibility | |
| fourcc = cv2.VideoWriter_fourcc(*'H264') | |
| out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) | |
| # If H264 fails, fall back to mp4v | |
| if not out.isOpened(): | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) | |
| # Write frames | |
| for frame in frames: | |
| if isinstance(frame, Image.Image): | |
| frame_array = np.array(frame) | |
| else: | |
| frame_array = frame | |
| # Ensure frame is in correct format | |
| if frame_array.dtype != np.uint8: | |
| frame_array = (frame_array * 255).astype(np.uint8) | |
| # Convert RGB to BGR for OpenCV | |
| if len(frame_array.shape) == 3 and frame_array.shape[2] == 3: | |
| frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGB2BGR) | |
| else: | |
| frame_bgr = frame_array | |
| out.write(frame_bgr) | |
| out.release() | |
| # Verify the video file was created successfully | |
| if os.path.exists(video_path) and os.path.getsize(video_path) > 0: | |
| return video_path | |
| else: | |
| print("Video file creation failed") | |
| return None | |
| except Exception as e: | |
| print(f"Error creating video: {e}") | |
| return None | |
| # Initialize the generator | |
| print("Initializing video generator...") | |
| generator = VideoGenerator() | |
| def generate_video_interface(prompt, negative_prompt, num_frames, height, width, steps, guidance_scale): | |
| """Interface function for Gradio""" | |
| if not prompt.strip(): | |
| return None, "Please enter a prompt" | |
| video_path, message = generator.generate_video( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=int(num_frames), | |
| height=int(height), | |
| width=int(width), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance_scale) | |
| ) | |
| return video_path, message | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Wan2.1 Text-to-Video Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎬 Wan2.1 Text-to-Video Generator") | |
| gr.Markdown("Generate videos from text prompts using the Wan2.1-T2V-1.3B model") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the video you want to generate...", | |
| lines=3, | |
| value="A cat playing with a ball of yarn" | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="What you don't want in the video...", | |
| lines=2, | |
| value="blurry, low quality, distorted" | |
| ) | |
| with gr.Row(): | |
| num_frames = gr.Slider( | |
| label="Number of Frames", | |
| minimum=5, | |
| maximum=33, | |
| value=17, | |
| step=1, | |
| info="Will be auto-adjusted so (frames-1) is divisible by 4" | |
| ) | |
| steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=3, | |
| maximum=50, | |
| value=20, | |
| step=5 | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=768, | |
| value=512, | |
| step=64 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=576, | |
| value=320, | |
| step=64 | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=15.0, | |
| value=7.5, | |
| step=0.5 | |
| ) | |
| generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_video = gr.Video( | |
| label="Generated Video", | |
| height=400 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Markdown("## 📝 Example Prompts") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["A cute cat playing with a red ball", "blurry, low quality"], | |
| ["A beautiful sunset over the ocean with waves", "dark, gloomy"], | |
| ["A person walking in a forest with sunlight filtering through trees", "scary, horror"], | |
| ["Colorful flowers blooming in a garden", "wilted, dead"], | |
| ["A bird flying in the sky with clouds", "static, motionless"] | |
| ], | |
| inputs=[prompt, negative_prompt] | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_video_interface, | |
| inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance_scale], | |
| outputs=[output_video, status_text], | |
| show_progress=True | |
| ) | |
| # Info | |
| gr.Markdown(""" | |
| ### ℹ️ Tips: | |
| - **Lower resolution and fewer frames** = faster generation | |
| - **Higher inference steps** = better quality but slower | |
| - **Guidance scale 7-9** usually works best | |
| - Be descriptive in your prompts for better results | |
| - Generation may take 2-5 minutes on CPU | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |