tester343 commited on
Commit
2292603
·
verified ·
1 Parent(s): b51f7ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -64
app.py CHANGED
@@ -8,30 +8,31 @@ import numpy as np
8
  import gradio as gr
9
  from PIL import Image
10
 
 
 
 
 
11
  # =========================================================
12
- # 1. CONFIGURATION - USE SMALLER 1.3B MODEL
13
  # =========================================================
14
- # The 14B model is too large for ZeroGPU free tier
15
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" # Or use 1.3B if available
16
- LORA_REPO = "Kijai/WanVideo_comfy"
17
- LORA_NAME = "Lightx2v/lightx2v_I2V_480p_bf16.safetensors"
18
-
19
  HF_TOKEN = os.environ.get("HF_TOKEN")
20
 
 
21
  MAX_DIM = 480
22
  MIN_DIM = 480
23
  MULTIPLE_OF = 16
24
  MAX_SEED = np.iinfo(np.int32).max
25
  FIXED_FPS = 16
26
 
27
- # Global pipeline holder
28
- _pipe = None
29
 
30
  # =========================================================
31
  # 2. HELPER FUNCTIONS
32
  # =========================================================
33
  def resize_image(image: Image.Image) -> Image.Image:
34
- """Resize image to safe dimensions."""
35
  width, height = image.size
36
  aspect = width / height
37
 
@@ -42,77 +43,96 @@ def resize_image(image: Image.Image) -> Image.Image:
42
  w = MIN_DIM
43
  h = int(w / aspect)
44
 
 
45
  w = (round(w / MULTIPLE_OF) * MULTIPLE_OF)
46
  h = (round(h / MULTIPLE_OF) * MULTIPLE_OF)
 
 
47
  w = min(max(w, MIN_DIM), MAX_DIM)
48
  h = min(max(h, MIN_DIM), MAX_DIM)
49
 
50
  return image.resize((w, h), Image.LANCZOS)
51
 
52
  def cleanup():
 
53
  gc.collect()
54
  if torch.cuda.is_available():
55
  torch.cuda.empty_cache()
56
 
57
  # =========================================================
58
- # 3. SIMPLE GENERATION - MINIMAL LOADING
59
  # =========================================================
60
- @spaces.GPU(duration=180)
61
  def generate(
62
  image_path: str,
63
  prompt: str,
64
  duration: float = 3.0,
65
- steps: int = 4,
66
- guidance: float = 1.0,
67
  seed: int = 42,
68
  randomize: bool = True,
69
  progress=gr.Progress(track_tqdm=True)
70
  ):
71
- """Generate video with minimal overhead."""
72
- global _pipe
73
 
74
  if not image_path:
75
  raise gr.Error("Please upload an image.")
76
 
77
- try:
78
- progress(0.1, desc="Initializing...")
79
-
80
- # Import inside function
81
- from diffusers import AutoPipelineForImage2Video
82
- from diffusers.utils import export_to_video
83
 
84
- # Load pipeline only once
85
- if _pipe is None:
86
- progress(0.2, desc="Loading model (first run)...")
87
- print("⏳ Loading pipeline...")
88
-
89
- _pipe = AutoPipelineForImage2Video.from_pretrained(
90
  MODEL_ID,
91
  torch_dtype=torch.bfloat16,
92
  token=HF_TOKEN,
93
  )
