Revrse commited on
Commit
6c356c6
·
verified ·
1 Parent(s): a688b6e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -0
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # PyTorch 2.8 (temporary hack)
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # --- 1. Model Download and Setup (Diffusers Backend) ---
6
+ import spaces
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
10
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
11
+ from diffusers.utils.export_utils import export_to_video
12
+ import gradio as gr
13
+ import tempfile
14
+ import numpy as np
15
+ from PIL import Image
16
+ import random
17
+ import gc
18
+ from gradio_client import Client, handle_file # Import for API call
19
+
20
+ # Import the optimization function from the separate file
21
+ from optimization import optimize_pipeline_
22
+
23
+ # --- Constants and Model Loading ---
24
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
25
+
26
+ # --- NEW: Flexible Dimension Constants ---
27
+ MAX_DIMENSION = 832
28
+ MIN_DIMENSION = 480
29
+ DIMENSION_MULTIPLE = 16
30
+ SQUARE_SIZE = 480
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+
34
+ FIXED_FPS = 16
35
+ MIN_FRAMES_MODEL = 8
36
+ MAX_FRAMES_MODEL = 81
37
+
38
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
39
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
40
+
41
+ default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
42
+
43
+ print("Loading models into memory. This may take a few minutes...")
44
+
45
+ pipe = WanImageToVideoPipeline.from_pretrained(
46
+ MODEL_ID,
47
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
48
+ subfolder='transformer',
49
+ torch_dtype=torch.bfloat16,
50
+ device_map='cuda',
51
+ ),
52
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
53
+ subfolder='transformer_2',
54
+ torch_dtype=torch.bfloat16,
55
+ device_map='cuda',
56
+ ),
57
+ torch_dtype=torch.bfloat16,
58
+ )
59
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0)
60
+ pipe.to('cuda')
61
+
62
+
63
+
64
+ print("Optimizing pipeline...")
65
+ for i in range(3):
66
+ gc.collect()
67
+ torch.cuda.synchronize()
68
+ torch.cuda.empty_cache()
69
+
70
+ optimize_pipeline_(pipe,
71
+ image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
72
+ prompt='prompt',
73
+ height=MIN_DIMENSION,
74
+ width=MAX_DIMENSION,
75
+ num_frames=MAX_FRAMES_MODEL,
76
+ )
77
+ print("All models loaded and optimized. Gradio app is ready.")
78
+
79
+
80
+ # --- 2. Image Processing and Application Logic ---
81
+ def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)):
82
+ """Calls an external Gradio API to generate an image."""
83
+ if start_img is None:
84
+ raise gr.Error("Please provide a Start Frame first.")
85
+
86
+ hf_token = os.getenv("HF_TOKEN")
87
+ if not hf_token:
88
+ raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.")
89
+
90
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
91
+ start_img.save(tmpfile.name)
92
+ tmp_path = tmpfile.name
93
+
94
+ progress(0.1, desc="Connecting to image generation API...")
95
+ client = Client("multimodalart/nano-banana-private")
96
+
97
+ progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...")
98
+ try:
99
+ result = client.predict(
100
+ prompt=gen_prompt,
101
+ images=[
102
+ {"image": handle_file(tmp_path)}
103
+ ],
104
+ manual_token=hf_token,
105
+ api_name="/unified_image_generator"
106
+ )
107
+ finally:
108
+ os.remove(tmp_path)
109
+
110
+ progress(1.0, desc="Done!")
111
+ print(result)
112
+ return result
113
+
114
+ def switch_to_upload_tab():
115
+ """Returns a gr.Tabs update to switch to the first tab."""
116
+ return gr.Tabs(selected="upload_tab")
117
+
118
+
119
+ def process_image_for_video(image: Image.Image) -> Image.Image:
120
+ """
121
+ Resizes an image based on the following rules for video generation:
122
+ 1. The longest side will be scaled down to MAX_DIMENSION if it's larger.
123
+ 2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller.
124
+ 3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE.
125
+ 4. Square images are resized to a fixed SQUARE_SIZE.
126
+ The aspect ratio is preserved as closely as possible.
127
+ """
128
+ width, height = image.size
129
+
130
+ # Rule 4: Handle square images
131
+ if width == height:
132
+ return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
133
+
134
+ # Determine target dimensions while preserving aspect ratio
135
+ aspect_ratio = width / height
136
+ new_width, new_height = width, height
137
+
138
+ # Rule 1: Scale down if too large
139
+ if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
140
+ if aspect_ratio > 1: # Landscape
141
+ scale = MAX_DIMENSION / new_width
142
+ else: # Portrait
143
+ scale = MAX_DIMENSION / new_height
144
+ new_width *= scale
145
+ new_height *= scale
146
+
147
+ # Rule 2: Scale up if too small
148
+ if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
149
+ if aspect_ratio > 1: # Landscape
150
+ scale = MIN_DIMENSION / new_height
151
+ else: # Portrait
152
+ scale = MIN_DIMENSION / new_width
153
+ new_width *= scale
154
+ new_height *= scale
155
+
156
+ # Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE
157
+ final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
158
+ final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
159
+
160
+ # Ensure final dimensions are at least the minimum
161
+ final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
162
+ final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
163
+
164
+
165
+ return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
166
+
167
+ def resize_and_crop_to_match(target_image, reference_image):
168
+ """Resizes and center-crops the target image to match the reference image's dimensions."""
169
+ ref_width, ref_height = reference_image.size
170
+ target_width, target_height = target_image.size
171
+ scale = max(ref_width / target_width, ref_height / target_height)
172
+ new_width, new_height = int(target_width * scale), int(target_height * scale)
173
+ resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
174
+ left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
175
+ return resized.crop((left, top, left + ref_width, top + ref_height))
176
+
177
+ @spaces.GPU(duration=120)
178
+ def generate_video(
179
+ start_image_pil,
180
+ end_image_pil,
181
+ prompt,
182
+ negative_prompt=default_negative_prompt,
183
+ duration_seconds=2.1,
184
+ steps=8,
185
+ guidance_scale=1,
186
+ guidance_scale_2=1,
187
+ seed=42,
188
+ randomize_seed=False,
189
+ progress=gr.Progress(track_tqdm=True)
190
+ ):
191
+ """
192
+ Generates a video by interpolating between a start and end image, guided by a text prompt,
193
+ using the diffusers Wan2.2 pipeline.
194
+ """
195
+ if start_image_pil is None or end_image_pil is None:
196
+ raise gr.Error("Please upload both a start and an end image.")
197
+
198
+ progress(0.1, desc="Preprocessing images...")
199
+
200
+ # Step 1: Process the start image to get our target dimensions based on the new rules.
201
+ processed_start_image = process_image_for_video(start_image_pil)
202
+
203
+ # Step 2: Make the end image match the *exact* dimensions of the processed start image.
204
+ processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
205
+
206
+ target_height, target_width = processed_start_image.height, processed_start_image.width
207
+
208
+ # Handle seed and frame count
209
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
210
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
211
+
212
+ progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...")
213
+
214
+ output_frames_list = pipe(
215
+ image=processed_start_image,
216
+ last_image=processed_end_image,
217
+ prompt=prompt,
218
+ negative_prompt=negative_prompt,
219
+ height=target_height,
220
+ width=target_width,
221
+ num_frames=num_frames,
222
+ guidance_scale=float(guidance_scale),
223
+ guidance_scale_2=float(guidance_scale_2),
224
+ num_inference_steps=int(steps),
225
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
226
+ ).frames[0]
227
+
228
+ progress(0.9, desc="Encoding and saving video...")
229
+
230
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
231
+ video_path = tmpfile.name
232
+
233
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
234
+
235
+ progress(1.0, desc="Done!")
236
+ return video_path, current_seed
237
+
238
+
239
+ # --- 3. Simplified Gradio User Interface (Examples removed) ---
240
+
241
+ css = '''
242
+ .fillable{max-width: 1100px !important}
243
+ .section-title {font-size: 20px; margin-bottom: 8px;}
244
+ .kv {margin-bottom: 8px;}
245
+ .controls {gap: 8px;}
246
+ '''
247
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
248
+ # Header
249
+ with gr.Row():
250
+ gr.Markdown("### Wan 2.2 — First & Last Frame Video (Diffusers)")
251
+ gr.Markdown("Compact UI — examples removed.")
252
+
253
+ # Main layout: left inputs, right preview
254
+ with gr.Row():
255
+ with gr.Column(scale=6):
256
+ gr.Markdown("<div class='section-title'>Inputs</div>", elem_id="inputs_title")
257
+ start_image = gr.Image(type="pil", label="Start Frame", source="upload", elem_classes=["kv"])
258
+ end_image = gr.Image(type="pil", label="End Frame", source="upload", elem_classes=["kv"])
259
+
260
+ prompt = gr.Textbox(label="Prompt", placeholder="Describe the transition between frames", lines=2, elem_classes=["kv"])
261
+
262
+ # Quick generate button that creates an end-frame 5s after start
263
+ with gr.Row():
264
+ generate_5seconds = gr.Button("Generate End Frame (5s later)", elem_classes=["kv"])
265
+ generate_button = gr.Button("Generate Video", variant="primary", elem_classes=["kv"])
266
+
267
+ # Advanced settings collapsed in an accordion to keep UI lean
268
+ with gr.Accordion("Advanced Settings (click to open)", open=False):
269
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (s)")
270
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
271
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps")
272
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise")
273
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise")
274
+ with gr.Row():
275
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
276
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
277
+
278
+ with gr.Column(scale=6):
279
+ gr.Markdown("<div class='section-title'>Output</div>", elem_id="output_title")
280
+ output_video = gr.Video(label="Generated Video", autoplay=True)
281
+ seed_display = gr.Textbox(label="Seed Used (for reproducibility)", interactive=False)
282
+
283
+ # Hook up events
284
+ ui_inputs = [
285
+ start_image,
286
+ end_image,
287
+ prompt,
288
+ negative_prompt_input,
289
+ duration_seconds_input,
290
+ steps_slider,
291
+ guidance_scale_input,
292
+ guidance_scale_2_input,
293
+ seed_input,
294
+ randomize_seed_checkbox
295
+ ]
296
+ ui_outputs = [output_video, seed_input]
297
+
298
+ generate_button.click(
299
+ fn=generate_video,
300
+ inputs=ui_inputs,
301
+ outputs=ui_outputs
302
+ )
303
+
304
+ # Generate an end-frame from the start image, then switch back to input and populate End Frame
305
+ generate_5seconds.click(
306
+ fn=switch_to_upload_tab,
307
+ inputs=None,
308
+ outputs=None
309
+ ).then(
310
+ fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"),
311
+ inputs=[start_image],
312
+ outputs=[end_image]
313
+ )
314
+
315
+ if __name__ == "__main__":
316
+ app.launch(share=True)