tester343 commited on
Commit
cef6409
·
verified ·
1 Parent(s): 5fa715d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -208
app.py CHANGED
@@ -1,223 +1,347 @@
1
- # CACHE_BUSTER = "2025-12-15-ROBUST"
2
- # ======================================================
3
- # Wan I2V – ROBUST, CLEAN, HF-SPACES SAFE IMPLEMENTATION
4
- # ======================================================
5
- # Goals:
6
- # - Zero syntax errors (Python 3 only)
7
- # - No Gradio hot-reload issues
8
- # - Robust video export (diffusers → imageio → ffmpeg → opencv)
9
- # - Safe long-video chunking
10
- # - Clear logging
11
- # - HF Spaces compatible (ZeroGPU, CPU, CUDA)
12
-
13
  import os
14
- import uuid
15
- import time
16
- import shutil
17
- import logging
18
- from typing import List, Tuple
19
-
 
20
  import numpy as np
21
  from PIL import Image
22
- import gradio as gr
23
-
24
- # ------------------------------------------------------
25
- # Optional backends
26
- # ------------------------------------------------------
27
- try:
28
- from diffusers.utils import export_to_video as diffusers_export_to_video
29
- DIFFUSERS_AVAILABLE = True
30
- except Exception:
31
- DIFFUSERS_AVAILABLE = False
32
-
33
- try:
34
- import imageio
35
- IMAGEIO_AVAILABLE = True
36
- except Exception:
37
- IMAGEIO_AVAILABLE = False
38
-
39
- try:
40
- import cv2
41
- CV2_AVAILABLE = True
42
- except Exception:
43
- CV2_AVAILABLE = False
44
-
45
- # ------------------------------------------------------
46
- # Logging
47
- # ------------------------------------------------------
48
- logging.basicConfig(
49
- level=logging.INFO,
50
- format="[%(asctime)s] [%(levelname)s] %(message)s",
 
 
 
 
 
51
  )
