7nglzz commited on
Commit
2fcccb7
·
1 Parent(s): 6b0cf57
Files changed (3) hide show
  1. app.py +249 -16
  2. requirements.txt +14 -7
  3. startup.sh +0 -2
app.py CHANGED
@@ -1,21 +1,254 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import WanPipeline
4
- from diffusers.utils import export_to_video
5
- from io import BytesIO
 
 
 
6
 
7
- # Initialize the model
8
- model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
9
- pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
10
- pipe.enable_attention_slicing() # Reduce memory usage
11
- pipe.enable_sequential_cpu_offload() # Offload layers to CPU sequentially
12
 
13
- def generate_video(prompt):
14
- frames = pipe(prompt=prompt, num_frames=33).frames[0]
15
- buf = BytesIO()
16
- export_to_video(frames, buf, fps=16, codec="mp4v")
17
- return buf.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Define Gradio interface
20
- iface = gr.Interface(fn=generate_video, inputs="text", outputs="video")
21
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import DiffusionPipeline
4
+ import numpy as np
5
+ import cv2
6
+ import os
7
+ from PIL import Image
8
+ import tempfile
9
 
10
+ # Force CPU usage for better compatibility on HF Spaces
11
+ device = "cpu"
12
+ torch.set_num_threads(4) # Optimize for CPU
 
 
13
 
14
+ class VideoGenerator:
15
+ def __init__(self):
16
+ self.pipe = None
17
+ self.load_model()
18
+
19
+ def load_model(self):
20
+ try:
21
+ print("Loading Wan2.1-T2V model...")
22
+ self.pipe = DiffusionPipeline.from_pretrained(
23
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
24
+ torch_dtype=torch.float32, # Use float32 for CPU
25
+ variant=None,
26
+ use_safetensors=True,
27
+ )
28
+ self.pipe = self.pipe.to(device)
29
+
30
+ # Enable memory efficient attention if available
31
+ if hasattr(self.pipe, "enable_attention_slicing"):
32
+ self.pipe.enable_attention_slicing()
33
+
34
+ print("Model loaded successfully!")
35
+ except Exception as e:
36
+ print(f"Error loading model: {e}")
37
+ self.pipe = None
38
+
39
+ def generate_video(self, prompt, negative_prompt="", num_frames=16, height=320, width=512, num_inference_steps=20, guidance_scale=7.5):
40
+ if self.pipe is None:
41
+ return None, "Model not loaded properly"
42
+
43
+ try:
44
+ print(f"Generating video for prompt: {prompt}")
45
+
46
+ # Generate video
47
+ with torch.no_grad():
48
+ result = self.pipe(
49
+ prompt=prompt,
50
+ negative_prompt=negative_prompt,
51
+ num_frames=num_frames,
52
+ height=height,
53
+ width=width,
54
+ num_inference_steps=num_inference_steps,
55
+ guidance_scale=guidance_scale,
56
+ generator=torch.Generator(device=device).manual_seed(42)
57
+ )
58
+
59
+ # Extract frames
60
+ if hasattr(result, 'frames'):
61
+ frames = result.frames[0] # Get first batch
62
+ else:
63
+ frames = result.images
64
+
65
+ # Convert frames to video
66
+ video_path = self.frames_to_video(frames)
67
+
68
+ return video_path, "Video generated successfully!"
69
+
70
+ except Exception as e:
71
+ error_msg = f"Error generating video: {str(e)}"
72
+ print(error_msg)
73
+ return None, error_msg
74
+
75
+ def frames_to_video(self, frames, fps=8):
76
+ """Convert frames to video file"""
77
+ try:
78
+ # Create temporary file
79
+ temp_dir = tempfile.gettempdir()
80
+ video_path = os.path.join(temp_dir, f"generated_video_{np.random.randint(1000, 9999)}.mp4")
81
+
82
+ # Get frame dimensions
83
+ if isinstance(frames[0], Image.Image):
84
+ frame_array = np.array(frames[0])
85
+ height, width = frame_array.shape[:2]
86
+ else:
87
+ height, width = frames[0].shape[:2]
88
+
89
+ # Initialize video writer
90
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
91
+ out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
92
+
93
+ # Write frames
94
+ for frame in frames:
95
+ if isinstance(frame, Image.Image):
96
+ frame_array = np.array(frame)
97
+ else:
98
+ frame_array = frame
99
+
100
+ # Convert RGB to BGR for OpenCV
101
+ if len(frame_array.shape) == 3:
102
+ frame_bgr = cv2.cvtColor(frame_array, cv2.COLOR_RGB2BGR)
103
+ else:
104
+ frame_bgr = frame_array
105
+
106
+ out.write(frame_bgr.astype(np.uint8))
107
+
108
+ out.release()
109
+ return video_path
110
+
111
+ except Exception as e:
112
+ print(f"Error creating video: {e}")
113
+ return None
114
 
