vbench-i2v / verify_models.py
lxl-158's picture
upload vbench-i2v
d0d70d4 verified
#!/usr/bin/env python
"""
Verify all required models can be loaded for VBench I2V evaluation.
This script downloads and initializes all models without running full evaluation.
"""
import os
import sys
import torch
# Set model cache directory
os.environ['VBENCH_CACHE_DIR'] = '/workspace/vbench-i2v/vbench2_beta_i2v/pretrained_models'
os.environ['HF_HOME'] = '/workspace/vbench-i2v/vbench2_beta_i2v/pretrained_models/huggingface'
os.environ['TORCH_HOME'] = '/workspace/vbench-i2v/vbench2_beta_i2v/pretrained_models/torch'
CACHE_DIR = os.environ['VBENCH_CACHE_DIR']
def test_dino():
"""Test DINO model loading"""
print("\n[1/7] Testing DINO model (i2v_subject, subject_consistency)...")
try:
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
model.eval()
print(" βœ“ DINO model loaded successfully")
del model
torch.cuda.empty_cache()
return True
except Exception as e:
print(f" βœ— DINO failed: {e}")
return False
def test_clip():
"""Test CLIP model loading"""
print("\n[2/7] Testing CLIP model (background_consistency, aesthetic_quality)...")
try:
import clip
model, preprocess = clip.load("ViT-B/32", device="cuda")
print(" βœ“ CLIP ViT-B/32 loaded successfully")
model_l, preprocess_l = clip.load("ViT-L/14", device="cuda")
print(" βœ“ CLIP ViT-L/14 loaded successfully")
del model, model_l
torch.cuda.empty_cache()
return True
except Exception as e:
print(f" βœ— CLIP failed: {e}")
return False
def test_cotracker():
"""Test CoTracker model loading"""
print("\n[3/7] Testing CoTracker model (camera_motion)...")
try:
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2")
cotracker = cotracker.cuda()
cotracker.eval()
print(" βœ“ CoTracker model loaded successfully")
del cotracker
torch.cuda.empty_cache()
return True
except Exception as e:
print(f" βœ— CoTracker failed: {e}")
return False
def test_amt():
"""Test AMT model loading"""
print("\n[4/7] Testing AMT model (motion_smoothness)...")
try:
ckpt_path = f'{CACHE_DIR}/amt_model/amt-s.pth'
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
print(f" βœ“ AMT model checkpoint exists and loadable ({os.path.getsize(ckpt_path) / 1e6:.1f} MB)")
del ckpt
return True
else:
print(f" βœ— AMT model not found at {ckpt_path}")
return False
except Exception as e:
print(f" βœ— AMT failed: {e}")
return False
def test_raft():
"""Test RAFT model loading"""
print("\n[5/7] Testing RAFT model (dynamic_degree)...")
try:
ckpt_path = f'{CACHE_DIR}/raft_model/models/raft-things.pth'
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
print(f" βœ“ RAFT model checkpoint exists and loadable ({os.path.getsize(ckpt_path) / 1e6:.1f} MB)")
del ckpt
return True
else:
print(f" βœ— RAFT model not found at {ckpt_path}")
return False
except Exception as e:
print(f" βœ— RAFT failed: {e}")
return False
def test_musiq():
"""Test MUSIQ model loading"""
print("\n[6/7] Testing MUSIQ model (imaging_quality)...")
try:
ckpt_path = f'{CACHE_DIR}/pyiqa_model/musiq_spaq_ckpt-358bb6af.pth'
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
print(f" βœ“ MUSIQ model checkpoint exists and loadable ({os.path.getsize(ckpt_path) / 1e6:.1f} MB)")
del ckpt
return True
else:
print(f" βœ— MUSIQ model not found at {ckpt_path}")
return False
except Exception as e:
print(f" βœ— MUSIQ failed: {e}")
return False
def test_pyiqa():
"""Test PyIQA library (using same method as VBench)"""
print("\n[7/7] Testing PyIQA library (imaging_quality)...")
try:
from pyiqa.archs.musiq_arch import MUSIQ
model_path = f'{CACHE_DIR}/pyiqa_model/musiq_spaq_ckpt-358bb6af.pth'
model = MUSIQ(pretrained_model_path=model_path)
model = model.cuda()
model.eval()
print(" βœ“ PyIQA MUSIQ model loaded successfully")
del model
torch.cuda.empty_cache()
return True
except Exception as e:
print(f" βœ— PyIQA failed: {e}")
return False
def list_downloaded_models():
"""List all downloaded models"""
print("\n" + "=" * 60)
print("Downloaded models summary:")
print("=" * 60)
total_size = 0
for root, dirs, files in os.walk(CACHE_DIR):
for f in files:
fpath = os.path.join(root, f)
size = os.path.getsize(fpath)
total_size += size
if size > 1e6: # Only show files > 1MB
rel_path = os.path.relpath(fpath, CACHE_DIR)
print(f" {rel_path}: {size/1e6:.1f} MB")
print(f"\nTotal size: {total_size/1e9:.2f} GB")
def main():
print("=" * 60)
print("VBench I2V Model Verification")
print("=" * 60)
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Model cache: {CACHE_DIR}")
results = {}
results['dino'] = test_dino()
results['clip'] = test_clip()
results['cotracker'] = test_cotracker()
results['amt'] = test_amt()
results['raft'] = test_raft()
results['musiq'] = test_musiq()
results['pyiqa'] = test_pyiqa()
list_downloaded_models()
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
all_passed = True
for name, passed in results.items():
status = "βœ“ PASS" if passed else "βœ— FAIL"
print(f" {name}: {status}")
if not passed:
all_passed = False
if all_passed:
print("\nβœ“ All models verified successfully!")
print(" You can now run the full evaluation with: python run_i2v_eval.py")
else:
print("\nβœ— Some models failed verification. Please check the errors above.")
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())