52
- logger = logging.getLogger("wan_i2v")
53
-
54
- # ------------------------------------------------------
55
- # Frame normalization
56
- # ------------------------------------------------------
57
-
58
- def normalize_frame(frame: np.ndarray) -> np.ndarray:
59
- frame = np.asarray(frame)
60
-
61
- if frame.ndim == 4 and frame.shape[0] == 1:
62
- frame = frame[0]
63
-
64
- if frame.ndim == 2:
65
- frame = np.stack([frame] * 3, axis=-1)
66
-
67
- if frame.ndim == 3 and frame.shape[0] in (1, 3):
68
- frame = np.transpose(frame, (1, 2, 0))
69
-
70
- if frame.dtype != np.uint8:
71
- if frame.min() < 0:
72
- frame = (frame + 1.0) / 2.0
73
- frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
74
-
75
- return frame
76
-
77
- # ------------------------------------------------------
78
- # Robust video export
79
- # ------------------------------------------------------
80
-
81
- def export_video(frames: List[np.ndarray], out_path: str, fps: int) -> str:
82
- frames = [normalize_frame(f) for f in frames]
83
-
84
- # 1. diffusers
85
- if DIFFUSERS_AVAILABLE:
86
- try:
87
- diffusers_export_to_video(frames, out_path, fps=fps)
88
- return out_path
89
- except Exception as e:
90
- logger.warning(f"Diffusers export failed: {e}")
91
-
92
- # 2. imageio
93
- if IMAGEIO_AVAILABLE:
94
- try:
95
- writer = imageio.get_writer(out_path, fps=fps, macro_block_size=None)
96
- for f in frames:
97
- writer.append_data(f)
98
- writer.close()
99
- return out_path
100
- except Exception as e:
101
- logger.warning(f"ImageIO export failed: {e}")
102
-
103
- # 3. ffmpeg raw pipe
104
- ffmpeg = shutil.which("ffmpeg")
105
- if ffmpeg:
106
- try:
107
- import subprocess
108
- h, w = frames[0].shape[:2]
109
- cmd = [
110
- ffmpeg, "-y",
111
- "-f", "rawvideo",
112
- "-pix_fmt", "rgb24",
113
- "-s", f"{w}x{h}",
114
- "-r", str(fps),
115
- "-i", "-",
116
- "-an",
117
- "-c:v", "libx264",
118
- "-pix_fmt", "yuv420p",
119
- out_path,
120
- ]
121
- p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
122
- for f in frames:
123
- p.stdin.write(f.tobytes())
124
- p.stdin.close()
125
- p.wait()
126
- return out_path
127
- except Exception as e:
128
- logger.warning(f"FFmpeg export failed: {e}")
129
-
130
- # 4. OpenCV
131
- if CV2_AVAILABLE:
132
- try:
133
- h, w = frames[0].shape[:2]
134
- writer = cv2.VideoWriter(
135
- out_path,
136
- cv2.VideoWriter_fourcc(*"mp4v"),
137
- float(fps),
138
- (w, h),
139
- )
140
- for f in frames:
141
- writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
142
- writer.release()
143
- return out_path
144
- except Exception as e:
145
- logger.warning(f"OpenCV export failed: {e}")
146
-
147
- raise RuntimeError("All video export backends failed")
148
-
149
- # ------------------------------------------------------
150
- # Dummy inference (SAFE placeholder)
151
- # ------------------------------------------------------
152
-
153
- def infer_frames(image: Image.Image, num_frames: int) -> List[np.ndarray]:
154
- base = np.asarray(image.convert("RGB"))
155
- frames = []
156
- for i in range(num_frames):
157
- f = base.copy()
158
- f[:, :, 0] = np.clip(f[:, :, 0] + i, 0, 255)
159
- frames.append(f)
160
- return frames
161
-
162
- # ------------------------------------------------------
163
- # Long-video generator
164
- # ------------------------------------------------------
165
-
166
- def generate_video(
167
- image: Image.Image,
168
- total_frames: int,
169
- fps: int,
170
- seed: int,
171
- ) -> Tuple[str, int]:
172
-
173
- if image is None:
174
- raise gr.Error("Please upload an input image")
175
-
176
- np.random.seed(int(seed))
177
 
178
- CHUNK = 80
179
- frames_all: List[np.ndarray] = []
180
- remaining = total_frames
 
 
 
181
 
182
- while remaining > 0:
183
- take = min(CHUNK, remaining)
184
- frames_all.extend(infer_frames(image, take))
185
- remaining -= take
 
 
 
 
186
 
187
- out_path = f"/tmp/wan_{uuid.uuid4().hex}.mp4"
188
- export_video(frames_all, out_path, fps)
189
 
190
- return out_path, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- # ------------------------------------------------------
193
- # Gradio UI (single demo object – avoids hot-reload bug)
194
- # ------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- demo = gr.Blocks(title="Wan I2V Robust")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- with demo:
199
- gr.Markdown("## Wan I2V – Robust Stable Build")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  with gr.Row():
202
- input_image = gr.Image(type="pil", label="Input Image")
203
- fps = gr.Slider(8, 60, value=24, step=1, label="FPS")
204
-
205
- total_frames = gr.Slider(16, 480, value=80, step=1, label="Total Frames")
206
- seed = gr.Number(value=1234, label="Seed")
207
-
208
- run_btn = gr.Button("Generate Video")
209
- output_video = gr.Video(label="Output Video")
210
- output_seed = gr.Number(label="Seed Used")
211
-
212
- run_btn.click(
213
- fn=generate_video,
214
- inputs=[input_image, total_frames, fps, seed],
215
- outputs=[output_video, output_seed],
216
- )
217
-
218
- # ------------------------------------------------------
219
- # Launch
220
- # ------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  if __name__ == "__main__":
223
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import spaces
3
+ import torch
4
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
5
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
6
+ from diffusers.utils.export_utils import export_to_video
7
+ import gradio as gr
8
+ import tempfile
9
  import numpy as np