115
+ # Initialize the generator
116
+ print("Initializing video generator...")
117
+ generator = VideoGenerator()
118
+
119
+ def generate_video_interface(prompt, negative_prompt, num_frames, height, width, steps, guidance_scale):
120
+ """Interface function for Gradio"""
121
+ if not prompt.strip():
122
+ return None, "Please enter a prompt"
123
+
124
+ video_path, message = generator.generate_video(
125
+ prompt=prompt,
126
+ negative_prompt=negative_prompt,
127
+ num_frames=int(num_frames),
128
+ height=int(height),
129
+ width=int(width),
130
+ num_inference_steps=int(steps),
131
+ guidance_scale=float(guidance_scale)
132
+ )
133
+
134
+ return video_path, message
135
+
136
+ # Create Gradio interface
137
+ def create_interface():
138
+ with gr.Blocks(title="Wan2.1 Text-to-Video Generator", theme=gr.themes.Soft()) as demo:
139
+ gr.Markdown("# 🎬 Wan2.1 Text-to-Video Generator")
140
+ gr.Markdown("Generate videos from text prompts using the Wan2.1-T2V-1.3B model")
141
+
142
+ with gr.Row():
143
+ with gr.Column(scale=1):
144
+ prompt = gr.Textbox(
145
+ label="Prompt",
146
+ placeholder="Describe the video you want to generate...",
147
+ lines=3,
148
+ value="A cat playing with a ball of yarn"
149
+ )
150
+
151
+ negative_prompt = gr.Textbox(
152
+ label="Negative Prompt (Optional)",
153
+ placeholder="What you don't want in the video...",
154
+ lines=2,
155
+ value="blurry, low quality, distorted"
156
+ )
157
+
158
+ with gr.Row():
159
+ num_frames = gr.Slider(
160
+ label="Number of Frames",
161
+ minimum=8,
162
+ maximum=32,
163
+ value=16,
164
+ step=4
165
+ )
166
+
167
+ steps = gr.Slider(
168
+ label="Inference Steps",
169
+ minimum=10,
170
+ maximum=50,
171
+ value=20,
172
+ step=5
173
+ )
174
+
175
+ with gr.Row():
176
+ width = gr.Slider(
177
+ label="Width",
178
+ minimum=256,
179
+ maximum=768,
180
+ value=512,
181
+ step=64
182
+ )
183
+
184
+ height = gr.Slider(
185
+ label="Height",
186
+ minimum=256,
187
+ maximum=576,
188
+ value=320,
189
+ step=64
190
+ )
191
+
192
+ guidance_scale = gr.Slider(
193
+ label="Guidance Scale",
194
+ minimum=1.0,
195
+ maximum=15.0,
196
+ value=7.5,
197
+ step=0.5
198
+ )
199
+
200
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
201
+
202
+ with gr.Column(scale=1):
203
+ output_video = gr.Video(
204
+ label="Generated Video",
205
+ height=400
206
+ )
207
+
208
+ status_text = gr.Textbox(
209
+ label="Status",
210
+ lines=2,
211
+ interactive=False
212
+ )
213
+
214
+ # Examples
215
+ gr.Markdown("## 📝 Example Prompts")
216
+ examples = gr.Examples(
217
+ examples=[
218
+ ["A cute cat playing with a red ball", "blurry, low quality"],
219
+ ["A beautiful sunset over the ocean with waves", "dark, gloomy"],
220
+ ["A person walking in a forest with sunlight filtering through trees", "scary, horror"],
221
+ ["Colorful flowers blooming in a garden", "wilted, dead"],
222
+ ["A bird flying in the sky with clouds", "static, motionless"]
223
+ ],
224
+ inputs=[prompt, negative_prompt]
225
+ )
226
+
227
+ # Event handlers
228
+ generate_btn.click(
229
+ fn=generate_video_interface,
230
+ inputs=[prompt, negative_prompt, num_frames, height, width, steps, guidance_scale],
231
+ outputs=[output_video, status_text],
232
+ show_progress=True
233
+ )
234
+
235
+ # Info
236
+ gr.Markdown("""
237
+ ### ℹ️ Tips:
238
+ - **Lower resolution and fewer frames** = faster generation
239
+ - **Higher inference steps** = better quality but slower
240
+ - **Guidance scale 7-9** usually works best
241
+ - Be descriptive in your prompts for better results
242
+ - Generation may take 2-5 minutes on CPU
243
+ """)
244
+
245
+ return demo
246
+
247
+ if __name__ == "__main__":
248
+ demo = create_interface()
249
+ demo.launch(
250
+ server_name="0.0.0.0",
251
+ server_port=7860,
252
+ share=False,
253
+ show_error=True
254
+ )
requirements.txt CHANGED
@@ -1,7 +1,14 @@
1
- diffusers==0.33.0
2
- torch==2.4.0
3
- gradio==3.44.0
4
- huggingface_hub==0.17.0
5
- transformers==4.33.2
6
- imageio==2.31.1
7
- imageio-ffmpeg==0.4.8
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ diffusers>=0.25.0
4
+ transformers>=4.35.0
5
+ accelerate>=0.20.0
6
+ gradio>=4.0.0
7
+ opencv-python-headless
8
+ pillow>=9.0.0
9
+ numpy>=1.21.0
10
+ safetensors>=0.3.0
11
+ huggingface-hub>=0.16.0
12
+ scipy
13
+ ftfy
14
+ regex
startup.sh DELETED
@@ -1,2 +0,0 @@
1
- #!/bin/bash
2
- pip install -r requirements.txt