DawnC commited on
Commit
c56db8d
·
verified ·
1 Parent(s): 20e73d1

Upload 2 files

Browse files
Files changed (2) hide show
  1. FlowFacade.py +7 -8
  2. VideoEngine_optimized.py +355 -0
FlowFacade.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  from typing import Tuple, Optional
6
- from VideoEngine import VideoEngine
7
  from TextProcessor import TextProcessor
8
 
9
  try:
@@ -29,7 +29,7 @@ class FlowFacade:
29
  def _calculate_gpu_duration(self, image: Image.Image, duration_seconds: float,
30
  num_inference_steps: int, enable_prompt_expansion: bool, **kwargs) -> int:
31
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
32
- BASE_STEP_DURATION = 20 # Sequential CPU offload (conservative estimate)
33
 
34
  resized_image = self.video_engine.resize_image(image)
35
  width, height = resized_image.width, resized_image.height
@@ -39,16 +39,15 @@ class FlowFacade:
39
  step_duration = BASE_STEP_DURATION * factor ** 1.5
40
  total_duration = int(num_inference_steps) * step_duration
41
 
42
- # Add overhead for first-time model loading (CPU LoRA fusion)
43
  if not self.video_engine.is_loaded:
44
- total_duration += 90 # ~90s for CPU LoRA fusion
45
 
46
  if enable_prompt_expansion:
47
- total_duration += 60
48
 
49
- # Conservative minimum: 300 seconds (5 minutes)
50
- # No more NVML errors! Just need enough time for sequential offload
51
- return max(int(total_duration), 300)
52
 
53
  @spaces.GPU(duration=_calculate_gpu_duration)
