ColabWan / shared /mps /test_mps_inference.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
2.58 kB
"""Minimal MPS inference test for Wan2GP - loads the 1.3B model and runs one forward pass."""
import os, sys, gc, time
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
print("=" * 60)
print("Wan2GP MPS Inference Test")
print("=" * 60)
# Step 1: Apply MPS patch early
import torch
from shared.mps.device_patch import apply_mps_patch
apply_mps_patch()
print(f"\n[1] PyTorch {torch.__version__}, MPS: {torch.backends.mps.is_available()}")
print(f" Default device: {torch.get_default_device()}")
# Step 2: Try to load the wan handler
print("\n[2] Loading wan handler...")
try:
from models.wan.wan_handler import family_handler as WanHandler
from models.wan.configs import WAN_CONFIGS
print(" family_handler imported OK")
except ImportError:
print(" family_handler not available, trying wan_handler...")
try:
from models.wan.wan_handler import family_handler
print(" family_handler imported OK")
except ImportError as e:
print(f" SKIPPED: {e} (may need model weights installed)")
sys.exit(0)
# Step 3: Check available model files
print("\n[3] Checking model files...")
model_dir = os.path.join(REPO_ROOT, "ckpts")
if not os.path.exists(model_dir):
print(f" Model dir {model_dir} does not exist!")
print(" Need to download Wan2.1 1.3B model weights first.")
# List what's available
if os.path.exists(REPO_ROOT):
for root, dirs, files in os.walk(REPO_ROOT, topdown=True):
# Only go 2 levels deep
level = root.replace(REPO_ROOT, '').count(os.sep)
if level > 2:
dirs.clear()
continue
for f in files:
if f.endswith(('.safetensors', '.pt', '.bin', '.pth')):
fp = os.path.join(root, f)
size = os.path.getsize(fp) / (1024**3)
print(f" {fp}: {size:.2f}GB")
else:
for root, dirs, files in os.walk(model_dir, topdown=True):
level = root.replace(model_dir, '').count(os.sep)
if level > 2:
dirs.clear()
continue
for f in files:
if f.endswith(('.safetensors', '.pt', '.bin', '.pth')):
fp = os.path.join(root, f)
size = os.path.getsize(fp) / (1024**3)
print(f" {fp}: {size:.2f}GB")
print("\n[4] Done. Check if any model weights were found above.")
print(" If no weights found, you need to download Wan2.1 1.3B first.")