samsammysammy commited on
Commit
dcd7d2e
·
verified ·
1 Parent(s): e050b88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -208
app.py CHANGED
@@ -1,47 +1,91 @@
1
  # IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
2
  import spaces
3
 
4
- # Standard library imports
5
  import os
6
-
7
- # Third-party imports (non-CUDA)
8
  import numpy as np
9
  from PIL import Image
10
  import gradio as gr
11
 
12
- # CUDA-related imports (must come after spaces)
13
  import torch
14
  from diffusers import WanPipeline, AutoencoderKLWan
15
  from diffusers.utils import export_to_video
16
 
17
- # Model configuration
 
 
 
18
  MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
19
  dtype = torch.bfloat16
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # Global pipeline variable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  pipe = None
 
24
 
25
  def initialize_pipeline():
26
- """Initialize the Wan2.2 pipeline"""
27
- global pipe
28
- if pipe is None:
29
- print("Loading Wan2.2-TI2V-5B model...")
30
- vae = AutoencoderKLWan.from_pretrained(
31
- MODEL_ID,
32
- subfolder="vae",
33
- torch_dtype=torch.float32
34
- )
35
- pipe = WanPipeline.from_pretrained(
36
- MODEL_ID,
37
- vae=vae,
38
- torch_dtype=dtype
39
- )
40
- pipe.to(device)
41
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  return pipe
43
 
