Spaces:
Sleeping
Sleeping
load wan in global
Browse files
app.py
CHANGED
|
@@ -97,23 +97,15 @@ if not hasattr(vggt4track_model, 'infer'):
|
|
| 97 |
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
| 98 |
tracker_model.eval()
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
-
def get_wan_pipeline():
|
| 105 |
-
global wan_pipeline
|
| 106 |
-
if wan_pipeline is None:
|
| 107 |
-
print("🚀 Loading Wan TTM Pipeline (14B)...")
|
| 108 |
-
wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
|
| 109 |
-
WAN_MODEL_ID,
|
| 110 |
-
torch_dtype=torch.bfloat16
|
| 111 |
-
)
|
| 112 |
-
wan_pipeline.vae.enable_tiling()
|
| 113 |
-
wan_pipeline.vae.enable_slicing()
|
| 114 |
-
wan_pipeline.to("cuda")
|
| 115 |
-
return wan_pipeline
|
| 116 |
-
|
| 117 |
|
| 118 |
print("✅ Tracking models loaded successfully!")
|
| 119 |
|
|
@@ -290,7 +282,6 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 290 |
return None, "❌ TTM Inputs missing. Please run 3D tracking first."
|
| 291 |
|
| 292 |
progress(0, desc="Loading Wan TTM Pipeline...")
|
| 293 |
-
pipe = get_wan_pipeline()
|
| 294 |
|
| 295 |
progress(0.2, desc="Preparing inputs...")
|
| 296 |
image = load_image(first_frame_path)
|
|
@@ -304,8 +295,8 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 304 |
|
| 305 |
# Match resolution logic from run_wan.py
|
| 306 |
max_area = 480 * 832
|
| 307 |
-
mod_value =
|
| 308 |
-
|
| 309 |
height, width = compute_hw_from_area(
|
| 310 |
image.height, image.width, max_area, mod_value)
|
| 311 |
image = image.resize((width, height))
|
|
@@ -314,7 +305,7 @@ def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path,
|
|
| 314 |
generator = torch.Generator(device="cuda").manual_seed(0)
|
| 315 |
|
| 316 |
with torch.inference_mode():
|
| 317 |
-
result =
|
| 318 |
image=image,
|
| 319 |
prompt=prompt,
|
| 320 |
negative_prompt=negative_prompt,
|
|
|
|
| 97 |
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
| 98 |
tracker_model.eval()
|
| 99 |
|
| 100 |
+
wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
|
| 101 |
+
WAN_MODEL_ID,
|
| 102 |
+
torch_dtype=torch.bfloat16
|
| 103 |
+
)
|
| 104 |
+
wan_pipeline.vae.enable_tiling()
|
| 105 |
+
wan_pipeline.vae.enable_slicing()
|
| 106 |
+
wan_pipeline.to("cuda")
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
print("✅ Tracking models loaded successfully!")
|
| 111 |
|
|
|
|
| 282 |
return None, "❌ TTM Inputs missing. Please run 3D tracking first."
|
| 283 |
|
| 284 |
progress(0, desc="Loading Wan TTM Pipeline...")
|
|
|
|
| 285 |
|
| 286 |
progress(0.2, desc="Preparing inputs...")
|
| 287 |
image = load_image(first_frame_path)
|
|
|
|
| 295 |
|
| 296 |
# Match resolution logic from run_wan.py
|
| 297 |
max_area = 480 * 832
|
| 298 |
+
mod_value = wan_pipeline.vae_scale_factor_spatial * \
|
| 299 |
+
wan_pipeline.transformer.config.patch_size[1]
|
| 300 |
height, width = compute_hw_from_area(
|
| 301 |
image.height, image.width, max_area, mod_value)
|
| 302 |
image = image.resize((width, height))
|
|
|
|
| 305 |
generator = torch.Generator(device="cuda").manual_seed(0)
|
| 306 |
|
| 307 |
with torch.inference_mode():
|
| 308 |
+
result = wan_pipeline(
|
| 309 |
image=image,
|
| 310 |
prompt=prompt,
|
| 311 |
negative_prompt=negative_prompt,
|