recomendation / training /test_training.py
Ali Mohsin
folder reorganise
72af8c3
#!/usr/bin/env python3
"""
Simple test script to verify training components work.
Run this to test if the system is ready for training.
"""
import os
import sys
import torch
# Add parent directory to path to find models, data, utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.polyvore import PolyvoreTripletDataset
from models.resnet_embedder import ResNetItemEmbedder
from models.vit_outfit import OutfitCompatibilityModel
from utils.transforms import build_train_transforms
def test_imports():
"""Test if all required modules can be imported."""
print("πŸ” Testing imports...")
print("βœ… All modules imported successfully at top level")
return True
def test_models():
"""Test if models can be created and run forward pass."""
print("\nπŸ—οΈ Testing model creation...")
try:
# Test ResNet embedder
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
resnet = ResNetItemEmbedder(embedding_dim=512).to(device)
print(f"βœ… ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters")
# Test forward pass
dummy_input = torch.randn(2, 3, 224, 224).to(device)
with torch.no_grad():
output = resnet(dummy_input)
print(f"βœ… ResNet forward pass: input {dummy_input.shape} -> output {output.shape}")
# Test ViT outfit model
vit = OutfitCompatibilityModel(embedding_dim=512).to(device)
print(f"βœ… ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters")
# Test forward pass
dummy_tokens = torch.randn(2, 4, 512).to(device)
with torch.no_grad():
output = vit(dummy_tokens)
print(f"βœ… ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}")
return True
except Exception as e:
print(f"❌ Model test failed: {e}")
return False
def test_dataset():
"""Test if dataset can be loaded (if available)."""
print("\nπŸ“Š Testing dataset loading...")
data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
splits_dir = os.path.join(data_root, "splits")
train_file = os.path.join(splits_dir, "train.json")
if not os.path.exists(train_file):
print(f"⚠️ Training data not found at {train_file}")
print("πŸ’‘ Dataset preparation may be needed")
return True # Not a failure, just not ready
try:
dataset = PolyvoreTripletDataset(data_root, split="train")
print(f"βœ… Dataset loaded successfully: {len(dataset)} samples")
# Test getting one sample
if len(dataset) > 0:
sample = dataset[0]
print(f"βœ… Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}")
return True
except Exception as e:
print(f"❌ Dataset test failed: {e}")
return False
def test_training_components():
"""Test if training components can be created."""
print("\nπŸš€ Testing training components...")
try:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import TripletMarginLoss
# Test optimizer creation
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ResNetItemEmbedder(embedding_dim=512).to(device)
optimizer = AdamW(model.parameters(), lr=1e-3)
print("βœ… Optimizer created successfully")
# Test loss function
criterion = TripletMarginLoss(margin=0.2)
print("βœ… Loss function created successfully")
return True
except Exception as e:
print(f"❌ Training components test failed: {e}")
return False
def main():
"""Run all tests."""
print("πŸ§ͺ Starting Dressify Training System Tests\n")
tests = [
("Imports", test_imports),
("Models", test_models),
("Dataset", test_dataset),
("Training Components", test_training_components),
]
results = []
for test_name, test_func in tests:
try:
result = test_func()
results.append((test_name, result))
except Exception as e:
print(f"❌ {test_name} test crashed: {e}")
results.append((test_name, False))
# Summary
print("\n" + "="*50)
print("πŸ“Š TEST RESULTS SUMMARY")
print("="*50)
passed = 0
total = len(results)
for test_name, result in results:
status = "βœ… PASS" if result else "❌ FAIL"
print(f"{test_name:20} {status}")
if result:
passed += 1
print("="*50)
print(f"Overall: {passed}/{total} tests passed")
if passed == total:
print("πŸŽ‰ All tests passed! System is ready for training.")
return True
else:
print("⚠️ Some tests failed. Please check the errors above.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)