Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Comprehensive tests for the Dressify outfit recommendation system. | |
| Run with: python -m pytest tests/test_system.py -v | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| import json | |
| from pathlib import Path | |
| from unittest.mock import Mock, patch | |
| import pytest | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| # Add src to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | |
| from models.resnet_embedder import ResNetItemEmbedder | |
| from models.vit_outfit import OutfitCompatibilityModel | |
| from utils.transforms import build_inference_transform, build_train_transforms | |
| from utils.triplet_mining import create_triplet_miner | |
| class TestModels: | |
| """Test model architectures and forward passes.""" | |
| def test_resnet_embedder(self): | |
| """Test ResNet embedder model.""" | |
| model = ResNetItemEmbedder(embedding_dim=512) | |
| # Test forward pass | |
| batch_size = 4 | |
| x = torch.randn(batch_size, 3, 224, 224) | |
| output = model(x) | |
| assert output.shape == (batch_size, 512) | |
| assert not torch.isnan(output).any() | |
| assert not torch.isinf(output).any() | |
| def test_vit_outfit_model(self): | |
| """Test ViT outfit compatibility model.""" | |
| model = OutfitCompatibilityModel(embedding_dim=512) | |
| # Test forward pass | |
| batch_size = 2 | |
| max_items = 6 | |
| x = torch.randn(batch_size, max_items, 512) | |
| output = model(x) | |
| assert output.shape == (batch_size,) | |
| assert not torch.isnan(output).any() | |
| assert not torch.isinf(output).any() | |
| def test_model_consistency(self): | |
| """Test that models work together.""" | |
| embedder = ResNetItemEmbedder(embedding_dim=512) | |
| vit_model = OutfitCompatibilityModel(embedding_dim=512) | |
| # Create dummy outfit | |
| batch_size = 2 | |
| num_items = 4 | |
| images = torch.randn(batch_size * num_items, 3, 224, 224) | |
| # Get embeddings | |
| with torch.no_grad(): | |
| embeddings = embedder(images) | |
| embeddings = embeddings.view(batch_size, num_items, -1) | |
| # Score compatibility | |
| scores = vit_model(embeddings) | |
| assert scores.shape == (batch_size,) | |
| assert not torch.isnan(scores).any() | |
| class TestTransforms: | |
| """Test image transformation pipelines.""" | |
| def test_inference_transform(self): | |
| """Test inference transform pipeline.""" | |
| transform = build_inference_transform(image_size=224) | |
| # Create dummy image | |
| img = Image.new('RGB', (100, 100), color='red') | |
| transformed = transform(img) | |
| assert transformed.shape == (3, 224, 224) | |
| assert transformed.dtype == torch.float32 | |
| assert not torch.isnan(transformed).any() | |
| def test_train_transform(self): | |
| """Test training transform pipeline.""" | |
| transform = build_train_transforms(image_size=224) | |
| # Create dummy image | |
| img = Image.new('RGB', (100, 100), color='blue') | |
| transformed = transform(img) | |
| assert transformed.shape == (3, 224, 224) | |
| assert transformed.dtype == torch.float32 | |
| assert not torch.isnan(transformed).any() | |
| class TestTripletMining: | |
| """Test triplet mining utilities.""" | |
| def test_semi_hard_miner(self): | |
| """Test semi-hard negative mining.""" | |
| miner = create_triplet_miner(strategy="semi_hard", margin=0.2) | |
| # Create dummy embeddings and labels | |
| batch_size = 32 | |
| embed_dim = 128 | |
| num_classes = 8 | |
| embeddings = torch.randn(batch_size, embed_dim) | |
| labels = torch.randint(0, num_classes, (batch_size,)) | |
| # Mine triplets | |
| anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels) | |
| if len(anchors) > 0: | |
| assert len(anchors) == len(positives) == len(negatives) | |
| assert anchors.max() < batch_size | |
| assert positives.max() < batch_size | |
| assert negatives.max() < batch_size | |
| def test_random_miner(self): | |
| """Test random triplet mining.""" | |
| miner = create_triplet_miner(strategy="random", margin=0.2) | |
| batch_size = 16 | |
| embed_dim = 64 | |
| num_classes = 4 | |
| embeddings = torch.randn(batch_size, embed_dim) | |
| labels = torch.randint(0, num_classes, (batch_size,)) | |
| anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels) | |
| if len(anchors) > 0: | |
| assert len(anchors) == len(positives) == len(negatives) | |
| class TestDataPreparation: | |
| """Test dataset preparation utilities.""" | |
| def test_prepare_polyvore_script(self): | |
| """Test the Polyvore preparation script.""" | |
| from scripts.prepare_polyvore import ( | |
| _normalize_outfits, | |
| collect_all_items, | |
| build_triplets | |
| ) | |
| # Test outfit normalization | |
| test_data = [ | |
| {"items": ["item1", "item2", "item3"]}, | |
| {"items": [{"item_id": "item4"}, {"item_id": "item5"}]} | |
| ] | |
| normalized = _normalize_outfits(test_data) | |
| assert len(normalized) == 2 | |
| assert "items" in normalized[0] | |
| assert "items" in normalized[1] | |
| # Test item collection | |
| all_items = collect_all_items(normalized) | |
| assert len(all_items) == 5 | |
| assert "item1" in all_items | |
| # Test triplet building | |
| triplets = build_triplets(normalized, all_items, max_triplets=10) | |
| assert len(triplets) <= 10 | |
| if triplets: | |
| assert "anchor" in triplets[0] | |
| assert "positive" in triplets[0] | |
| assert "negative" in triplets[0] | |
| class TestInference: | |
| """Test inference service.""" | |
| def test_inference_service_creation(self, mock_load_vit, mock_load_resnet): | |
| """Test inference service initialization.""" | |
| # Mock model loading | |
| mock_resnet = Mock() | |
| mock_vit = Mock() | |
| mock_load_resnet.return_value = mock_resnet | |
| mock_load_vit.return_value = mock_vit | |
| from inference import InferenceService | |
| # This should not raise an error | |
| service = InferenceService() | |
| assert service.device in ["cuda", "mps", "cpu"] | |
| def test_image_embedding(self): | |
| """Test image embedding functionality.""" | |
| # Create dummy images | |
| images = [Image.new('RGB', (224, 224), color='red') for _ in range(3)] | |
| # Mock the inference service | |
| with patch('inference.InferenceService.embed_images') as mock_embed: | |
| mock_embed.return_value = [np.random.randn(512) for _ in range(3)] | |
| # Test embedding | |
| embeddings = mock_embed(images) | |
| assert len(embeddings) == 3 | |
| assert all(emb.shape == (512,) for emb in embeddings) | |
| class TestIntegration: | |
| """Integration tests for the complete system.""" | |
| def test_end_to_end_pipeline(self): | |
| """Test the complete pipeline from images to outfit recommendations.""" | |
| # This is a high-level integration test | |
| # In a real scenario, you'd test with actual trained models | |
| # Create dummy wardrobe | |
| wardrobe = [ | |
| {"id": "item1", "category": "upper"}, | |
| {"id": "item2", "category": "bottom"}, | |
| {"id": "item3", "category": "shoes"}, | |
| {"id": "item4", "category": "accessory"} | |
| ] | |
| # Mock embeddings | |
| embeddings = [np.random.randn(512) for _ in range(4)] | |
| for item, emb in zip(wardrobe, embeddings): | |
| item["embedding"] = emb.tolist() | |
| # Mock inference service | |
| with patch('inference.InferenceService.compose_outfits') as mock_compose: | |
| mock_compose.return_value = [ | |
| { | |
| "item_ids": ["item1", "item2", "item3"], | |
| "score": 0.85 | |
| }, | |
| { | |
| "item_ids": ["item1", "item2", "item4"], | |
| "score": 0.78 | |
| } | |
| ] | |
| # Test outfit composition | |
| outfits = mock_compose(wardrobe, context={"occasion": "casual"}) | |
| assert len(outfits) == 2 | |
| assert "item_ids" in outfits[0] | |
| assert "score" in outfits[0] | |
| class TestConfiguration: | |
| """Test configuration files.""" | |
| def test_item_config(self): | |
| """Test item training configuration.""" | |
| import yaml | |
| config_path = Path(__file__).parent.parent / "configs" / "item.yaml" | |
| if config_path.exists(): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| assert "model" in config | |
| assert "training" in config | |
| assert "data" in config | |
| assert config["model"]["embedding_dim"] == 512 | |
| def test_outfit_config(self): | |
| """Test outfit training configuration.""" | |
| import yaml | |
| config_path = Path(__file__).parent.parent / "configs" / "outfit.yaml" | |
| if config_path.exists(): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| assert "model" in config | |
| assert "training" in config | |
| assert "loss" in config | |
| assert config["model"]["embedding_dim"] == 512 | |
| class TestUtilities: | |
| """Test utility functions.""" | |
| def test_hf_utils(self): | |
| """Test Hugging Face utilities.""" | |
| from utils.hf_utils import HFModelManager | |
| # Test manager creation (without actual HF token) | |
| with pytest.raises(ValueError): | |
| HFModelManager(username=None) | |
| def test_export_utils(self): | |
| """Test export utilities.""" | |
| from utils.export import ensure_export_dir | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| export_dir = ensure_export_dir(temp_dir) | |
| assert os.path.exists(export_dir) | |
| assert os.path.isdir(export_dir) | |
| if __name__ == "__main__": | |
| # Run tests | |
| pytest.main([__file__, "-v"]) | |