54
  def generate_video_from_image(self, image: Image.Image, user_instruction: str,
 
3
  import numpy as np
4
  from PIL import Image
5
  from typing import Tuple, Optional
6
+ from VideoEngine_optimized import VideoEngine
7
  from TextProcessor import TextProcessor
8
 
9
  try:
 
29
  def _calculate_gpu_duration(self, image: Image.Image, duration_seconds: float,
30
  num_inference_steps: int, enable_prompt_expansion: bool, **kwargs) -> int:
31
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
32
+ BASE_STEP_DURATION = 8 # FP8 + AOTI optimized (fast direct GPU)
33
 
34
  resized_image = self.video_engine.resize_image(image)
35
  width, height = resized_image.width, resized_image.height
 
39
  step_duration = BASE_STEP_DURATION * factor ** 1.5
40
  total_duration = int(num_inference_steps) * step_duration
41
 
42
+ # Add overhead for first-time model loading (FP8 quantization + AOTI)
43
  if not self.video_engine.is_loaded:
44
+ total_duration += 60 # ~60s for FP8 quantization and AOTI loading
45
 
46
  if enable_prompt_expansion:
47
+ total_duration += 40
48
 
49
+ # Optimized minimum: 90 seconds (FP8 + AOTI is much faster)
50
+ return max(int(total_duration), 90)
 
51
 
52
  @spaces.GPU(duration=_calculate_gpu_duration)
53
  def generate_video_from_image(self, image: Image.Image, user_instruction: str,
VideoEngine_optimized.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeltaFlow - Video Engine (FP8 + AOTI Optimized)
3
+ Ultra-fast Image-to-Video generation using Wan2.2-I2V-A14B
4
+ Features: Lightning LoRA + FP8 Quantization + AOTI Compilation
5
+ ~30-40s inference (vs 150s baseline)
6
+ """
7
+
8
+ import warnings
9
+ warnings.filterwarnings('ignore', category=FutureWarning)
10
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
11
+
12
+ import gc
13
+ import os
14
+ import tempfile
15
+ import traceback
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
+ # Critical dependencies
23
+ import ftfy
24
+ import sentencepiece
25
+
26
+ # Diffusers imports
27
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
28
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
29
+ from diffusers.utils.export_utils import export_to_video
30
+
31
+
32
+ class VideoEngine:
33
+ """
34
+ Ultra-fast video generation with FP8 quantization and AOTI compilation.
35
+ 30-40s inference time (compared to 150s baseline).
36
+ """
37
+
38
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
39
+ TRANSFORMER_REPO = "cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers"
40
+ LORA_REPO = "Kijai/WanVideo_comfy"
41
+ LORA_WEIGHT = "Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
42
+
43
+ # Model parameters
44
+ MAX_DIM = 832
45
+ MIN_DIM = 480
46
+ SQUARE_DIM = 640
47
+ MULTIPLE_OF = 16
48
+ FIXED_FPS = 16
49
+ MIN_FRAMES = 8
50
+ MAX_FRAMES = 81
51
+
52
+ def __init__(self):
53
+ """Initialize VideoEngine."""
54
+ self.is_spaces = os.environ.get('SPACE_ID') is not None
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ self.pipeline: Optional[WanImageToVideoPipeline] = None
57
+ self.is_loaded = False
58
+ self.use_aoti = False
59
+
60
+ print(f"✓ VideoEngine initialized ({self.device})")
61
+
62
+ def _check_xformers_available(self) -> bool:
63
+ """Check if xFormers is available."""
64
+ try:
65
+ import xformers
66
+ return True
67
+ except ImportError:
68
+ return False
69
+
70
+ def load_model(self) -> None:
71
+ """Load model with FP8 quantization and AOTI compilation."""
72
+ if self.is_loaded:
73
+ print("⚠ VideoEngine already loaded")
74
+ return
75
+
76
+ try:
77
+ print("=" * 60)
78
+ print("Loading Wan2.2 I2V Engine with FP8 + AOTI")
79
+ print("=" * 60)
80
+
81
+ # Stage 1: Load base pipeline to CPU
82
+ print("→ [1/5] Loading base pipeline to CPU...")
83
+ self.pipeline = WanImageToVideoPipeline.from_pretrained(
84
+ self.MODEL_ID,
85
+ transformer=WanTransformer3DModel.from_pretrained(
86
+ self.TRANSFORMER_REPO,
87
+ subfolder='transformer',
88
+ torch_dtype=torch.bfloat16,
89
+ ),
90
+ transformer_2=WanTransformer3DModel.from_pretrained(
91
+ self.TRANSFORMER_REPO,
92
+ subfolder='transformer_2',
93
+ torch_dtype=torch.bfloat16,
94
+ ),
95
+ torch_dtype=torch.bfloat16,
96
+ )
97
+ print("✓ Base pipeline loaded to CPU")
98
+
99
+ # Stage 2: Load and fuse Lightning LoRA
100
+ print("→ [2/5] Loading Lightning LoRA...")
101
+ self.pipeline.load_lora_weights(
102
+ self.LORA_REPO, weight_name=self.LORA_WEIGHT,
103
+ adapter_name="lightx2v"
104
+ )
105
+ kwargs_lora = {"load_into_transformer_2": True}
106
+ self.pipeline.load_lora_weights(
107
+ self.LORA_REPO, weight_name=self.LORA_WEIGHT,
108
+ adapter_name="lightx2v_2", **kwargs_lora
109
+ )
110
+ self.pipeline.set_adapters(
111
+ ["lightx2v", "lightx2v_2"],
112
+ adapter_weights=[1., 1.]
113
+ )
114
+ self.pipeline.fuse_lora(
115
+ adapter_names=["lightx2v"], lora_scale=3.,
116
+ components=["transformer"]
117
+ )
118
+ self.pipeline.fuse_lora(
119
+ adapter_names=["lightx2v_2"], lora_scale=1.,
120
+ components=["transformer_2"]
121
+ )
122
+ self.pipeline.unload_lora_weights()
123
+ print("✓ Lightning LoRA fused")
124
+
125
+ # Stage 3: FP8 Quantization
126
+ print("→ [3/5] Applying FP8 quantization...")
127
+ try:
128
+ from torchao.quantization import quantize_
129
+ from torchao.quantization import (
130
+ Float8DynamicActivationFloat8WeightConfig,
131
+ Int8WeightOnlyConfig
132
+ )
133
+
134
+ # Quantize text encoder (INT8)
135
+ quantize_(self.pipeline.text_encoder, Int8WeightOnlyConfig())
136
+
137
+ # Quantize transformers (FP8)
138
+ quantize_(
139
+ self.pipeline.transformer,
140
+ Float8DynamicActivationFloat8WeightConfig()
141
+ )
142
+ quantize_(
143
+ self.pipeline.transformer_2,
144
+ Float8DynamicActivationFloat8WeightConfig()
145
+ )
146
+
147
+ print("✓ FP8 quantization applied (50% memory reduction)")
148
+ except Exception as e:
149
+ print(f"⚠ Quantization failed: {e}")
150
+ raise RuntimeError("FP8 quantization required for this optimized version")
151
+
152
+ # Stage 4: Load AOTI blocks
153
+ print("→ [4/5] Loading AOTI blocks...")
154
+ try:
155
+ import aoti
156
+
157
+ aoti.aoti_blocks_load(
158
+ self.pipeline.transformer,
159
+ 'zerogpu-aoti/Wan2',
160
+ variant='fp8da'
161
+ )
162
+ aoti.aoti_blocks_load(
163
+ self.pipeline.transformer_2,
164
+ 'zerogpu-aoti/Wan2',
165
+ variant='fp8da'
166
+ )
167
+ print("✓ AOTI blocks loaded (1.5-1.8x speedup)")
168
+ self.use_aoti = True
169
+ except Exception as e:
170
+ print(f"⚠ AOTI loading failed: {e}")
171
+ print(" Continuing without AOTI (FP8 only)")
172
+ self.use_aoti = False
173
+
174
+ # Stage 5: Move to GPU and enable optimizations
175
+ print("→ [5/5] Moving to GPU...")
176
+ gc.collect()
177
+ if torch.cuda.is_available():
178
+ torch.cuda.empty_cache()
179
+
180
+ self.pipeline = self.pipeline.to('cuda')
181
+
182
+ # Enable VAE optimizations
183
+ self.pipeline.enable_vae_tiling()
184
+ self.pipeline.enable_vae_slicing()
185
+
186
+ # Enable TF32
187
+ if torch.cuda.is_available():
188
+ torch.backends.cuda.matmul.allow_tf32 = True
189
+ torch.backends.cudnn.allow_tf32 = True
190
+
191
+ # Enable xFormers
192
+ try:
193
+ if self._check_xformers_available():
194
+ self.pipeline.enable_xformers_memory_efficient_attention()
195
+ print(" • xFormers enabled")
196
+ except:
197
+ pass
198
+
199
+ self.is_loaded = True
200
+ mode = "FP8 + AOTI" if self.use_aoti else "FP8 only"
201
+ print("=" * 60)
202
+ print(f"✓ VideoEngine Ready - {mode}")
203
+ print(f" • Device: {self.device}")
204
+ print(f" • Quantization: FP8 (50% memory reduction)")
205
+ print(f" • AOTI: {'Enabled (1.5-1.8x speedup)' if self.use_aoti else 'Disabled'}")
206
+ print(f" • Expected inference: {'~30-40s' if self.use_aoti else '~60-70s'}")
207
+ print("=" * 60)
208
+
209
+ except Exception as e:
210
+ print(f"\n{'='*60}")
211
+ print("✗ FATAL ERROR LOADING VIDEO ENGINE")
212
+ print(f"{'='*60}")
213
+ print(f"Error Type: {type(e).__name__}")
214
+ print(f"Error Message: {str(e)}")
215
+ print(f"\nFull Traceback:")
216
+ print(traceback.format_exc())
217
+ print(f"{'='*60}")
218
+ raise
219
+
220
+ def resize_image(self, image: Image.Image) -> Image.Image:
221
+ """Resize image to fit model constraints while preserving aspect ratio."""
222
+ width, height = image.size
223
+
224
+ if width == height:
225
+ return image.resize((self.SQUARE_DIM, self.SQUARE_DIM), Image.LANCZOS)
226
+
227
+ aspect_ratio = width / height
228
+ MAX_ASPECT_RATIO = self.MAX_DIM / self.MIN_DIM
229
+ MIN_ASPECT_RATIO = self.MIN_DIM / self.MAX_DIM
230
+
231
+ image_to_resize = image
232
+
233
+ if aspect_ratio > MAX_ASPECT_RATIO:
234
+ target_w, target_h = self.MAX_DIM, self.MIN_DIM
235
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
236
+ left = (width - crop_width) // 2
237
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
238
+ elif aspect_ratio < MIN_ASPECT_RATIO:
239
+ target_w, target_h = self.MIN_DIM, self.MAX_DIM
240
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
241
+ top = (height - crop_height) // 2
242
+ image_to_resize = image.crop((0, top, width, top + crop_height))
243
+ else:
244
+ if width > height:
245
+ target_w = self.MAX_DIM
246
+ target_h = int(round(target_w / aspect_ratio))
247
+ else:
248
+ target_h = self.MAX_DIM
249
+ target_w = int(round(target_h * aspect_ratio))
250
+
251
+ final_w = round(target_w / self.MULTIPLE_OF) * self.MULTIPLE_OF
252
+ final_h = round(target_h / self.MULTIPLE_OF) * self.MULTIPLE_OF
253
+ final_w = max(self.MIN_DIM, min(self.MAX_DIM, final_w))
254
+ final_h = max(self.MIN_DIM, min(self.MAX_DIM, final_h))
255
+
256
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
257
+
258
+ def get_num_frames(self, duration_seconds: float) -> int:
259
+ """Calculate frame count from duration."""
260
+ return 1 + int(np.clip(
261
+ int(round(duration_seconds * self.FIXED_FPS)),
262
+ self.MIN_FRAMES,
263
+ self.MAX_FRAMES,
264
+ ))
265
+
266
+ def generate_video(
267
+ self,
268
+ image: Image.Image,
269
+ prompt: str,
270
+ duration_seconds: float = 3.0,
271
+ num_inference_steps: int = 4,
272
+ guidance_scale: float = 1.0,
273
+ guidance_scale_2: float = 1.0,
274
+ seed: int = 42,
275
+ ) -> str:
276
+ """Generate video from image with FP8 + AOTI optimization."""
277
+ if not self.is_loaded:
278
+ raise RuntimeError("VideoEngine not loaded. Call load_model() first.")
279
+
280
+ try:
281
+ resized_image = self.resize_image(image)
282
+ num_frames = self.get_num_frames(duration_seconds)
283
+
284
+ print(f"\n→ Generating video:")
285
+ print(f" • Prompt: {prompt}")
286
+ print(f" • Resolution: {resized_image.width}x{resized_image.height}")
287
+ print(f" • Frames: {num_frames} ({duration_seconds}s @ {self.FIXED_FPS}fps)")
288
+ print(f" • Steps: {num_inference_steps}")
289
+ print(f" • Mode: {'FP8 + AOTI' if self.use_aoti else 'FP8 only'}")
290
+
291
+ # Memory cleanup
292
+ gc.collect()
293
+ if torch.cuda.is_available():
294
+ torch.cuda.empty_cache()
295
+ torch.cuda.synchronize()
296
+
297
+ with torch.no_grad():
298
+ # Use CUDA generator for optimized version
299
+ generator = torch.Generator(device="cuda").manual_seed(seed)
300
+
301
+ output_frames = self.pipeline(
302
+ image=resized_image,
303
+ prompt=prompt,
304
+ height=resized_image.height,
305
+ width=resized_image.width,
306
+ num_frames=num_frames,
307
+ guidance_scale=float(guidance_scale),
308
+ guidance_scale_2=float(guidance_scale_2),
309
+ num_inference_steps=int(num_inference_steps),
310
+ generator=generator,
311
+ ).frames[0]
312
+
313
+ # Cleanup after generation
314
+ gc.collect()
315
+ if torch.cuda.is_available():
316
+ torch.cuda.empty_cache()
317
+
318
+ # Export video
319
+ temp_dir = tempfile.gettempdir()
320
+ output_path = os.path.join(temp_dir, f"deltaflow_{seed}.mp4")
321
+ export_to_video(output_frames, output_path, fps=self.FIXED_FPS)
322
+
323
+ print(f"✓ Video generated: {output_path}")
324
+ return output_path
325
+
326
+ except Exception as e:
327
+ print(f"\n{'='*60}")
328
+ print("✗ FATAL ERROR DURING VIDEO GENERATION")
329
+ print(f"{'='*60}")
330
+ print(f"Error Type: {type(e).__name__}")
331
+ print(f"Error Message: {str(e)}")
332
+ print(f"\nFull Traceback:")
333
+ print(traceback.format_exc())
334
+ print(f"{'='*60}")
335
+ raise
336
+
337
+ def unload_model(self) -> None:
338
+ """Unload pipeline and free memory."""
339
+ if not self.is_loaded:
340
+ return
341
+
342
+ try:
343
+ if self.pipeline is not None:
344
+ del self.pipeline
345
+ self.pipeline = None
346
+
347
+ gc.collect()
348
+ if torch.cuda.is_available():
349
+ torch.cuda.empty_cache()
350
+
351
+ self.is_loaded = False
352
+ print("✓ VideoEngine unloaded")
353
+
354
+ except Exception as e:
355
+ print(f"⚠ Error during unload: {str(e)}")