videeooo / app.py
7nglzz's picture
wowww
323451b
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import numpy as np
import cv2
import os
from PIL import Image
import tempfile
# Force CPU usage for better compatibility on HF Spaces
device = "cpu"
torch.set_num_threads(4) # Optimize for CPU
class VideoGenerator:
def __init__(self):
self.pipe = None
self.load_model()
def load_model(self):
try:
print("Loading Wan2.1-T2V model...")
self.pipe = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
torch_dtype=torch.float32, # Use float32 for CPU
variant=None,
use_safetensors=True,
)
self.pipe = self.pipe.to(device)
# Enable memory efficient attention if available
if hasattr(self.pipe, "enable_attention_slicing"):
self.pipe.enable_attention_slicing()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
self.pipe = None
def adjust_frame_count(self, num_frames):
"""Adjust frame count so that (num_frames - 1) is divisible by 4"""
remainder = (num_frames - 1) % 4
if remainder == 0:
return num_frames
# Round to nearest valid frame count
option1 = num_frames - remainder
option2 = num_frames + (4 - remainder)
# Choose the closest option, but prefer lower count for performance
if remainder <= 2:
return option1
else:
return option2
def generate_video(self, prompt, negative_prompt="", num_frames=16, height=320, width=512, num_inference_steps=20, guidance_scale=7.5):
if self.pipe is None:
return None, "Model not loaded properly"
try:
# Fix num_frames to satisfy requirement: (num_frames - 1) must be divisible by 4
adjusted_frames = self.adjust_frame_count(num_frames)
if adjusted_frames != num_frames:
print(f"Adjusted frames from {num_frames} to {adjusted_frames} to satisfy model requirements")
print(f"Generating video for prompt: {prompt}")
print(f"Using {adjusted_frames} frames")
# Generate video
with torch.no_grad():
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=adjusted_frames,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(42)
)
# Extract frames
if hasattr(result, 'frames'):
frames = result.frames[0] # Get first batch
else:
frames = result.images
# Convert frames to video
video_path = self.frames_to_video(frames)
return video_path, "Video generated successfully!"
except Exception as e:
error_msg = f"Error generating video: {str(e)}"
print(error_msg)
return None, error_msg
def frames_to_video(self, frames, fps=8):
"""Convert frames to video file with proper browser compatibility"""
try:
# Create temporary file
temp_dir = tempfile.gettempdir()
video_path = os.path.join(temp_dir, f"generated_video_{np.random.randint(1000, 9999)}.mp4")
# Get frame dimensions
if isinstance(frames[0], Image.Image):
frame_array = np.array(frames[0])
height, width = frame_array.shape[:2]
else:
height, width = frames[0].shape[:2]
# Use H.264 codec for better browser compatibility
fourcc = cv2.VideoWriter_fourcc(*'H264')
out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
# If H264 fails, fall back to mp4v
if not out.isOpened():
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
# Write frames
for frame in frames:
if isinstance(frame, Image.Image):
frame_array = np.array(frame)
else:
frame_array = frame
# Ensure frame is in correct format
if frame_array.dtype != np.uint8:
frame_array = (frame_array * 255).astype(np.uint8)
# Convert RGB to BGR for OpenCV
if len(frame_array.shape) == 3 and frame_array.shape[2] == 3:
frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGB2BGR)
else:
frame_bgr = frame_array
out.write(frame_bgr)
out.release()
# Verify the video file was created successfully
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
return video_path
else:
print("Video file creation failed")
return None
except Exception as e:
print(f"Error creating video: {e}")
return None
# Initialize the generator
print("Initializing video generator...")
generator = VideoGenerator()
def generate_video_interface(prompt, negative_prompt, num_frames, height, width, steps, guidance_scale):
"""Interface function for Gradio"""
if not prompt.strip():
return None, "Please enter a prompt"
video_path, message = generator.generate_video(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=int(num_frames),
height=int(height),
width=int(width),
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale)
)
return video_path, message
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Wan2.1 Text-to-Video Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 Wan2.1 Text-to-Video Generator")
gr.Markdown("Generate videos from text prompts using the Wan2.1-T2V-1.3B model")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate...",
lines=3,
value="A cat playing with a ball of yarn"
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="What you don't want in the video...",
lines=2,
value="blurry, low quality, distorted"
)
with gr.Row():
num_frames = gr.Slider(
label="Number of Frames",
minimum=5,
maximum=33,
value=17,
step=1,
info="Will be auto-adjusted so (frames-1) is divisible by 4"
)
steps = gr.Slider(
label="Inference Steps",
minimum=3,
maximum=50,
value=20,
step=5
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=768,
value=512,
step=64
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=576,
value=320,
step=64
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=15.0,
value=7.5,
step=0.5
)
generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
with gr.Column(scale=1):
output_video = gr.Video(
label="Generated Video",
height=400
)
status_text = gr.Textbox(
label="Status",
lines=2,
interactive=False
)
# Examples
gr.Markdown("## 📝 Example Prompts")
examples = gr.Examples(
examples=[
["A cute cat playing with a red ball", "blurry, low quality"],
["A beautiful sunset over the ocean with waves", "dark, gloomy"],
["A person walking in a forest with sunlight filtering through trees", "scary, horror"],
["Colorful flowers blooming in a garden", "wilted, dead"],
["A bird flying in the sky with clouds", "static, motionless"]
],
inputs=[prompt, negative_prompt]
)
# Event handlers
generate_btn.click(
fn=generate_video_interface,
inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance_scale],
outputs=[output_video, status_text],
show_progress=True
)
# Info
gr.Markdown("""
### ℹ️ Tips:
- **Lower resolution and fewer frames** = faster generation
- **Higher inference steps** = better quality but slower
- **Guidance scale 7-9** usually works best
- Be descriptive in your prompts for better results
- Generation may take 2-5 minutes on CPU
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)