#!/usr/bin/env python3 """ Simple test script to verify training components work. Run this to test if the system is ready for training. """ import os import sys import torch # Add parent directory to path to find models, data, utils sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.polyvore import PolyvoreTripletDataset from models.resnet_embedder import ResNetItemEmbedder from models.vit_outfit import OutfitCompatibilityModel from utils.transforms import build_train_transforms def test_imports(): """Test if all required modules can be imported.""" print("๐Ÿ” Testing imports...") print("โœ… All modules imported successfully at top level") return True def test_models(): """Test if models can be created and run forward pass.""" print("\n๐Ÿ—๏ธ Testing model creation...") try: # Test ResNet embedder device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") resnet = ResNetItemEmbedder(embedding_dim=512).to(device) print(f"โœ… ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters") # Test forward pass dummy_input = torch.randn(2, 3, 224, 224).to(device) with torch.no_grad(): output = resnet(dummy_input) print(f"โœ… ResNet forward pass: input {dummy_input.shape} -> output {output.shape}") # Test ViT outfit model vit = OutfitCompatibilityModel(embedding_dim=512).to(device) print(f"โœ… ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters") # Test forward pass dummy_tokens = torch.randn(2, 4, 512).to(device) with torch.no_grad(): output = vit(dummy_tokens) print(f"โœ… ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}") return True except Exception as e: print(f"โŒ Model test failed: {e}") return False def test_dataset(): """Test if dataset can be loaded (if available).""" print("\n๐Ÿ“Š Testing dataset loading...") data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore") splits_dir = os.path.join(data_root, "splits") train_file = os.path.join(splits_dir, "train.json") if not os.path.exists(train_file): print(f"โš ๏ธ Training data not found at {train_file}") print("๐Ÿ’ก Dataset preparation may be needed") return True # Not a failure, just not ready try: dataset = PolyvoreTripletDataset(data_root, split="train") print(f"โœ… Dataset loaded successfully: {len(dataset)} samples") # Test getting one sample if len(dataset) > 0: sample = dataset[0] print(f"โœ… Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}") return True except Exception as e: print(f"โŒ Dataset test failed: {e}") return False def test_training_components(): """Test if training components can be created.""" print("\n๐Ÿš€ Testing training components...") try: from torch.utils.data import DataLoader from torch.optim import AdamW from torch.nn import TripletMarginLoss # Test optimizer creation device = "cuda" if torch.cuda.is_available() else "cpu" model = ResNetItemEmbedder(embedding_dim=512).to(device) optimizer = AdamW(model.parameters(), lr=1e-3) print("โœ… Optimizer created successfully") # Test loss function criterion = TripletMarginLoss(margin=0.2) print("โœ… Loss function created successfully") return True except Exception as e: print(f"โŒ Training components test failed: {e}") return False def main(): """Run all tests.""" print("๐Ÿงช Starting Dressify Training System Tests\n") tests = [ ("Imports", test_imports), ("Models", test_models), ("Dataset", test_dataset), ("Training Components", test_training_components), ] results = [] for test_name, test_func in tests: try: result = test_func() results.append((test_name, result)) except Exception as e: print(f"โŒ {test_name} test crashed: {e}") results.append((test_name, False)) # Summary print("\n" + "="*50) print("๐Ÿ“Š TEST RESULTS SUMMARY") print("="*50) passed = 0 total = len(results) for test_name, result in results: status = "โœ… PASS" if result else "โŒ FAIL" print(f"{test_name:20} {status}") if result: passed += 1 print("="*50) print(f"Overall: {passed}/{total} tests passed") if passed == total: print("๐ŸŽ‰ All tests passed! System is ready for training.") return True else: print("โš ๏ธ Some tests failed. Please check the errors above.") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)