44
- @spaces.GPU(duration=180) # Allocate GPU for 3 minutes (max allowed for Pro)
45
  def generate_video(
46
  prompt: str,
47
  image: Image.Image = None,
@@ -50,31 +94,41 @@ def generate_video(
50
  num_frames: int = 73,
51
  num_inference_steps: int = 35,
52
  guidance_scale: float = 5.0,
53
- seed: int = -1
 
 
54
  ):
55
- """
56
- Generate video from text prompt and optional image
57
-
58
- Args:
59
- prompt: Text description of the video to generate
60
- image: Optional input image for image-to-video generation
61
- width: Video width (default: 1280)
62
- height: Video height (default: 704)
63
- num_frames: Number of frames to generate (default: 73 for 3 seconds at 24fps)
64
- num_inference_steps: Number of denoising steps (default: 35 for faster generation)
65
- guidance_scale: Guidance scale for generation (default: 5.0)
66
- seed: Random seed for reproducibility (-1 for random)
67
- """
68
  try:
69
- # Initialize pipeline
70
  pipeline = initialize_pipeline()
71
 
72
- # Set seed for reproducibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if seed == -1:
74
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
75
  generator = torch.Generator(device=device).manual_seed(seed)
76
 
77
- # Prepare generation parameters
78
  gen_params = {
79
  "prompt": prompt,
80
  "height": height,
@@ -85,212 +139,96 @@ def generate_video(
85
  "generator": generator,
86
  }
87
 
88
- # Add image if provided (for image-to-video)
89
  if image is not None:
90
  gen_params["image"] = image
91
 
92
- # Generate video
93
- print(f"Generating video with prompt: {prompt}")
94
- print(f"Parameters: {width}x{height}, {num_frames} frames, seed: {seed}")
95
-
96
  output = pipeline(**gen_params).frames[0]
97
 
98
- # Export to video file
99
  output_path = "output.mp4"
100
  export_to_video(output, output_path, fps=24)
101
 
102
- return output_path, f"Video generated successfully! Seed used: {seed}"
 
 
103
 
104
- except Exception as e:
105
- error_msg = f"Error generating video: {str(e)}"
106
- print(error_msg)
107
- return None, error_msg
108
 
109
- # Create Gradio interface
110
- with gr.Blocks(title="Wan2.2 Video Generation") as demo:
111
- gr.Markdown(
112
- """
113
- # Wan2.2 Video Generation
114
 
115
- Generate high-quality videos from text prompts or images using Wan2.2-TI2V-5B model.
116
- This model supports both **Text-to-Video** and **Image-to-Video** generation at 720P/24fps.
 
117
 
118
- **Note:** Generation takes 2-3 minutes. Settings are optimized for Zero GPU 3-minute limit.
119
- """
120
- )
 
 
121
 
122
  with gr.Row():
123
  with gr.Column():
124
- # Input controls
125
  prompt_input = gr.Textbox(
126
- label="Prompt",
127
- placeholder="Describe the video you want to generate...",
128
- lines=3,
129
  value="Two anthropomorphic cats in comfy boxing gear fight on stage"
130
  )
131
-
132
- image_input = gr.Image(
133
- label="Input Image (Optional - for Image-to-Video)",
134
- type="pil",
135
- sources=["upload"]
136
- )
137
 
138
  with gr.Accordion("Advanced Settings", open=False):
139
  with gr.Row():
140
- width_input = gr.Slider(
141
- label="Width",
142
- minimum=512,
143
- maximum=1920,
144
- step=64,
145
- value=1280
146
- )
147
- height_input = gr.Slider(
148
- label="Height",
149
- minimum=512,
150
- maximum=1080,
151
- step=64,
152
- value=704
153
- )
154
-
155
- num_frames_input = gr.Slider(
156
- label="Number of Frames (more frames = longer video)",
157
- minimum=25,
158
- maximum=145,
159
- step=24,
160
- value=73,
161
- info="73 frames ≈ 3 seconds at 24fps (optimized for Zero GPU limits)"
162
- )
163
-
164
- num_steps_input = gr.Slider(
165
- label="Inference Steps (more steps = better quality, slower)",
166
- minimum=20,
167
- maximum=60,
168
- step=5,
169
- value=35
170
- )
171
-
172
- guidance_scale_input = gr.Slider(
173
- label="Guidance Scale (higher = closer to prompt)",
174
- minimum=1.0,
175
- maximum=15.0,
176
- step=0.5,
177
- value=5.0
178
- )
179
-
180
- seed_input = gr.Number(
181
- label="Seed (-1 for random)",
182
- value=-1,
183
- precision=0
184
  )
 
 
185
 
186
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
187
 
188
  with gr.Column():
189
- # Output
190
- video_output = gr.Video(
191
- label="Generated Video",
192
- autoplay=True
193
- )
194
- status_output = gr.Textbox(
195
- label="Status",
196
- lines=2
197
- )
198
 
199
- # Examples
200
  gr.Examples(
201
  examples=[
202
- [
203
- "Two anthropomorphic cats in comfy boxing gear fight on stage",
204
- None,
205
- 1280,
206
- 704,
207
- 73,
208
- 35,
209
- 5.0,
210
- 42
211
- ],
212
- [
213
- "A serene underwater scene with colorful coral reefs and tropical fish swimming gracefully",
214
- None,
215
- 1280,
216
- 704,
217
- 73,
218
- 35,
219
- 5.0,
220
- 123
221
- ],
222
- [
223
- "A bustling futuristic city at night with neon lights and flying cars",
224
- None,
225
- 1280,
226
- 704,
227
- 73,
228
- 35,
229
- 5.0,
230
- 456
231
- ],
232
- [
233
- "A peaceful mountain landscape with snow-capped peaks and a flowing river",
234
- None,
235
- 1280,
236
- 704,
237
- 73,
238
- 35,
239
- 5.0,
240
- 789
241
- ],
242
- ],
243
- inputs=[
244
- prompt_input,
245
- image_input,
246
- width_input,
247
- height_input,
248
- num_frames_input,
249
- num_steps_input,
250
- guidance_scale_input,
251
- seed_input
252
  ],
 
 
253
  outputs=[video_output, status_output],
254
  fn=generate_video,
255
  cache_examples=False,
256
  )
257
 
258
- # Connect generate button
259
  generate_btn.click(
260
- fn=generate_video,
261
- inputs=[
262
- prompt_input,
263
- image_input,
264
- width_input,
265
- height_input,
266
- num_frames_input,
267
- num_steps_input,
268
- guidance_scale_input,
269
- seed_input
270
- ],
271
  outputs=[video_output, status_output]
272
  )
273
 
274
- gr.Markdown(
275
- """
276
- ## Tips for Best Results:
277
- - Use detailed, descriptive prompts
278
- - For image-to-video: Upload a clear image that matches your prompt
279
- - Higher inference steps = better quality but slower generation
280
- - Adjust guidance scale to balance creativity vs. prompt adherence
281
- - Use the same seed to reproduce results
282
- - Keep generation under 3 minutes to fit Zero GPU limits
283
-
284
- ## Model Information:
285
- - Model: Wan2.2-TI2V-5B (5B parameters)
286
- - Resolution: 720P (1280x704 or custom)
287
- - Frame Rate: 24 fps
288
- - Default Duration: 3 seconds (optimized for Zero GPU)
289
- - Generation Time: ~2-3 minutes on Zero GPU (with optimized settings)
290
- """
291
- )
292
 
293
- # Launch the app
294
  if __name__ == "__main__":
295
- demo.queue(max_size=20)
296
- demo.launch()
 
1
  # IMPORTANT: spaces must be imported first to avoid CUDA initialization issues
2
  import spaces
3
 
 
4
  import os
 
 
5
  import numpy as np
6
  from PIL import Image
7
  import gradio as gr
8
 
 
9
  import torch
10
  from diffusers import WanPipeline, AutoencoderKLWan
11
  from diffusers.utils import export_to_video
12
 
13
+ # ────────────────────────────────────────────────
14
+ # Model + LoRA configuration
15
+ # ────────────────────────────────────────────────
16
+
17
  MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
18
  dtype = torch.bfloat16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ AVAILABLE_LORAS = [
22
+ {
23
+ "name": "Lightning (Fast 4-step)",
24
+ "repo_id": "lightx2v/Wan2.2-Distill-Loras",
25
+ "filename": "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",
26
+ "default_strength": 1.0,
27
+ },
28
+ {
29
+ "name": "General NSFW",
30
+ "repo_id": "lopi999/Wan2.2-I2V_General-NSFW-LoRA",
31
+ "filename": "pytorch_lora_weights.safetensors",
32
+ "default_strength": 0.8,
33
+ },
34
+ # Add more LoRAs here — they will be pre-loaded automatically
35
+ ]
36
+
37
+ # Global pipeline + pre-loaded adapter info
38
  pipe = None
39
+ lora_adapters = {} # name → {"adapter_name": str, "strength": float}
40
 
41
  def initialize_pipeline():
42
+ global pipe, lora_adapters
43
+
44
+ if pipe is not None:
45
+ return pipe
46
+
47
+ print("Loading Wan2.2-TI2V-5B base model...")
48
+ vae = AutoencoderKLWan.from_pretrained(
49
+ MODEL_ID,
50
+ subfolder="vae",
51
+ torch_dtype=torch.float32
52
+ )
53
+ pipe = WanPipeline.from_pretrained(
54
+ MODEL_ID,
55
+ vae=vae,
56
+ torch_dtype=dtype
57
+ )
58
+ pipe.to(device)
59
+ print("Base model loaded.")
60
+
61
+ # ── Pre-load ALL available LoRAs once ───────────────
62
+ print("Pre-loading LoRAs...")
63
+ for lora in AVAILABLE_LORAS:
64
+ name = lora["name"]
65
+ try:
66
+ print(f" → {name}")
67
+ pipe.load_lora_weights(
68
+ lora["repo_id"],
69
+ weight_name=lora["filename"],
70
+ adapter_name=name, # unique identifier
71
+ )
72
+ # Store for later hot-swapping
73
+ lora_adapters[name] = {
74
+ "adapter_name": name,
75
+ "strength": lora["default_strength"]
76
+ }
77
+ except Exception as e:
78
+ print(f" Failed to load {name}: {e}")
79
+
80
+ # Fuse once → best inference performance
81
+ if lora_adapters:
82
+ pipe.fuse_lora()
83
+ print("All LoRAs fused.")
84
+
85
+ print("Pipeline fully initialized.")
86
  return pipe
87
 
88
+ @spaces.GPU(duration=180)
89
  def generate_video(
90
  prompt: str,
91
  image: Image.Image = None,
 
94
  num_frames: int = 73,
95
  num_inference_steps: int = 35,
96
  guidance_scale: float = 5.0,
97
+ seed: int = -1,
98
+ enabled_loras: list = None,
99
+ lora_strength_multiplier: float = 1.0
100
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  try:
 
102
  pipeline = initialize_pipeline()
103
 
104
+ # ── Hot-swap / enable only selected LoRAs ───────
105
+ active_adapters = []
106
+ active_strengths = []
107
+
108
+ enabled = enabled_loras or []
109
+
110
+ for lora_name in enabled:
111
+ if lora_name in lora_adapters:
112
+ strength = lora_adapters[lora_name]["strength"] * lora_strength_multiplier
113
+ active_adapters.append(lora_name)
114
+ active_strengths.append(strength)
115
+
116
+ if active_adapters:
117
+ pipeline.set_adapters(active_adapters, adapter_strengths=active_strengths)
118
+ print(f"Activated LoRAs: {', '.join(active_adapters)}")
119
+ else:
120
+ pipeline.disable_lora() # important: turn off if none selected
121
+
122
+ # Lightning auto-optimization
123
+ if "Lightning (Fast 4-step)" in enabled and num_inference_steps > 8:
124
+ num_inference_steps = 4
125
+ print("Lightning LoRA → reduced to 4 steps")
126
+
127
+ # Seed
128
  if seed == -1:
129
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
130
  generator = torch.Generator(device=device).manual_seed(seed)
131
 
 
132
  gen_params = {
133
  "prompt": prompt,
134
  "height": height,
 
139
  "generator": generator,
140
  }
141
 
 
142
  if image is not None:
143
  gen_params["image"] = image
144
 
145
+ print(f"Generating: {width}x{height}, {num_frames} frames, steps={num_inference_steps}")
 
 
 
146
  output = pipeline(**gen_params).frames[0]
147
 
 
148
  output_path = "output.mp4"
149
  export_to_video(output, output_path, fps=24)
150
 
151
+ status = f"Done! Seed: {seed}"
152
+ if active_adapters:
153
+ status += f"\nLoRAs: {', '.join(active_adapters)} @ {lora_strength_multiplier:.2f}x"
154
 
155
+ return output_path, status
 
 
 
156
 
157
+ except Exception as e:
158
+ msg = f"Error: {str(e)}"
159
+ print(msg)
160
+ return None, msg
 
161
 
162
+ # ────────────────────────────────────────────────
163
+ # Gradio UI
164
+ # ────────────────────────────────────────────────
165
 
166
+ with gr.Blocks(title="Wan2.2 Video + Fast LoRA") as demo:
167
+ gr.Markdown("""
168
+ # Wan2.2-TI2V-5B Video Generation
169
+ **Optimized LoRA loading** — all LoRAs pre-loaded at startup, then hot-swapped instantly.
170
+ """)
171
 
172
  with gr.Row():
173
  with gr.Column():
 
174
  prompt_input = gr.Textbox(
175
+ label="Prompt", lines=3,
 
 
176
  value="Two anthropomorphic cats in comfy boxing gear fight on stage"
177
  )
178
+ image_input = gr.Image(label="Input Image (optional)", type="pil", sources=["upload"])
 
 
 
 
 
179
 
180
  with gr.Accordion("Advanced Settings", open=False):
181
  with gr.Row():
182
+ width_input = gr.Slider(512, 1920, step=64, value=1280, label="Width")
183
+ height_input = gr.Slider(512, 1080, step=64, value=704, label="Height")
184
+ num_frames_input = gr.Slider(25, 145, step=24, value=73, label="Frames")
185
+ num_steps_input = gr.Slider(4, 60, step=1, value=35, label="Inference Steps",
186
+ info="Lightning LoRA → try 4–8 steps")
187
+ guidance_scale_input = gr.Slider(1.0, 15.0, 0.5, value=5.0, label="Guidance Scale")
188
+ seed_input = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
189
+
190
+ with gr.Accordion("LoRA Controls", open=True):
191
+ lora_checkbox = gr.CheckboxGroup(
192
+ choices=[l["name"] for l in AVAILABLE_LORAS],
193
+ label="Enable LoRAs",
194
+ value=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
+ lora_strength = gr.Slider(0.1, 1.5, 0.05, value=1.0,
197
+ label="Global Strength Multiplier")
198
 
199
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
200
 
201
  with gr.Column():
202
+ video_output = gr.Video(label="Generated Video", autoplay=True)
203
+ status_output = gr.Textbox(label="Status", lines=3)
 
 
 
 
 
 
 
204
 
205
+ # Examples with LoRA usage
206
  gr.Examples(
207
  examples=[
208
+ ["Two anthropomorphic cats in comfy boxing gear fight on stage", None, 1280, 704, 73, 35, 5.0, 42, [], 1.0],
209
+ ["A serene underwater scene with colorful coral reefs...", None, 1280, 704, 73, 4, 5.0, 123, ["Lightning (Fast 4-step)"], 1.0],
210
+ ["Explicit adult scene, detailed", None, 1280, 704, 73, 30, 6.0, 999, ["General NSFW"], 0.9],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  ],
212
+ inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
213
+ num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
214
  outputs=[video_output, status_output],
215
  fn=generate_video,
216
  cache_examples=False,
217
  )
218
 
 
219
  generate_btn.click(
220
+ generate_video,
221
+ inputs=[prompt_input, image_input, width_input, height_input, num_frames_input,
222
+ num_steps_input, guidance_scale_input, seed_input, lora_checkbox, lora_strength],
 
 
 
 
 
 
 
 
223
  outputs=[video_output, status_output]
224
  )
225
 
226
+ gr.Markdown("""
227
+ ## Performance Notes
228
+ - LoRAs are **pre-loaded once** → first generation may take ~10–30s longer, later ones are fast.
229
+ - Lightning LoRA: use **4–8 steps** → generation can finish in <60s.
230
+ - Add new LoRAs by appending to `AVAILABLE_LORAS` they auto-load at startup.
231
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
232
 
 
233
  if __name__ == "__main__":
234
+ demo.queue(max_size=20).launch()