Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Simple test script to verify training components work. | |
| Run this to test if the system is ready for training. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| # Add parent directory to path to find models, data, utils | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from data.polyvore import PolyvoreTripletDataset | |
| from models.resnet_embedder import ResNetItemEmbedder | |
| from models.vit_outfit import OutfitCompatibilityModel | |
| from utils.transforms import build_train_transforms | |
| def test_imports(): | |
| """Test if all required modules can be imported.""" | |
| print("π Testing imports...") | |
| print("β All modules imported successfully at top level") | |
| return True | |
| def test_models(): | |
| """Test if models can be created and run forward pass.""" | |
| print("\nποΈ Testing model creation...") | |
| try: | |
| # Test ResNet embedder | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| resnet = ResNetItemEmbedder(embedding_dim=512).to(device) | |
| print(f"β ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters") | |
| # Test forward pass | |
| dummy_input = torch.randn(2, 3, 224, 224).to(device) | |
| with torch.no_grad(): | |
| output = resnet(dummy_input) | |
| print(f"β ResNet forward pass: input {dummy_input.shape} -> output {output.shape}") | |
| # Test ViT outfit model | |
| vit = OutfitCompatibilityModel(embedding_dim=512).to(device) | |
| print(f"β ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters") | |
| # Test forward pass | |
| dummy_tokens = torch.randn(2, 4, 512).to(device) | |
| with torch.no_grad(): | |
| output = vit(dummy_tokens) | |
| print(f"β ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}") | |
| return True | |
| except Exception as e: | |
| print(f"β Model test failed: {e}") | |
| return False | |
| def test_dataset(): | |
| """Test if dataset can be loaded (if available).""" | |
| print("\nπ Testing dataset loading...") | |
| data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore") | |
| splits_dir = os.path.join(data_root, "splits") | |
| train_file = os.path.join(splits_dir, "train.json") | |
| if not os.path.exists(train_file): | |
| print(f"β οΈ Training data not found at {train_file}") | |
| print("π‘ Dataset preparation may be needed") | |
| return True # Not a failure, just not ready | |
| try: | |
| dataset = PolyvoreTripletDataset(data_root, split="train") | |
| print(f"β Dataset loaded successfully: {len(dataset)} samples") | |
| # Test getting one sample | |
| if len(dataset) > 0: | |
| sample = dataset[0] | |
| print(f"β Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}") | |
| return True | |
| except Exception as e: | |
| print(f"β Dataset test failed: {e}") | |
| return False | |
| def test_training_components(): | |
| """Test if training components can be created.""" | |
| print("\nπ Testing training components...") | |
| try: | |
| from torch.utils.data import DataLoader | |
| from torch.optim import AdamW | |
| from torch.nn import TripletMarginLoss | |
| # Test optimizer creation | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ResNetItemEmbedder(embedding_dim=512).to(device) | |
| optimizer = AdamW(model.parameters(), lr=1e-3) | |
| print("β Optimizer created successfully") | |
| # Test loss function | |
| criterion = TripletMarginLoss(margin=0.2) | |
| print("β Loss function created successfully") | |
| return True | |
| except Exception as e: | |
| print(f"β Training components test failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("π§ͺ Starting Dressify Training System Tests\n") | |
| tests = [ | |
| ("Imports", test_imports), | |
| ("Models", test_models), | |
| ("Dataset", test_dataset), | |
| ("Training Components", test_training_components), | |
| ] | |
| results = [] | |
| for test_name, test_func in tests: | |
| try: | |
| result = test_func() | |
| results.append((test_name, result)) | |
| except Exception as e: | |
| print(f"β {test_name} test crashed: {e}") | |
| results.append((test_name, False)) | |
| # Summary | |
| print("\n" + "="*50) | |
| print("π TEST RESULTS SUMMARY") | |
| print("="*50) | |
| passed = 0 | |
| total = len(results) | |
| for test_name, result in results: | |
| status = "β PASS" if result else "β FAIL" | |
| print(f"{test_name:20} {status}") | |
| if result: | |
| passed += 1 | |
| print("="*50) | |
| print(f"Overall: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("π All tests passed! System is ready for training.") | |
| return True | |
| else: | |
| print("β οΈ Some tests failed. Please check the errors above.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) | |