abreza commited on
Commit
2cbde32
·
1 Parent(s): 7e8e6f1
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -97,15 +97,23 @@ if not hasattr(vggt4track_model, 'infer'):
97
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
98
  tracker_model.eval()
99
 
100
- wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
101
- WAN_MODEL_ID,
102
- torch_dtype=torch.bfloat16
103
- )
104
- wan_pipeline.vae.enable_tiling()
105
- wan_pipeline.vae.enable_slicing()
106
- wan_pipeline.to("cuda")
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  print("✅ Tracking models loaded successfully!")
111
 
@@ -282,6 +290,7 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
282
  return None, "❌ TTM Inputs missing. Please run 3D tracking first."
283
 
284
  progress(0, desc="Loading Wan TTM Pipeline...")
 
285
 
286
  progress(0.2, desc="Preparing inputs...")
287
  image = load_image(first_frame_path)
@@ -295,8 +304,8 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
295
 
296
  # Match resolution logic from run_wan.py
297
  max_area = 480 * 832
298
- mod_value = wan_pipeline.vae_scale_factor_spatial * \
299
- wan_pipeline.transformer.config.patch_size[1]
300
  height, width = compute_hw_from_area(
301
  image.height, image.width, max_area, mod_value)
302
  image = image.resize((width, height))
@@ -305,7 +314,7 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
305
  generator = torch.Generator(device="cuda").manual_seed(0)
306
 
307
  with torch.inference_mode():
308
- result = wan_pipeline(
309
  image=image,
310
  prompt=prompt,
311
  negative_prompt=negative_prompt,
 
97
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
98
  tracker_model.eval()
99
 
100
+ # Lazy loading for Wan to save memory until needed
101
+ wan_pipeline = None
 
 
 
 
 
102
 
103
 
104
+ def get_wan_pipeline():
105
+ global wan_pipeline
106
+ if wan_pipeline is None:
107
+ print("🚀 Loading Wan TTM Pipeline (14B)...")
108
+ wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
109
+ WAN_MODEL_ID,
110
+ torch_dtype=torch.bfloat16
111
+ )
112
+ wan_pipeline.vae.enable_tiling()
113
+ wan_pipeline.vae.enable_slicing()
114
+ wan_pipeline.to("cuda")
115
+ return wan_pipeline
116
+
117
 
118
  print("✅ Tracking models loaded successfully!")
119
 
 
290
  return None, "❌ TTM Inputs missing. Please run 3D tracking first."
291
 
292
  progress(0, desc="Loading Wan TTM Pipeline...")
293
+ pipe = get_wan_pipeline()
294
 
295
  progress(0.2, desc="Preparing inputs...")
296
  image = load_image(first_frame_path)
 
304
 
305
  # Match resolution logic from run_wan.py
306
  max_area = 480 * 832
307
+ mod_value = pipe.vae_scale_factor_spatial * \
308
+ pipe.transformer.config.patch_size[1]
309
  height, width = compute_hw_from_area(
310
  image.height, image.width, max_area, mod_value)
311
  image = image.resize((width, height))
 
314
  generator = torch.Generator(device="cuda").manual_seed(0)
315
 
316
  with torch.inference_mode():
317
+ result = pipe(
318
  image=image,
319
  prompt=prompt,
320
  negative_prompt=negative_prompt,