EdBanshee commited on
Commit
5523755
·
1 Parent(s): cce6ca9

Rest of code

Browse files
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # PyTorch 2.8 (temporary hack)
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # --- 1. Model Download and Setup (Diffusers Backend) ---
6
+ import spaces
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
10
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
11
+ from diffusers.utils.export_utils import export_to_video
12
+ import gradio as gr
13
+ import tempfile
14
+ import numpy as np
15
+ from PIL import Image
16
+ import random
17
+ import gc
18
+ from gradio_client import Client, handle_file # Import for API call
19
+
20
+ # Import the optimization function from the separate file
21
+ from optimization import optimize_pipeline_
22
+
23
+ # --- Constants and Model Loading ---
24
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
25
+
26
+ # --- NEW: Flexible Dimension Constants ---
27
+ MAX_DIMENSION = 832
28
+ MIN_DIMENSION = 480
29
+ DIMENSION_MULTIPLE = 16
30
+ SQUARE_SIZE = 480
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+
34
+ FIXED_FPS = 16
35
+ MIN_FRAMES_MODEL = 8
36
+ MAX_FRAMES_MODEL = 81
37
+
38
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
39
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
40
+
41
+ default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
42
+
43
+ print("Loading models into memory. This may take a few minutes...")
44
+
45
+ pipe = WanImageToVideoPipeline.from_pretrained(
46
+ MODEL_ID,
47
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
48
+ subfolder='transformer',
49
+ torch_dtype=torch.bfloat16,
50
+ device_map='cuda',
51
+ ),
52
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
53
+ subfolder='transformer_2',
54
+ torch_dtype=torch.bfloat16,
55
+ device_map='cuda',
56
+ ),
57
+ torch_dtype=torch.bfloat16,
58
+ )
59
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0)
60
+ pipe.to('cuda')
61
+
62
+
63
+
64
+ print("Optimizing pipeline...")
65
+ for i in range(3):
66
+ gc.collect()
67
+ torch.cuda.synchronize()
68
+ torch.cuda.empty_cache()
69
+
70
+ optimize_pipeline_(pipe,
71
+ image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
72
+ prompt='prompt',
73
+ height=MIN_DIMENSION,
74
+ width=MAX_DIMENSION,
75
+ num_frames=MAX_FRAMES_MODEL,
76
+ )
77
+ print("All models loaded and optimized. Gradio app is ready.")
78
+
79
+
80
+ # --- 2. Image Processing and Application Logic ---
81
+ def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)):
82
+ """Calls an external Gradio API to generate an image."""
83
+ if start_img is None:
84
+ raise gr.Error("Please provide a Start Frame first.")
85
+
86
+ hf_token = os.getenv("HF_TOKEN")
87
+ if not hf_token:
88
+ raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.")
89
+
90
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
91
+ start_img.save(tmpfile.name)
92
+ tmp_path = tmpfile.name
93
+
94
+ progress(0.1, desc="Connecting to image generation API...")
95
+ client = Client("multimodalart/nano-banana")
96
+
97
+ progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...")
98
+ try:
99
+ result = client.predict(
100
+ prompt=gen_prompt,
101
+ images=[
102
+ {"image": handle_file(tmp_path)}
103
+ ],
104
+ manual_token=hf_token,
105
+ api_name="/unified_image_generator"
106
+ )
107
+ finally:
108
+ os.remove(tmp_path)
109
+
110
+ progress(1.0, desc="Done!")
111
+ print(result)
112
+ return result
113
+
114
+ def switch_to_upload_tab():
115
+ """Returns a gr.Tabs update to switch to the first tab."""
116
+ return gr.Tabs(selected="upload_tab")
117
+
118
+
119
+ def process_image_for_video(image: Image.Image) -> Image.Image:
120
+ """
121
+ Resizes an image based on the following rules for video generation:
122
+ 1. The longest side will be scaled down to MAX_DIMENSION if it's larger.
123
+ 2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller.
124
+ 3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE.
125
+ 4. Square images are resized to a fixed SQUARE_SIZE.
126
+ The aspect ratio is preserved as closely as possible.
127
+ """
128
+ width, height = image.size
129
+
130
+ # Rule 4: Handle square images
131
+ if width == height:
132
+ return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
133
+
134
+ # Determine target dimensions while preserving aspect ratio
135
+ aspect_ratio = width / height
136
+ new_width, new_height = width, height
137
+
138
+ # Rule 1: Scale down if too large
139
+ if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
140
+ if aspect_ratio > 1: # Landscape
141
+ scale = MAX_DIMENSION / new_width
142
+ else: # Portrait
143
+ scale = MAX_DIMENSION / new_height
144
+ new_width *= scale
145
+ new_height *= scale
146
+
147
+ # Rule 2: Scale up if too small
148
+ if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
149
+ if aspect_ratio > 1: # Landscape
150
+ scale = MIN_DIMENSION / new_height
151
+ else: # Portrait
152
+ scale = MIN_DIMENSION / new_width
153
+ new_width *= scale
154
+ new_height *= scale
155
+
156
+ # Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE
157
+ final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
158
+ final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
159
+
160
+ # Ensure final dimensions are at least the minimum
161
+ final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
162
+ final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
163
+
164
+
165
+ return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
166
+
167
+ def resize_and_crop_to_match(target_image, reference_image):
168
+ """Resizes and center-crops the target image to match the reference image's dimensions."""
169
+ ref_width, ref_height = reference_image.size
170
+ target_width, target_height = target_image.size
171
+ scale = max(ref_width / target_width, ref_height / target_height)
172
+ new_width, new_height = int(target_width * scale), int(target_height * scale)
173
+ resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
174
+ left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
175
+ return resized.crop((left, top, left + ref_width, top + ref_height))
176
+
177
+ @spaces.GPU(duration=120)
178
+ def generate_video(
179
+ start_image_pil,
180
+ end_image_pil,
181
+ prompt,
182
+ negative_prompt=default_negative_prompt,
183
+ duration_seconds=2.1,
184
+ steps=8,
185
+ guidance_scale=1,
186
+ guidance_scale_2=1,
187
+ seed=42,
188
+ randomize_seed=False,
189
+ progress=gr.Progress(track_tqdm=True)
190
+ ):
191
+ """
192
+ Generates a video by interpolating between a start and end image, guided by a text prompt,
193
+ using the diffusers Wan2.2 pipeline.
194
+ """
195
+ if start_image_pil is None or end_image_pil is None:
196
+ raise gr.Error("Please upload both a start and an end image.")
197
+
198
+ progress(0.1, desc="Preprocessing images...")
199
+
200
+ # Step 1: Process the start image to get our target dimensions based on the new rules.
201
+ processed_start_image = process_image_for_video(start_image_pil)
202
+
203
+ # Step 2: Make the end image match the *exact* dimensions of the processed start image.
204
+ processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
205
+
206
+ target_height, target_width = processed_start_image.height, processed_start_image.width
207
+
208
+ # Handle seed and frame count
209
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
210
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
211
+
212
+ progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...")
213
+
214
+ output_frames_list = pipe(
215
+ image=processed_start_image,
216
+ last_image=processed_end_image,
217
+ prompt=prompt,
218
+ negative_prompt=negative_prompt,
219
+ height=target_height,
220
+ width=target_width,
221
+ num_frames=num_frames,
222
+ guidance_scale=float(guidance_scale),
223
+ guidance_scale_2=float(guidance_scale_2),
224
+ num_inference_steps=int(steps),
225
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
226
+ ).frames[0]
227
+
228
+ progress(0.9, desc="Encoding and saving video...")
229
+
230
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
231
+ video_path = tmpfile.name
232
+
233
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
234
+
235
+ progress(1.0, desc="Done!")
236
+ return video_path, current_seed
237
+
238
+
239
+ # --- 3. Gradio User Interface ---
240
+
241
+ css = '''
242
+ .fillable{max-width: 1100px !important}
243
+ .dark .progress-text {color: white}
244
+ #general_items{margin-top: 2em}
245
+ #group_all{overflow:visible}
246
+ #group_all .styler{overflow:visible}
247
+ #group_tabs .tabitem{padding: 0}
248
+ .tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;}
249
+ #component-9-button{width: 50%;justify-content: center}
250
+ #component-11-button{width: 50%;justify-content: center}
251
+ #or_item{text-align: center; padding-top: 1em; padding-bottom: 1em; font-size: 1.1em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
252
+ #fivesec{margin-top: 5em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
253
+ '''
254
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
255
+ gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
256
+ gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA")
257
+
258
+ with gr.Row(elem_id="general_items"):
259
+ with gr.Column():
260
+ with gr.Group(elem_id="group_all"):
261
+ with gr.Row():
262
+ start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"])
263
+ # Capture the Tabs component in a variable and assign IDs to tabs
264
+ with gr.Tabs(elem_id="group_tabs") as tabs:
265
+ with gr.TabItem("Upload", id="upload_tab"):
266
+ end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"])
267
+ with gr.TabItem("Generate", id="generate_tab"):
268
+ generate_5seconds = gr.Button("Generate scene 5 seconds in the future", elem_id="fivesec")
269
+ gr.Markdown("Generate a custom end-frame with an edit model like [Nano Banana](https://huggingface.co/spaces/multimodalart/nano-banana) or [Qwen Image Edit](https://huggingface.co/spaces/multimodalart/Qwen-Image-Edit-Fast)", elem_id="or_item")
270
+ prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
271
+
272
+ with gr.Accordion("Advanced Settings", open=False):
273
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
274
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
275
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps")
276
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise")
277
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise")
278
+ with gr.Row():
279
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
280
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
281
+
282
+ generate_button = gr.Button("Generate Video", variant="primary")
283
+
284
+ with gr.Column():
285
+ output_video = gr.Video(label="Generated Video", autoplay=True)
286
+
287
+ # Main video generation button
288
+ ui_inputs = [
289
+ start_image,
290
+ end_image,
291
+ prompt,
292
+ negative_prompt_input,
293
+ duration_seconds_input,
294
+ steps_slider,
295
+ guidance_scale_input,
296
+ guidance_scale_2_input,
297
+ seed_input,
298
+ randomize_seed_checkbox
299
+ ]
300
+ ui_outputs = [output_video, seed_input]
301
+
302
+ generate_button.click(
303
+ fn=generate_video,
304
+ inputs=ui_inputs,
305
+ outputs=ui_outputs
306
+ )
307
+
308
+ generate_5seconds.click(
309
+ fn=switch_to_upload_tab,
310
+ inputs=None,
311
+ outputs=[tabs]
312
+ ).then(
313
+ fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"),
314
+ inputs=[start_image],
315
+ outputs=[end_image]
316
+ ).success(
317
+ fn=generate_video,
318
+ inputs=ui_inputs,
319
+ outputs=ui_outputs
320
+ )
321
+
322
+ gr.Examples(
323
+ examples=[
324
+ ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
325
+ ["ugly_sonic.jpg", "squatting_sonic.png", "the character dodges the missiles"],
326
+ ["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"],
327
+ ],
328
+ inputs=[start_image, end_image, prompt],
329
+ outputs=ui_outputs,
330
+ fn=generate_video,
331
+ cache_examples="lazy",
332
+ )
333
+
334
+ if __name__ == "__main__":
335
+ app.launch(share=True)
capyabara.webp ADDED

