abreza commited on
Commit
f15498a
·
1 Parent(s): 2cbde32

load wan in global

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -97,23 +97,15 @@ if not hasattr(vggt4track_model, 'infer'):
97
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
98
  tracker_model.eval()
99
 
100
- # Lazy loading for Wan to save memory until needed
101
- wan_pipeline = None
 
 
 
 
 
102
 
103
 
104
- def get_wan_pipeline():
105
- global wan_pipeline
106
- if wan_pipeline is None:
107
- print("🚀 Loading Wan TTM Pipeline (14B)...")
108
- wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
109
- WAN_MODEL_ID,
110
- torch_dtype=torch.bfloat16
111
- )
112
- wan_pipeline.vae.enable_tiling()
113
- wan_pipeline.vae.enable_slicing()
114
- wan_pipeline.to("cuda")
115
- return wan_pipeline
116
-
117
 
118
  print("✅ Tracking models loaded successfully!")
119
 
@@ -290,7 +282,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
290
  return None, "❌ TTM Inputs missing. Please run 3D tracking first."
291
 
292
  progress(0, desc="Loading Wan TTM Pipeline...")
293
- pipe = get_wan_pipeline()
294
 
295
  progress(0.2, desc="Preparing inputs...")
296
  image = load_image(first_frame_path)
@@ -304,8 +295,8 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
304
 
305
  # Match resolution logic from run_wan.py
306
  max_area = 480 * 832
307
- mod_value = pipe.vae_scale_factor_spatial * \
308
- pipe.transformer.config.patch_size[1]
309
  height, width = compute_hw_from_area(
310
  image.height, image.width, max_area, mod_value)
311
  image = image.resize((width, height))
@@ -314,7 +305,7 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
314
  generator = torch.Generator(device="cuda").manual_seed(0)
315
 
316
  with torch.inference_mode():
317
- result = pipe(
318
  image=image,
319
  prompt=prompt,
320
  negative_prompt=negative_prompt,
 
97
  tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
98
  tracker_model.eval()
99
 
100
+ wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
101
+ WAN_MODEL_ID,
102
+ torch_dtype=torch.bfloat16
103
+ )
104
+ wan_pipeline.vae.enable_tiling()
105
+ wan_pipeline.vae.enable_slicing()
106
+ wan_pipeline.to("cuda")
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  print("✅ Tracking models loaded successfully!")
111
 
 
282
  return None, "❌ TTM Inputs missing. Please run 3D tracking first."
283
 
284
  progress(0, desc="Loading Wan TTM Pipeline...")
 
285
 
286
  progress(0.2, desc="Preparing inputs...")
287
  image = load_image(first_frame_path)
 
295
 
296
  # Match resolution logic from run_wan.py
297
  max_area = 480 * 832
298
+ mod_value = wan_pipeline.vae_scale_factor_spatial * \
299
+ wan_pipeline.transformer.config.patch_size[1]
300
  height, width = compute_hw_from_area(
301
  image.height, image.width, max_area, mod_value)
302
  image = image.resize((width, height))
 
305
  generator = torch.Generator(device="cuda").manual_seed(0)
306
 
307
  with torch.inference_mode():
308
+ result = wan_pipeline(
309
  image=image,
310
  prompt=prompt,
311
  negative_prompt=negative_prompt,