File size: 10,428 Bytes
8bcf79a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#!/usr/bin/env python3
"""
Comprehensive tests for the Dressify outfit recommendation system.
Run with: python -m pytest tests/test_system.py -v
"""

import os
import sys
import tempfile
import shutil
import json
from pathlib import Path
from unittest.mock import Mock, patch

import pytest
import torch
import numpy as np
from PIL import Image

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from models.resnet_embedder import ResNetItemEmbedder
from models.vit_outfit import OutfitCompatibilityModel
from utils.transforms import build_inference_transform, build_train_transforms
from utils.triplet_mining import create_triplet_miner


class TestModels:
    """Test model architectures and forward passes."""
    
    def test_resnet_embedder(self):
        """Test ResNet embedder model."""
        model = ResNetItemEmbedder(embedding_dim=512)
        
        # Test forward pass
        batch_size = 4
        x = torch.randn(batch_size, 3, 224, 224)
        output = model(x)
        
        assert output.shape == (batch_size, 512)
        assert not torch.isnan(output).any()
        assert not torch.isinf(output).any()
    
    def test_vit_outfit_model(self):
        """Test ViT outfit compatibility model."""
        model = OutfitCompatibilityModel(embedding_dim=512)
        
        # Test forward pass
        batch_size = 2
        max_items = 6
        x = torch.randn(batch_size, max_items, 512)
        output = model(x)
        
        assert output.shape == (batch_size,)
        assert not torch.isnan(output).any()
        assert not torch.isinf(output).any()
    
    def test_model_consistency(self):
        """Test that models work together."""
        embedder = ResNetItemEmbedder(embedding_dim=512)
        vit_model = OutfitCompatibilityModel(embedding_dim=512)
        
        # Create dummy outfit
        batch_size = 2
        num_items = 4
        images = torch.randn(batch_size * num_items, 3, 224, 224)
        
        # Get embeddings
        with torch.no_grad():
            embeddings = embedder(images)
            embeddings = embeddings.view(batch_size, num_items, -1)
            
            # Score compatibility
            scores = vit_model(embeddings)
        
        assert scores.shape == (batch_size,)
        assert not torch.isnan(scores).any()


class TestTransforms:
    """Test image transformation pipelines."""
    
    def test_inference_transform(self):
        """Test inference transform pipeline."""
        transform = build_inference_transform(image_size=224)
        
        # Create dummy image
        img = Image.new('RGB', (100, 100), color='red')
        transformed = transform(img)
        
        assert transformed.shape == (3, 224, 224)
        assert transformed.dtype == torch.float32
        assert not torch.isnan(transformed).any()
    
    def test_train_transform(self):
        """Test training transform pipeline."""
        transform = build_train_transforms(image_size=224)
        
        # Create dummy image
        img = Image.new('RGB', (100, 100), color='blue')
        transformed = transform(img)
        
        assert transformed.shape == (3, 224, 224)
        assert transformed.dtype == torch.float32
        assert not torch.isnan(transformed).any()


class TestTripletMining:
    """Test triplet mining utilities."""
    
    def test_semi_hard_miner(self):
        """Test semi-hard negative mining."""
        miner = create_triplet_miner(strategy="semi_hard", margin=0.2)
        
        # Create dummy embeddings and labels
        batch_size = 32
        embed_dim = 128
        num_classes = 8
        
        embeddings = torch.randn(batch_size, embed_dim)
        labels = torch.randint(0, num_classes, (batch_size,))
        
        # Mine triplets
        anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels)
        
        if len(anchors) > 0:
            assert len(anchors) == len(positives) == len(negatives)
            assert anchors.max() < batch_size
            assert positives.max() < batch_size
            assert negatives.max() < batch_size
    
    def test_random_miner(self):
        """Test random triplet mining."""
        miner = create_triplet_miner(strategy="random", margin=0.2)
        
        batch_size = 16
        embed_dim = 64
        num_classes = 4
        
        embeddings = torch.randn(batch_size, embed_dim)
        labels = torch.randint(0, num_classes, (batch_size,))
        
        anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels)
        
        if len(anchors) > 0:
            assert len(anchors) == len(positives) == len(negatives)


class TestDataPreparation:
    """Test dataset preparation utilities."""
    
    def test_prepare_polyvore_script(self):
        """Test the Polyvore preparation script."""
        from scripts.prepare_polyvore import (
            _normalize_outfits, 
            collect_all_items, 
            build_triplets
        )
        
        # Test outfit normalization
        test_data = [
            {"items": ["item1", "item2", "item3"]},
            {"items": [{"item_id": "item4"}, {"item_id": "item5"}]}
        ]
        
        normalized = _normalize_outfits(test_data)
        assert len(normalized) == 2
        assert "items" in normalized[0]
        assert "items" in normalized[1]
        
        # Test item collection
        all_items = collect_all_items(normalized)
        assert len(all_items) == 5
        assert "item1" in all_items
        
        # Test triplet building
        triplets = build_triplets(normalized, all_items, max_triplets=10)
        assert len(triplets) <= 10
        if triplets:
            assert "anchor" in triplets[0]
            assert "positive" in triplets[0]
            assert "negative" in triplets[0]


class TestInference:
    """Test inference service."""
    
    @patch('inference.InferenceService._load_resnet')
    @patch('inference.InferenceService._load_vit')
    def test_inference_service_creation(self, mock_load_vit, mock_load_resnet):
        """Test inference service initialization."""
        # Mock model loading
        mock_resnet = Mock()
        mock_vit = Mock()
        mock_load_resnet.return_value = mock_resnet
        mock_load_vit.return_value = mock_vit
        
        from inference import InferenceService
        
        # This should not raise an error
        service = InferenceService()
        assert service.device in ["cuda", "mps", "cpu"]
    
    def test_image_embedding(self):
        """Test image embedding functionality."""
        # Create dummy images
        images = [Image.new('RGB', (224, 224), color='red') for _ in range(3)]
        
        # Mock the inference service
        with patch('inference.InferenceService.embed_images') as mock_embed:
            mock_embed.return_value = [np.random.randn(512) for _ in range(3)]
            
            # Test embedding
            embeddings = mock_embed(images)
            assert len(embeddings) == 3
            assert all(emb.shape == (512,) for emb in embeddings)


class TestIntegration:
    """Integration tests for the complete system."""
    
    def test_end_to_end_pipeline(self):
        """Test the complete pipeline from images to outfit recommendations."""
        # This is a high-level integration test
        # In a real scenario, you'd test with actual trained models
        
        # Create dummy wardrobe
        wardrobe = [
            {"id": "item1", "category": "upper"},
            {"id": "item2", "category": "bottom"},
            {"id": "item3", "category": "shoes"},
            {"id": "item4", "category": "accessory"}
        ]
        
        # Mock embeddings
        embeddings = [np.random.randn(512) for _ in range(4)]
        for item, emb in zip(wardrobe, embeddings):
            item["embedding"] = emb.tolist()
        
        # Mock inference service
        with patch('inference.InferenceService.compose_outfits') as mock_compose:
            mock_compose.return_value = [
                {
                    "item_ids": ["item1", "item2", "item3"],
                    "score": 0.85
                },
                {
                    "item_ids": ["item1", "item2", "item4"],
                    "score": 0.78
                }
            ]
            
            # Test outfit composition
            outfits = mock_compose(wardrobe, context={"occasion": "casual"})
            assert len(outfits) == 2
            assert "item_ids" in outfits[0]
            assert "score" in outfits[0]


class TestConfiguration:
    """Test configuration files."""
    
    def test_item_config(self):
        """Test item training configuration."""
        import yaml
        
        config_path = Path(__file__).parent.parent / "configs" / "item.yaml"
        if config_path.exists():
            with open(config_path) as f:
                config = yaml.safe_load(f)
            
            assert "model" in config
            assert "training" in config
            assert "data" in config
            assert config["model"]["embedding_dim"] == 512
    
    def test_outfit_config(self):
        """Test outfit training configuration."""
        import yaml
        
        config_path = Path(__file__).parent.parent / "configs" / "outfit.yaml"
        if config_path.exists():
            with open(config_path) as f:
                config = yaml.safe_load(f)
            
            assert "model" in config
            assert "training" in config
            assert "loss" in config
            assert config["model"]["embedding_dim"] == 512


class TestUtilities:
    """Test utility functions."""
    
    def test_hf_utils(self):
        """Test Hugging Face utilities."""
        from utils.hf_utils import HFModelManager
        
        # Test manager creation (without actual HF token)
        with pytest.raises(ValueError):
            HFModelManager(username=None)
    
    def test_export_utils(self):
        """Test export utilities."""
        from utils.export import ensure_export_dir
        
        with tempfile.TemporaryDirectory() as temp_dir:
            export_dir = ensure_export_dir(temp_dir)
            assert os.path.exists(export_dir)
            assert os.path.isdir(export_dir)


if __name__ == "__main__":
    # Run tests
    pytest.main([__file__, "-v"])