10
  from PIL import Image
11
+ import random
12
+ import gc
13
+ from torchao.quantization import quantize_
14
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
15
+ import aoti
16
+
17
+ # =========================================================
18
+ # MODEL CONFIGURATION
19
+ # =========================================================
20
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+
23
+ MAX_DIM = 832
24
+ MIN_DIM = 480
25
+ SQUARE_DIM = 640
26
+ MULTIPLE_OF = 16
27
+ MAX_SEED = np.iinfo(np.int32).max
28
+
29
+ FIXED_FPS = 16
30
+ # We will generate in chunks of ~5 seconds (81 frames) to reach 20s
31
+ CHUNK_DURATION = 5.0
32
+ TOTAL_DURATION_TARGET = 20.0
33
+
34
+ # =========================================================
35
+ # LOAD PIPELINE
36
+ # =========================================================
37
+ print("Loading pipeline components...")
38
+
39
+ # Load models in bfloat16
40
+ transformer = WanTransformer3DModel.from_pretrained(
41
+ MODEL_ID,
42
+ subfolder="transformer",
43
+ torch_dtype=torch.bfloat16,
44
+ token=HF_TOKEN
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ transformer_2 = WanTransformer3DModel.from_pretrained(
48
+ MODEL_ID,
49
+ subfolder="transformer_2",
50
+ torch_dtype=torch.bfloat16,
51
+ token=HF_TOKEN
52
+ )
53
 
54
+ print("Assembling pipeline...")
55
+ pipe = WanImageToVideoPipeline.from_pretrained(
56
+ MODEL_ID,
57
+ transformer=transformer,
58
+ transformer_2=transformer_2,
59
+ torch_dtype=torch.bfloat16,
60
+ token=HF_TOKEN
61
+ )
62
 
63
+ print("Moving to CUDA...")
64
+ pipe = pipe.to("cuda")
65
 
66
+ # =========================================================
67
+ # LOAD LORA ADAPTERS
68
+ # =========================================================
69
+ print("Loading LoRA adapters...")
70
+ try:
71
+ pipe.load_lora_weights(
72
+ "Kijai/WanVideo_comfy",
73
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
74
+ adapter_name="lightx2v"
75
+ )
76
+ pipe.load_lora_weights(
77
+ "Kijai/WanVideo_comfy",
78
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
79
+ adapter_name="lightx2v_2",
80
+ load_into_transformer_2=True
81
+ )
82
 
83
+ pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
84
+ pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
85
+ pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
86
+ pipe.unload_lora_weights()
87
+ print("LoRA loaded and fused successfully.")
88
+ except Exception as e:
89
+ print(f"Warning: Failed to load LoRA. Continuing without it. Error: {e}")
90
+
91
+ # =========================================================
92
+ # QUANTIZATION & AOT OPTIMIZATION
93
+ # =========================================================
94
+ print("Applying quantization...")
95
+ torch.cuda.empty_cache()
96
+ gc.collect()
97
 
98
+ try:
99
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
100
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
101
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
102
+
103
+ print("Loading AOTI blocks...")
104
+ aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
105
+ aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
106
+ except Exception as e:
107
+ print(f"Warning: Quantization/AOTI failed. Running in standard mode might OOM. Error: {e}")
108
+
109
+ # =========================================================
110
+ # DEFAULT PROMPTS
111
+ # =========================================================
112
+ default_prompt_i2v = "Make this image come alive with dynamic, cinematic human motion. Create smooth, natural, lifelike animation with fluid transitions, expressive body movement, realistic physics, and elegant camera flow. Deliver a polished, high-quality motion style that feels immersive, artistic, and visually captivating."
113
+
114
+ default_negative_prompt = (
115
+ "low quality, worst quality, motion artifacts, unstable motion, jitter, frame jitter, wobbling limbs, motion distortion, inconsistent movement, robotic movement, animation-like motion, awkward transitions, incorrect body mechanics, unnatural posing, off-balance poses, broken motion paths, frozen frames, duplicated frames, frame skipping, warped motion, stretching artifacts bad anatomy, incorrect proportions, deformed body, twisted torso, broken joints, dislocated limbs, distorted neck, unnatural spine curvature, malformed hands, extra fingers, missing fingers, fused fingers, distorted legs, extra limbs, collapsed feet, floating feet, foot sliding, foot jitter, backward walking, unnatural gait blurry details, long exposure blur, ghosting, shadow trails, smearing, washed-out colors, overexposure, underexposure, excessive contrast, blown highlights, poorly rendered clothing, fabric glitches, texture warping, clothing merging with body, incorrect cloth physics ugly background, cluttered scene, crowded background, random objects, unwanted text, subtitles, logos, graffiti, grain, noise, static artifacts, compression noise, jpeg artifacts, image-like stillness, painting-like look, cartoon texture, low-resolution textures"
116
+ )
117
 
118
+ # =========================================================
119
+ # IMAGE RESIZING LOGIC
120
+ # =========================================================
121
+ def resize_image(image: Image.Image) -> Image.Image:
122
+ width, height = image.size
123
+ if width == height:
124
+ return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
125
+
126
+ aspect_ratio = width / height
127
+ MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
128
+ MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
129
+
130
+ image_to_resize = image
131
+ if aspect_ratio > MAX_ASPECT_RATIO:
132
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
133
+ left = (width - crop_width) // 2
134
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
135
+ elif aspect_ratio < MIN_ASPECT_RATIO:
136
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
137
+ top = (height - crop_height) // 2
138
+ image_to_resize = image.crop((0, top, width, top + crop_height))
139
+
140
+ if width > height:
141
+ target_w = MAX_DIM
142
+ target_h = int(round(target_w / aspect_ratio))
143
+ else:
144
+ target_h = MAX_DIM
145
+ target_w = int(round(target_h * aspect_ratio))
146
+
147
+ final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
148
+ final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
149
+
150
+ final_w = max(MIN_DIM, min(MAX_DIM, final_w))
151
+ final_h = max(MIN_DIM, min(MAX_DIM, final_h))
152
+
153
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
154
+
155
+ def get_num_frames(duration_seconds: float):
156
+ return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), 8, 300))
157
+
158
+ # =========================================================
159
+ # MAIN GENERATION FUNCTION (REWRITTEN FOR LONG VIDEO)
160
+ # =========================================================
161
+ @spaces.GPU(duration=300) # Increased timeout for long generation
162
+ def generate_video(
163
+ input_image_path,
164
+ prompt,
165
+ steps=4,
166
+ negative_prompt=default_negative_prompt,
167
+ duration_seconds=20.0, # Defaulting to 20s
168
+ guidance_scale=1,
169
+ guidance_scale_2=1,
170
+ seed=42,
171
+ randomize_seed=False,
172
+ progress=gr.Progress(track_tqdm=True),
173
+ ):
174
+ # Cleanup memory
175
+ gc.collect()
176
+ torch.cuda.empty_cache()
177
+
178
+ try:
179
+ # 1. Validation checks
180
+ if not input_image_path:
181
+ raise gr.Error("Please upload an input image.")
182
+ if not os.path.exists(input_image_path):
183
+ raise gr.Error("Image file not found! Please re-upload the image.")
184
+
185
+ # 2. Setup
186
+ original_input_image = Image.open(input_image_path).convert("RGB")
187
+ current_input_image = resize_image(original_input_image)
188
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
189
+
190
+ # Determine number of iterations needed for 20s
191
+ # 20s / 5s chunks = 4 iterations
192
+ chunk_duration = 5.0
193
+ total_duration = float(duration_seconds)
194
+ iterations = int(np.ceil(total_duration / chunk_duration))
195
+
196
+ print(f"Starting Long Video Generation: {total_duration}s ({iterations} iterations)")
197
+
198
+ all_frames = []
199
+
200
+ # 3. The Autoregressive Loop
201
+ for i in range(iterations):
202
+ progress(i / iterations, desc=f"Generating Part {i+1} of {iterations}...")
203
+
204
+ # Calculate frames for this chunk
205
+ num_frames = get_num_frames(chunk_duration)
206
+
207
+ print(f"--- Generative Pass {i+1}: Seed {current_seed} ---")
208
+
209
+ output_frames_list = pipe(
210
+ image=current_input_image,
211
+ prompt=prompt,
212
+ negative_prompt=negative_prompt,
213
+ height=current_input_image.height,
214
+ width=current_input_image.width,
215
+ num_frames=num_frames,
216
+ guidance_scale=float(guidance_scale),
217
+ guidance_scale_2=float(guidance_scale_2),
218
+ num_inference_steps=int(steps),
219
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
220
+ ).frames[0]
221
+
222
+ # Store frames
223
+ # If this is not the first chunk, we drop the first frame of the new chunk
224
+ # because it is (theoretically) identical to the last frame of the previous chunk
225
+ if i > 0:
226
+ all_frames.extend(output_frames_list[1:])
227
+ else:
228
+ all_frames.extend(output_frames_list)
229
+
230
+ # Prepare for next iteration
231
+ # The last frame of this video becomes the input for the next video
232
+ # We convert the numpy/PIL frame back to PIL for the pipeline
233
+ last_frame = output_frames_list[-1]
234
+ current_input_image = last_frame
235
+
236
+ # Optional: Slightly shift seed per chunk to prevent looping artifacts,
237
+ # or keep it same for consistency. Keeping same is usually safer for style.
238
+
239
+ # Cleanup per chunk
240
+ del output_frames_list
241
+ torch.cuda.empty_cache()
242
+
243
+ # 4. Save Final Long Video
244
+ print(f"Stitching {len(all_frames)} frames...")
245
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
246
+ video_path = tmpfile.name
247
+
248
+ export_to_video(all_frames, video_path, fps=FIXED_FPS)
249
+
250
+ # Final Cleanup
251
+ del all_frames
252
+ del current_input_image
253
+ torch.cuda.empty_cache()
254
+ gc.collect()
255
+
256
+ return video_path, current_seed
257
+
258
+ except Exception as e:
259
+ print(f"Error during generation: {e}")
260
+ raise gr.Error(f"Generation failed: {str(e)}")
261
+
262
+ # =========================================================
263
+ # GRADIO UI
264
+ # =========================================================
265
+
266
+ # Google Analytics Script
267
+ ga_script = """
268
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1TD40BVM04"></script>
269
+ <script>
270
+ window.dataLayer = window.dataLayer || [];
271
+ function gtag(){dataLayer.push(arguments);}
272
+ gtag('js', new Date());
273
+
274
+ gtag('config', 'G-1TD40BVM04');
275
+ </script>
276
+ """
277
+
278
+ with gr.Blocks(theme=gr.themes.Soft(), head=ga_script) as demo:
279
+
280
+ # --- PROFESSIONAL YOUTUBE EMBED SECTION ---
281
+ gr.HTML("""
282
+ <div style="background: linear-gradient(135deg, #b90000 0%, #ff0000 100%); color: white; padding: 25px; border-radius: 16px; text-align: center; margin-bottom: 25px; box-shadow: 0 10px 30px rgba(185, 0, 0, 0.3);">
283
+ <div style="display: flex; align-items: center; justify-content: center; gap: 25px; flex-wrap: wrap; margin-bottom: 20px;">
284
+ <div style="display: flex; align-items: center; gap: 15px;">
285
+ <div style="background: white; width: 50px; height: 50px; border-radius: 50%; display: flex; align-items: center; justify-content: center; box-shadow: 0 4px 8px rgba(0,0,0,0.2);">
286
+ <span style="font-size: 24px;">▶️</span>
287
+ </div>
288
+ <div style="text-align: left;">
289
+ <h3 style="margin: 0; font-weight: 800; font-size: 22px; letter-spacing: 0.5px;">Imagination Engineering</h3>
290
+ <p style="margin: 4px 0 0 0; opacity: 0.95; font-size: 14px; font-weight: 400;">Mastering AI & Creative Tech</p>
291
+ </div>
292
+ </div>
293
+ <a href="https://www.youtube.com/@ImaginationEngineering" target="_blank" style="text-decoration: none;">
294
+ <button style="background-color: white; color: #cc0000; border: none; padding: 10px 28px; border-radius: 30px; font-weight: 700; cursor: pointer; transition: transform 0.2s, box-shadow 0.2s; font-size: 15px; box-shadow: 0 4px 12px rgba(0,0,0,0.2);">
295
+ SUBSCRIBE & WATCH 📺
296
+ </button>
297
+ </a>
298
+ </div>
299
+ </div>
300
+ """)
301
 