94
- _pipe.to("cuda")
95
- print("✅ Pipeline loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Prepare inputs
98
- progress(0.4, desc="Processing...")
99
  img = Image.open(image_path).convert("RGB")
100
  img = resize_image(img)
101
 
102
  final_seed = random.randint(0, MAX_SEED) if randomize else int(seed)
103
- num_frames = max(8, min(int(duration * FIXED_FPS), 49))
104
 
105
- print(f"📐 {img.size}, frames={num_frames}, seed={final_seed}")
106
-
107
- # Generate
108
- progress(0.5, desc="Generating video...")
109
- cleanup()
 
 
 
 
 
 
 
110
 
111
  with torch.inference_mode():
112
- output = _pipe(
113
  image=img,
114
  prompt=prompt,
115
- negative_prompt="low quality, blur, distortion",
116
  height=img.height,
117
  width=img.width,
118
  num_frames=num_frames,
@@ -123,7 +143,7 @@ def generate(
123
 
124
  frames = output.frames[0]
125
 
126
- # Save
127
  progress(0.9, desc="Saving...")
128
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
129
  video_path = f.name
@@ -131,18 +151,16 @@ def generate(
131
  export_to_video(frames, video_path, fps=FIXED_FPS)
132
 
133
  cleanup()
134
- print(f"✅ Saved: {video_path}")
135
-
136
  return video_path, final_seed
137
 
138
  except Exception as e:
139
  cleanup()
140
- error_msg = str(e)
141
- print(f"❌ {error_msg}")
142
-
143
- if "out of memory" in error_msg.lower():
144
- raise gr.Error("Out of memory. Try shorter duration or smaller image.")
145
- raise gr.Error(f"Error: {error_msg[:150]}")
146
 
147
  # =========================================================
148
  # 4. GRADIO UI
@@ -151,48 +169,49 @@ with gr.Blocks() as demo:
151
  gr.HTML("""
152
  <div style="text-align:center; padding:20px; background:linear-gradient(135deg,#1e3c72,#2a5298);
153
  color:white; border-radius:12px; margin-bottom:20px;">
154
- <h1>🎬 Wan Video Generator</h1>
155
- <p>Image to Video • Optimized for ZeroGPU</p>
156
  </div>
157
  """)
158
 
159
  with gr.Row():
160
  with gr.Column():
161
- img_in = gr.Image(type="filepath", label="📷 Image")
162
  prompt = gr.Textbox(
163
  label="✍️ Prompt",
164
- value="Smooth cinematic motion, high quality, natural movement",
165
  lines=2
166
  )
167
 
168
  with gr.Row():
169
- duration = gr.Slider(1, 5, value=3, step=0.5, label="Duration (s)")
170
- steps = gr.Slider(2, 8, value=4, step=1, label="Steps")
 
171
 
172
  with gr.Row():
173
  seed = gr.Number(value=42, label="Seed", precision=0)
174
- randomize = gr.Checkbox(value=True, label="Random")
175
 
176
- btn = gr.Button("🚀 Generate", variant="primary")
177
 
178
  with gr.Column():
179
  video_out = gr.Video(label="🎥 Result")
180
- seed_out = gr.Number(label="Seed", precision=0)
181
 
182
  gr.HTML("""
183
- <div style="background:#e7f3ff; padding:12px; border-radius:8px; margin-top:10px;">
184
- <b>💡 Tips:</b><br>
185
- Keep duration short (2-3s) for best results<br>
186
- First generation takes longer (loading model)<br>
187
- If error, wait a moment and retry
188
  </div>
189
  """)
190
 
191
  btn.click(
192
  fn=generate,
193
- inputs=[img_in, prompt, duration, steps, gr.Number(value=1.0, visible=False), seed, randomize],
194
  outputs=[video_out, seed_out]
195
  )
196
 
197
  if __name__ == "__main__":
198
- demo.queue(max_size=1).launch()
 
8
  import gradio as gr
9
  from PIL import Image
10
 
11
+ # Use the specific pipeline class for Wan models
12
+ from diffusers import WanImageToVideoPipeline
13
+ from diffusers.utils import export_to_video
14
+
15
  # =========================================================
16
+ # 1. CONFIGURATION
17
  # =========================================================
18
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
 
 
 
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN")
20
 
21
+ # Strict dimensions for the 14B model to prevent crashes
22
  MAX_DIM = 480
23
  MIN_DIM = 480
24
  MULTIPLE_OF = 16
25
  MAX_SEED = np.iinfo(np.int32).max
26
  FIXED_FPS = 16
27
 
28
+ # Global variable to hold the model in memory between runs
29
+ global_pipe = None
30
 
31
  # =========================================================
32
  # 2. HELPER FUNCTIONS
33
  # =========================================================
34
  def resize_image(image: Image.Image) -> Image.Image:
35
+ """Resize image to exactly 480p to keep the 14B model happy."""
36
  width, height = image.size
37
  aspect = width / height
38
 
 
43
  w = MIN_DIM
44
  h = int(w / aspect)
45
 
46
+ # Enforce multiples of 16
47
  w = (round(w / MULTIPLE_OF) * MULTIPLE_OF)
48
  h = (round(h / MULTIPLE_OF) * MULTIPLE_OF)
49
+
50
+ # Hard cap
51
  w = min(max(w, MIN_DIM), MAX_DIM)
52
  h = min(max(h, MIN_DIM), MAX_DIM)
53
 
54
  return image.resize((w, h), Image.LANCZOS)
55
 
56
  def cleanup():
57
+ """Force garbage collection to free VRAM."""
58
  gc.collect()
59
  if torch.cuda.is_available():
60
  torch.cuda.empty_cache()
61
 
62
  # =========================================================
63
+ # 3. GENERATION LOGIC
64
  # =========================================================
65
+ @spaces.GPU(duration=240) # 4 Minute timeout
66
  def generate(
67
  image_path: str,
68
  prompt: str,
69
  duration: float = 3.0,
70
+ steps: int = 15, # Increased slightly for quality
71
+ guidance: float = 5.0,
72
  seed: int = 42,
73
  randomize: bool = True,
74
  progress=gr.Progress(track_tqdm=True)
75
  ):
76
+ global global_pipe
 
77
 
78
  if not image_path:
79
  raise gr.Error("Please upload an image.")
80
 
81
+ # 1. LOAD MODEL (Lazy Loading)
82
+ # We only load it once. If it's already loaded, we skip this.
83
+ if global_pipe is None:
84
+ print("⏳ Loading Wan 14B Pipeline... (This happens only once)")
85
+ progress(0.1, desc="Loading Model (One-time setup)...")
 
86
 
87
+ try:
88
+ # Load in bfloat16 to save memory
89
+ global_pipe = WanImageToVideoPipeline.from_pretrained(
 
 
 
90
  MODEL_ID,
91
  torch_dtype=torch.bfloat16,
92
  token=HF_TOKEN,
93
  )
94
+
95
+ # CRITICAL OPTIMIZATION FOR ZERO GPU:
96
+ # 1. CPU Offload: Moves layers to CPU when not in use. Essential for 14B.
97
+ global_pipe.enable_model_cpu_offload()
98
+
99
+ # 2. VAE Tiling: Prevents VRAM explosion during decoding.
100
+ global_pipe.enable_vae_tiling()
101
+
102
+ print("✅ Model loaded and optimized.")
103
+
104
+ except Exception as e:
105
+ print(f"❌ Load Error: {e}")
106
+ raise gr.Error(f"Failed to load model: {e}")
107
+
108
+ # 2. PROCESS INPUT
109
+ try:
110
+ progress(0.3, desc="Processing Image...")
111
+ cleanup()
112
 
 
 
113
  img = Image.open(image_path).convert("RGB")
114
  img = resize_image(img)
115
 
116
  final_seed = random.randint(0, MAX_SEED) if randomize else int(seed)
 
117
 
118
+ # Wan generally produces 16fps.
119
+ # 5 seconds = 81 frames usually.
120
+ # We ensure we don't ask for too many frames to avoid timeout.
121
+ num_frames = int(duration * FIXED_FPS)
122
+ # Ensure divisible by 4 plus 1 for Wan specifics if needed, but standard int is usually fine
123
+ if (num_frames - 1) % 4 != 0:
124
+ num_frames += (4 - ((num_frames - 1) % 4))
125
+
126
+ print(f"🎬 Generating: {img.size} | Frames: {num_frames} | Seed: {final_seed}")
127
+
128
+ # 3. RUN INFERENCE
129
+ progress(0.4, desc="Dreaming...")
130
 
131
  with torch.inference_mode():
132
+ output = global_pipe(
133
  image=img,
134
  prompt=prompt,
135
+ negative_prompt="low quality, blur, distortion, morphing, jitter, artifacts",
136
  height=img.height,
137
  width=img.width,
138
  num_frames=num_frames,
 
143
 
144
  frames = output.frames[0]
145
 
146
+ # 4. SAVE VIDEO
147
  progress(0.9, desc="Saving...")
148
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
149
  video_path = f.name
 
151
  export_to_video(frames, video_path, fps=FIXED_FPS)
152
 
153
  cleanup()
154
+ print(f"✅ Video saved: {video_path}")
 
155
  return video_path, final_seed
156
 
157
  except Exception as e:
158
  cleanup()
159
+ print(f"❌ Error: {e}")
160
+ # Detect memory errors
161
+ if "out of memory" in str(e).lower():
162
+ raise gr.Error("GPU Out of Memory. Try a shorter duration.")
163
+ raise gr.Error(f"Generation Error: {str(e)[:200]}")
 
164
 
165
  # =========================================================
166
  # 4. GRADIO UI
 
169
  gr.HTML("""
170
  <div style="text-align:center; padding:20px; background:linear-gradient(135deg,#1e3c72,#2a5298);
171
  color:white; border-radius:12px; margin-bottom:20px;">
172
+ <h1>🎬 Wan 14B Video Generator</h1>
173
+ <p>Image to Video • Optimized for ZeroGPU • 14B Parameters</p>
174
  </div>
175
  """)
176
 
177
  with gr.Row():
178
  with gr.Column():
179
+ img_in = gr.Image(type="filepath", label="📷 Input Image")
180
  prompt = gr.Textbox(
181
  label="✍️ Prompt",
182
+ value="Cinematic slow motion, high quality, natural movement, 4k",
183
  lines=2
184
  )
185
 
186
  with gr.Row():
187
+ # Limited duration for safety on free tier
188
+ duration = gr.Slider(2, 5, value=4, step=1, label="Duration (seconds)")
189
+ steps = gr.Slider(10, 30, value=15, step=1, label="Quality Steps")
190
 
191
  with gr.Row():
192
  seed = gr.Number(value=42, label="Seed", precision=0)
193
+ randomize = gr.Checkbox(value=True, label="Randomize Seed")
194
 
195
+ btn = gr.Button("🚀 Generate Video", variant="primary")
196
 
197
  with gr.Column():
198
  video_out = gr.Video(label="🎥 Result")
199
+ seed_out = gr.Number(label="Used Seed", precision=0)
200
 
201
  gr.HTML("""
202
+ <div style="background:#f0f0f0; padding:12px; border-radius:8px; margin-top:10px; color:#333;">
203
+ <b>💡 Notes:</b><br>
204
+ <b>First Run:</b> Takes ~60s to load the model.<br>
205
+ <b>Subsequent Runs:</b> Much faster.<br>
206
+ <b>Limit:</b> Max 5 seconds recommended to avoid crashes.
207
  </div>
208
  """)
209
 
210
  btn.click(
211
  fn=generate,
212
+ inputs=[img_in, prompt, duration, steps, gr.Number(value=5.0, visible=False), seed, randomize],
213
  outputs=[video_out, seed_out]
214
  )
215
 
216
  if __name__ == "__main__":
217
+ demo.queue().launch()