Git LFS Details

  • SHA256: 26f8ee938a1f453a81e85c2035e3787b1e5ddbb9a92acb01688b39abd987c1e8
  • Pointer size: 131 Bytes
  • Size of remote file: 467 kB
capyabara_zoomed.png ADDED

Git LFS Details

  • SHA256: 37c27e972f09ab9b1c7df8aaa4b7c2cdbb702466e5bb0fecf5cb502ee531a26c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.58 MB
optimization.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import drain_module_parameters
18
+
19
+
20
+ P = ParamSpec('P')
21
+
22
+ # --- CORRECTED DYNAMIC SHAPING ---
23
+
24
+ # VAE temporal scale factor is 1, latent_frames = num_frames. Range is [8, 81].
25
+ LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
26
+
27
+ # The transformer has a patch_size of (1, 2, 2), which means the input latent height and width
28
+ # are effectively divided by 2. This creates constraints that fail if the symbolic tracer
29
+ # assumes odd numbers are possible.
30
+ #
31
+ # To solve this, we define the dynamic dimension for the *patched* (i.e., post-division) size,
32
+ # and then express the input shape as 2 * this dimension. This mathematically guarantees
33
+ # to the compiler that the input latent dimensions are always even, satisfying the constraints.
34
+
35
+ # App range for pixel dimensions: [480, 832]. VAE scale factor is 8.
36
+ # Latent dimension range: [480/8, 832/8] = [60, 104].
37
+ # Patched latent dimension range: [60/2, 104/2] = [30, 52].
38
+ LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
39
+ LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
40
+
41
+ # Now, we define the dynamic shapes for the transformer's `hidden_states` input,
42
+ # which has the shape (batch_size, channels, num_frames, height, width).
43
+ TRANSFORMER_DYNAMIC_SHAPES = {
44
+ 'hidden_states': {
45
+ 2: LATENT_FRAMES_DIM,
46
+ 3: 2 * LATENT_PATCHED_HEIGHT_DIM, # Guarantees even height
47
+ 4: 2 * LATENT_PATCHED_WIDTH_DIM, # Guarantees even width
48
+ },
49
+ }
50
+
51
+ # --- END OF CORRECTION ---
52
+
53
+
54
+ INDUCTOR_CONFIGS = {
55
+ 'conv_1x1_as_mm': True,
56
+ 'epilogue_fusion': False,
57
+ 'coordinate_descent_tuning': True,
58
+ 'coordinate_descent_check_all_directions': True,
59
+ 'max_autotune': True,
60
+ 'triton.cudagraphs': True,
61
+ }
62
+
63
+
64
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
65
+
66
+ @spaces.GPU(duration=1500)
67
+ def compile_transformer():
68
+
69
+ # This LoRA fusion part remains the same
70
+ pipeline.load_lora_weights(
71
+ "Kijai/WanVideo_comfy",
72
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
73
+ adapter_name="lightx2v"
74
+ )
75
+ kwargs_lora = {}
76
+ kwargs_lora["load_into_transformer_2"] = True
77
+ pipeline.load_lora_weights(
78
+ "Kijai/WanVideo_comfy",
79
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
80
+ adapter_name="lightx2v_2", **kwargs_lora
81
+ )
82
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
83
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
84
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
85
+ pipeline.unload_lora_weights()
86
+
87
+ # Capture a single call to get the args/kwargs structure
88
+ with capture_component_call(pipeline, 'transformer') as call:
89
+ pipeline(*args, **kwargs)
90
+
91
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
92
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
93
+
94
+ # Quantization remains the same
95
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
96
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
97
+
98
+ # --- SIMPLIFIED COMPILATION ---
99
+
100
+ exported_1 = torch.export.export(
101
+ mod=pipeline.transformer,
102
+ args=call.args,
103
+ kwargs=call.kwargs,
104
+ dynamic_shapes=dynamic_shapes,
105
+ )
106
+
107
+ exported_2 = torch.export.export(
108
+ mod=pipeline.transformer_2,
109
+ args=call.args,
110
+ kwargs=call.kwargs,
111
+ dynamic_shapes=dynamic_shapes,
112
+ )
113
+
114
+ compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
115
+ compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
116
+
117
+ # Return the two compiled models
118
+ return compiled_1, compiled_2
119
+
120
+
121
+ # Quantize text encoder (same as before)
122
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
123
+
124
+ # Get the two dynamically-shaped compiled models
125
+ compiled_transformer_1, compiled_transformer_2 = compile_transformer()
126
+
127
+ # --- SIMPLIFIED ASSIGNMENT ---
128
+
129
+ pipeline.transformer.forward = compiled_transformer_1
130
+ drain_module_parameters(pipeline.transformer)
131
+
132
+ pipeline.transformer_2.forward = compiled_transformer_2
133
+ drain_module_parameters(pipeline.transformer_2)
optimization_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+
52
+ def aoti_compile(
53
+ exported_program: torch.export.ExportedProgram,
54
+ inductor_configs: dict[str, Any] | None = None,
55
+ ):
56
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
+ gm = cast(torch.fx.GraphModule, exported_program.module())
58
+ assert exported_program.example_inputs is not None
59
+ args, kwargs = exported_program.example_inputs
60
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
+ archive_file = BytesIO()
62
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
+ package_aoti(archive_file, files)
64
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def capture_component_call(
71
+ pipeline: Any,
72
+ component_name: str,
73
+ component_method='forward',
74
+ ):
75
+
76
+ class CapturedCallException(Exception):
77
+ def __init__(self, *args, **kwargs):
78
+ super().__init__()
79
+ self.args = args
80
+ self.kwargs = kwargs
81
+
82
+ class CapturedCall:
83
+ def __init__(self):
84
+ self.args: tuple[Any, ...] = ()
85
+ self.kwargs: dict[str, Any] = {}
86
+
87
+ component = getattr(pipeline, component_name)
88
+ captured_call = CapturedCall()
89
+
90
+ def capture_call(*args, **kwargs):
91
+ raise CapturedCallException(*args, **kwargs)
92
+
93
+ with patch.object(component, component_method, new=capture_call):
94
+ try:
95
+ yield captured_call
96
+ except CapturedCallException as e:
97
+ captured_call.args = e.args
98
+ captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)
poli_tower.png ADDED

Git LFS Details

  • SHA256: 96bc0e056b5aee2d2f1ed7723bab4f9c928dfb519ec21380aff4bbb12d22b849
  • Pointer size: 132 Bytes
  • Size of remote file: 3.49 MB
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/linoytsaban/diffusers.git@wan22-loras
2
+
3
+ transformers
4
+ accelerate
5
+ safetensors
6
+ sentencepiece
7
+ peft
8
+ ftfy
9
+ imageio-ffmpeg
10
+ opencv-python
11
+ torchao==0.11.0
squatting_sonic.png ADDED

Git LFS Details

  • SHA256: d5675e8192c6274c22b07cb60af92b8577d9fcf26f79a10450a325e385e17e18
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
tower_takes_off.png ADDED

Git LFS Details

  • SHA256: 3f824eae87d73d1b841354fcb96cfe5f7d08f8f2d6410bfaed864ecaf1500499
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
ugly_sonic.jpg ADDED

Git LFS Details

  • SHA256: 37f76cf1cbb3a3fa0a6eb26898c8f89f71fa280d13f30fcc9dfdd3709cb9824d
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB