File size: 5,193 Bytes
24ea486
 
 
 
 
 
 
 
 
 
72af8c3
 
 
 
 
 
 
 
24ea486
 
 
72af8c3
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
"""
Simple test script to verify training components work.
Run this to test if the system is ready for training.
"""

import os
import sys
import torch

# Add parent directory to path to find models, data, utils
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data.polyvore import PolyvoreTripletDataset
from models.resnet_embedder import ResNetItemEmbedder
from models.vit_outfit import OutfitCompatibilityModel
from utils.transforms import build_train_transforms

def test_imports():
    """Test if all required modules can be imported."""
    print("πŸ” Testing imports...")
    print("βœ… All modules imported successfully at top level")
    return True

def test_models():
    """Test if models can be created and run forward pass."""
    print("\nπŸ—οΈ  Testing model creation...")
    
    try:
        # Test ResNet embedder
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        
        resnet = ResNetItemEmbedder(embedding_dim=512).to(device)
        print(f"βœ… ResNet created with {sum(p.numel() for p in resnet.parameters()):,} parameters")
        
        # Test forward pass
        dummy_input = torch.randn(2, 3, 224, 224).to(device)
        with torch.no_grad():
            output = resnet(dummy_input)
        print(f"βœ… ResNet forward pass: input {dummy_input.shape} -> output {output.shape}")
        
        # Test ViT outfit model
        vit = OutfitCompatibilityModel(embedding_dim=512).to(device)
        print(f"βœ… ViT created with {sum(p.numel() for p in vit.parameters()):,} parameters")
        
        # Test forward pass
        dummy_tokens = torch.randn(2, 4, 512).to(device)
        with torch.no_grad():
            output = vit(dummy_tokens)
        print(f"βœ… ViT forward pass: input {dummy_tokens.shape} -> output {output.shape}")
        
        return True
        
    except Exception as e:
        print(f"❌ Model test failed: {e}")
        return False

def test_dataset():
    """Test if dataset can be loaded (if available)."""
    print("\nπŸ“Š Testing dataset loading...")
    
    data_root = os.getenv("POLYVORE_ROOT", "/home/user/app/data/Polyvore")
    splits_dir = os.path.join(data_root, "splits")
    train_file = os.path.join(splits_dir, "train.json")
    
    if not os.path.exists(train_file):
        print(f"⚠️  Training data not found at {train_file}")
        print("πŸ’‘ Dataset preparation may be needed")
        return True  # Not a failure, just not ready
    
    try:
        dataset = PolyvoreTripletDataset(data_root, split="train")
        print(f"βœ… Dataset loaded successfully: {len(dataset)} samples")
        
        # Test getting one sample
        if len(dataset) > 0:
            sample = dataset[0]
            print(f"βœ… Sample loaded: {len(sample)} tensors with shapes {[s.shape for s in sample]}")
        
        return True
        
    except Exception as e:
        print(f"❌ Dataset test failed: {e}")
        return False

def test_training_components():
    """Test if training components can be created."""
    print("\nπŸš€ Testing training components...")
    
    try:
        from torch.utils.data import DataLoader
        from torch.optim import AdamW
        from torch.nn import TripletMarginLoss
        
        # Test optimizer creation
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = ResNetItemEmbedder(embedding_dim=512).to(device)
        optimizer = AdamW(model.parameters(), lr=1e-3)
        print("βœ… Optimizer created successfully")
        
        # Test loss function
        criterion = TripletMarginLoss(margin=0.2)
        print("βœ… Loss function created successfully")
        
        return True
        
    except Exception as e:
        print(f"❌ Training components test failed: {e}")
        return False

def main():
    """Run all tests."""
    print("πŸ§ͺ Starting Dressify Training System Tests\n")
    
    tests = [
        ("Imports", test_imports),
        ("Models", test_models),
        ("Dataset", test_dataset),
        ("Training Components", test_training_components),
    ]
    
    results = []
    for test_name, test_func in tests:
        try:
            result = test_func()
            results.append((test_name, result))
        except Exception as e:
            print(f"❌ {test_name} test crashed: {e}")
            results.append((test_name, False))
    
    # Summary
    print("\n" + "="*50)
    print("πŸ“Š TEST RESULTS SUMMARY")
    print("="*50)
    
    passed = 0
    total = len(results)
    
    for test_name, result in results:
        status = "βœ… PASS" if result else "❌ FAIL"
        print(f"{test_name:20} {status}")
        if result:
            passed += 1
    
    print("="*50)
    print(f"Overall: {passed}/{total} tests passed")
    
    if passed == total:
        print("πŸŽ‰ All tests passed! System is ready for training.")
        return True
    else:
        print("⚠️  Some tests failed. Please check the errors above.")
        return False

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)