DawnC commited on
Commit
cf1b6d7
·
verified ·
1 Parent(s): 7550e7c

Upload 2 files

Browse files
Files changed (2) hide show
  1. VideoEngine.py +6 -12
  2. app.py +6 -0
VideoEngine.py CHANGED
@@ -29,13 +29,6 @@ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
29
  from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
30
  from diffusers.utils.export_utils import export_to_video
31
 
32
- # Optional dependency for AOTI (HF Spaces deployment only)
33
- try:
34
- from spaces.zero.torch.aoti import aoti_blocks_load
35
- HAS_AOTI = True
36
- except ImportError:
37
- HAS_AOTI = False
38
-
39
 
40
  class VideoEngine:
41
  """
@@ -167,14 +160,13 @@ class VideoEngine:
167
  print("→ [6/7] Skipping AOTI (testing mode, not needed)")
168
  return
169
 
170
- if not HAS_AOTI:
171
- print("⚠ [6/7] Skipping AOTI (spaces.zero.torch.aoti not available)")
172
- return
173
-
174
  print("→ [6/7] Loading AOTI pre-compiled blocks...")
175
  try:
 
 
 
176
  # Determine variant based on GPU capability
177
- variant = 'int8' # Default
178
  if torch.cuda.is_available():
179
  cuda_cap = torch.cuda.get_device_capability()
180
  fp8_supported = cuda_cap[0] > 8 or (cuda_cap[0] == 8 and cuda_cap[1] >= 9)
@@ -185,6 +177,8 @@ class VideoEngine:
185
  aoti_blocks_load(self.pipeline.transformer, 'zerogpu-aoti/Wan2', variant=variant)
186
  aoti_blocks_load(self.pipeline.transformer_2, 'zerogpu-aoti/Wan2', variant=variant)
187
  print(f"✓ AOTI blocks loaded (variant: {variant}, 60-70% speedup)")
 
 
188
  except Exception as e:
189
  print(f"⚠ AOTI load failed (falling back to standard inference): {e}")
190
  print(" This is not critical, speed will be slightly slower")
 
29
  from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
30
  from diffusers.utils.export_utils import export_to_video
31
 
 
 
 
 
 
 
 
32
 
33
  class VideoEngine:
34
  """
 
160
  print("→ [6/7] Skipping AOTI (testing mode, not needed)")
161
  return
162
 
 
 
 
 
163
  print("→ [6/7] Loading AOTI pre-compiled blocks...")
164
  try:
165
+ # Lazy import to avoid CUDA initialization at module load
166
+ from spaces.zero.torch.aoti import aoti_blocks_load
167
+
168
  # Determine variant based on GPU capability
169
+ variant = 'int8'
170
  if torch.cuda.is_available():
171
  cuda_cap = torch.cuda.get_device_capability()
172
  fp8_supported = cuda_cap[0] > 8 or (cuda_cap[0] == 8 and cuda_cap[1] >= 9)
 
177
  aoti_blocks_load(self.pipeline.transformer, 'zerogpu-aoti/Wan2', variant=variant)
178
  aoti_blocks_load(self.pipeline.transformer_2, 'zerogpu-aoti/Wan2', variant=variant)
179
  print(f"✓ AOTI blocks loaded (variant: {variant}, 60-70% speedup)")
180
+ except ImportError:
181
+ print("⚠ [6/7] Skipping AOTI (spaces.zero.torch.aoti not available)")
182
  except Exception as e:
183
  print(f"⚠ AOTI load failed (falling back to standard inference): {e}")
184
  print(" This is not critical, speed will be slightly slower")
app.py CHANGED
@@ -1,6 +1,12 @@
1
  import os
2
  import sys
3
 
 
 
 
 
 
 
4
  sys.stdout.flush()
5
  import functools
6
  print = functools.partial(print, flush=True)
 
1
  import os
2
  import sys
3
 
4
+ # CRITICAL: Import spaces FIRST before any CUDA initialization
5
+ try:
6
+ import spaces
7
+ except ImportError:
8
+ pass
9
+
10
  sys.stdout.flush()
11
  import functools
12
  print = functools.partial(print, flush=True)