302
  with gr.Row():
303
+ with gr.Column(scale=1):
304
+ image_input = gr.Image(type="filepath", label="Input Image", elem_id="input_image")
305
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
306
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
307
+
308
+ with gr.Accordion("Advanced Settings", open=False):
309
+ # Set default to 20 seconds
310
+ duration_slider = gr.Slider(minimum=5.0, maximum=20.0, step=5.0, value=20.0, label="Duration (Seconds)")
311
+ steps_slider = gr.Slider(minimum=2, maximum=50, step=1, value=4, label="Inference Steps (per chunk)")
312
+ cfg_slider = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=1.0, label="Guidance Scale")
313
+ cfg_slider_2 = gr.Slider(minimum=1.0, maximum=10.0, step=0.1, value=1.0, label="Guidance Scale 2")
314
+ seed_input = gr.Number(label="Seed", value=42, precision=0)
315
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
316
+
317
+ generate_button = gr.Button("GENERATE LONG VIDEO (20s)", variant="primary", size="lg")
318
+
319
+ with gr.Column(scale=1):
320
+ video_output = gr.Video(label="Generated Video")
321
+
322
+ ui_inputs = [
323
+ image_input, prompt_input, steps_slider, negative_prompt_input,
324
+ duration_slider, cfg_slider, cfg_slider_2, seed_input, randomize_seed
325
+ ]
326
+
327
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
328
+
329
+ # --- BOTTOM ADVERTISEMENT BANNER ---
330
+ gr.HTML("""
331
+ <div style="background: linear-gradient(90deg, #4f46e5, #9333ea); color: white; padding: 15px; border-radius: 10px; text-align: center; margin-top: 20px; box-shadow: 0 4px 15px rgba(0,0,0,0.1);">
332
+ <div style="display: flex; align-items: center; justify-content: center; gap: 20px; flex-wrap: wrap;">
333
+ <div style="text-align: left;">
334
+ <h3 style="margin: 0; font-weight: bold; font-size: 18px;">✨ New: Dream Hub Pro (All-in-One)</h3>
335
+ <p style="margin: 5px 0 0 0; opacity: 0.9; font-size: 14px;">Access all your pro tools (Wan2.1, Qwen, Audio, Video Enhance) in one place!</p>
336
+ </div>
337
+ <a href="https://huggingface.co/spaces/dream2589632147/Dream-Hub-Pro" target="_blank" style="text-decoration: none;">
338
+ <button style="background-color: white; color: #4f46e5; border: none; padding: 10px 25px; border-radius: 25px; font-weight: bold; cursor: pointer; transition: all 0.2s; font-size: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.2);">
339
+ 🚀 Open Hub Pro Now
340
+ </button>
341
+ </a>
342
+ </div>
343
+ </div>
344
+ """)
345
 
346
  if __name__ == "__main__":
347
+ demo.queue().launch()