import gradio as gr import torch from diffusers import DiffusionPipeline import numpy as np import cv2 import os from PIL import Image import tempfile # Force CPU usage for better compatibility on HF Spaces device = "cpu" torch.set_num_threads(4) # Optimize for CPU class VideoGenerator: def __init__(self): self.pipe = None self.load_model() def load_model(self): try: print("Loading Wan2.1-T2V model...") self.pipe = DiffusionPipeline.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.float32, # Use float32 for CPU variant=None, use_safetensors=True, ) self.pipe = self.pipe.to(device) # Enable memory efficient attention if available if hasattr(self.pipe, "enable_attention_slicing"): self.pipe.enable_attention_slicing() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") self.pipe = None def adjust_frame_count(self, num_frames): """Adjust frame count so that (num_frames - 1) is divisible by 4""" remainder = (num_frames - 1) % 4 if remainder == 0: return num_frames # Round to nearest valid frame count option1 = num_frames - remainder option2 = num_frames + (4 - remainder) # Choose the closest option, but prefer lower count for performance if remainder <= 2: return option1 else: return option2 def generate_video(self, prompt, negative_prompt="", num_frames=16, height=320, width=512, num_inference_steps=20, guidance_scale=7.5): if self.pipe is None: return None, "Model not loaded properly" try: # Fix num_frames to satisfy requirement: (num_frames - 1) must be divisible by 4 adjusted_frames = self.adjust_frame_count(num_frames) if adjusted_frames != num_frames: print(f"Adjusted frames from {num_frames} to {adjusted_frames} to satisfy model requirements") print(f"Generating video for prompt: {prompt}") print(f"Using {adjusted_frames} frames") # Generate video with torch.no_grad(): result = self.pipe( prompt=prompt, negative_prompt=negative_prompt, num_frames=adjusted_frames, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=torch.Generator(device=device).manual_seed(42) ) # Extract frames if hasattr(result, 'frames'): frames = result.frames[0] # Get first batch else: frames = result.images # Convert frames to video video_path = self.frames_to_video(frames) return video_path, "Video generated successfully!" except Exception as e: error_msg = f"Error generating video: {str(e)}" print(error_msg) return None, error_msg def frames_to_video(self, frames, fps=8): """Convert frames to video file with proper browser compatibility""" try: # Create temporary file temp_dir = tempfile.gettempdir() video_path = os.path.join(temp_dir, f"generated_video_{np.random.randint(1000, 9999)}.mp4") # Get frame dimensions if isinstance(frames[0], Image.Image): frame_array = np.array(frames[0]) height, width = frame_array.shape[:2] else: height, width = frames[0].shape[:2] # Use H.264 codec for better browser compatibility fourcc = cv2.VideoWriter_fourcc(*'H264') out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) # If H264 fails, fall back to mp4v if not out.isOpened(): fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) # Write frames for frame in frames: if isinstance(frame, Image.Image): frame_array = np.array(frame) else: frame_array = frame # Ensure frame is in correct format if frame_array.dtype != np.uint8: frame_array = (frame_array * 255).astype(np.uint8) # Convert RGB to BGR for OpenCV if len(frame_array.shape) == 3 and frame_array.shape[2] == 3: frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGB2BGR) else: frame_bgr = frame_array out.write(frame_bgr) out.release() # Verify the video file was created successfully if os.path.exists(video_path) and os.path.getsize(video_path) > 0: return video_path else: print("Video file creation failed") return None except Exception as e: print(f"Error creating video: {e}") return None # Initialize the generator print("Initializing video generator...") generator = VideoGenerator() def generate_video_interface(prompt, negative_prompt, num_frames, height, width, steps, guidance_scale): """Interface function for Gradio""" if not prompt.strip(): return None, "Please enter a prompt" video_path, message = generator.generate_video( prompt=prompt, negative_prompt=negative_prompt, num_frames=int(num_frames), height=int(height), width=int(width), num_inference_steps=int(steps), guidance_scale=float(guidance_scale) ) return video_path, message # Create Gradio interface def create_interface(): with gr.Blocks(title="Wan2.1 Text-to-Video Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# đŸŽŦ Wan2.1 Text-to-Video Generator") gr.Markdown("Generate videos from text prompts using the Wan2.1-T2V-1.3B model") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="Describe the video you want to generate...", lines=3, value="A cat playing with a ball of yarn" ) negative_prompt = gr.Textbox( label="Negative Prompt (Optional)", placeholder="What you don't want in the video...", lines=2, value="blurry, low quality, distorted" ) with gr.Row(): num_frames = gr.Slider( label="Number of Frames", minimum=5, maximum=33, value=17, step=1, info="Will be auto-adjusted so (frames-1) is divisible by 4" ) steps = gr.Slider( label="Inference Steps", minimum=3, maximum=50, value=20, step=5 ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=768, value=512, step=64 ) height = gr.Slider( label="Height", minimum=256, maximum=576, value=320, step=64 ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.5 ) generate_btn = gr.Button("đŸŽŦ Generate Video", variant="primary", size="lg") with gr.Column(scale=1): output_video = gr.Video( label="Generated Video", height=400 ) status_text = gr.Textbox( label="Status", lines=2, interactive=False ) # Examples gr.Markdown("## 📝 Example Prompts") examples = gr.Examples( examples=[ ["A cute cat playing with a red ball", "blurry, low quality"], ["A beautiful sunset over the ocean with waves", "dark, gloomy"], ["A person walking in a forest with sunlight filtering through trees", "scary, horror"], ["Colorful flowers blooming in a garden", "wilted, dead"], ["A bird flying in the sky with clouds", "static, motionless"] ], inputs=[prompt, negative_prompt] ) # Event handlers generate_btn.click( fn=generate_video_interface, inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance_scale], outputs=[output_video, status_text], show_progress=True ) # Info gr.Markdown(""" ### â„šī¸ Tips: - **Lower resolution and fewer frames** = faster generation - **Higher inference steps** = better quality but slower - **Guidance scale 7-9** usually works best - Be descriptive in your prompts for better results - Generation may take 2-5 minutes on CPU """) return demo if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )