Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -12,12 +12,46 @@ This project implements an advanced fashion search system based on CLIP, with th
|
|
| 12 |
|
| 13 |
### Architecture
|
| 14 |
|
| 15 |
-
The main model
|
| 16 |
-
- **
|
| 17 |
-
- **
|
| 18 |
-
- **
|
| 19 |
|
| 20 |
-
Total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
## 🚀 Installation
|
| 23 |
|
|
@@ -46,43 +80,107 @@ pip install -r requirements.txt
|
|
| 46 |
|
| 47 |
```
|
| 48 |
.
|
| 49 |
-
├── color_model.py # Color model
|
| 50 |
-
├── hierarchy_model.py
|
| 51 |
-
├── main_model.py # Main CLIP model
|
|
|
|
| 52 |
├── config.py # Configuration for paths and parameters
|
| 53 |
-
├──
|
| 54 |
-
├──
|
| 55 |
├── models/
|
| 56 |
-
│ ├── color_model.pt # Trained color model
|
| 57 |
-
│ ├── hierarchy_model.pth # Trained hierarchy model
|
| 58 |
-
│ └── gap_clip.pth # Main CLIP model
|
| 59 |
-
├── evaluation/ #
|
| 60 |
-
│ ├──
|
| 61 |
-
│ ├──
|
| 62 |
-
│
|
| 63 |
-
├──
|
| 64 |
-
│ ├──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
├── requirements.txt # Python dependencies
|
| 66 |
└── README.md # This documentation
|
| 67 |
```
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
## 🔧 Configuration
|
| 70 |
|
| 71 |
Main parameters are defined in `config.py`:
|
| 72 |
|
| 73 |
```python
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
device = torch.device("mps") # Device (cuda, mps, cpu)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
```
|
| 78 |
|
| 79 |
### Model Paths
|
| 80 |
|
| 81 |
-
Default paths
|
| 82 |
-
- `models/color_model.pt` :
|
| 83 |
-
- `models/hierarchy_model.pth` :
|
| 84 |
-
- `models/gap_clip.pth` : Main CLIP model
|
| 85 |
-
- `tokenizer_vocab.json` : Tokenizer vocabulary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
## 📦 Usage
|
| 88 |
|
|
@@ -142,13 +240,32 @@ with torch.no_grad():
|
|
| 142 |
|
| 143 |
### 4. Using the Example Script
|
| 144 |
|
|
|
|
|
|
|
| 145 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
python example_usage.py \
|
| 147 |
-
--repo-id
|
| 148 |
-
--text "blue
|
| 149 |
--image path/to/image.jpg
|
| 150 |
```
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
## 🎯 Model Training
|
| 153 |
|
| 154 |
### Train the Color Model
|
|
@@ -179,27 +296,77 @@ train_hierarchy_model(model, train_loader, val_loader, num_epochs=20)
|
|
| 179 |
|
| 180 |
### Train the Main CLIP Model
|
| 181 |
|
| 182 |
-
The main model trains with both specialized models
|
| 183 |
|
|
|
|
| 184 |
```bash
|
| 185 |
-
python
|
| 186 |
```
|
|
|
|
| 187 |
|
| 188 |
-
**
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
- `batch_size = 32` : Batch size
|
| 194 |
-
- `
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
**
|
| 197 |
-
-
|
| 198 |
-
-
|
| 199 |
-
-
|
| 200 |
-
- Automatic learning rate reduction
|
| 201 |
-
- Automatic best model saving
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
## 📊 Models
|
| 205 |
|
|
@@ -217,22 +384,32 @@ python main_model.py
|
|
| 217 |
- **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
|
| 218 |
- **Usage** : Classify and encode categorical hierarchy
|
| 219 |
|
| 220 |
-
### Main CLIP Model
|
| 221 |
|
| 222 |
- **Architecture** : CLIP ViT-B/32 (LAION)
|
| 223 |
-
- **Base** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
|
| 224 |
-
- **
|
| 225 |
-
- **Dimensions** :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
- **Features** :
|
| 227 |
-
-
|
| 228 |
-
-
|
| 229 |
-
-
|
|
|
|
|
|
|
| 230 |
|
| 231 |
## 🔍 Advanced Usage Examples
|
| 232 |
|
| 233 |
### Search with Combined Embeddings
|
| 234 |
|
| 235 |
```python
|
|
|
|
| 236 |
import torch.nn.functional as F
|
| 237 |
|
| 238 |
# Text query
|
|
@@ -243,82 +420,290 @@ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
|
| 243 |
# Main model embeddings
|
| 244 |
with torch.no_grad():
|
| 245 |
outputs = main_model(**text_inputs)
|
| 246 |
-
text_features = outputs.text_embeds
|
| 247 |
|
| 248 |
# Extract specialized embeddings from main model
|
| 249 |
-
main_color_emb = text_features[:, :16] #
|
| 250 |
-
main_hierarchy_emb = text_features[:, 16:80] #
|
| 251 |
-
main_clip_emb = text_features[:, 80:] # CLIP dimensions (80
|
| 252 |
|
| 253 |
# Compare with specialized models
|
| 254 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 255 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 256 |
|
| 257 |
-
#
|
| 258 |
color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
| 259 |
hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
```
|
| 261 |
|
| 262 |
### Search in an Image Database
|
| 263 |
|
| 264 |
```python
|
| 265 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
#
|
| 268 |
image_paths = [...] # List of image paths
|
| 269 |
image_features_list = []
|
| 270 |
|
| 271 |
-
|
|
|
|
| 272 |
image = Image.open(img_path).convert("RGB")
|
| 273 |
image_inputs = processor(images=[image], return_tensors="pt")
|
| 274 |
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 275 |
|
| 276 |
with torch.no_grad():
|
| 277 |
outputs = main_model(**image_inputs)
|
| 278 |
-
features = outputs.image_embeds
|
| 279 |
-
image_features_list.append(features.cpu()
|
| 280 |
|
| 281 |
-
#
|
| 282 |
-
image_features =
|
| 283 |
|
| 284 |
-
# Search
|
| 285 |
query = "red dress"
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
# Calculate similarities
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
dim=1
|
| 293 |
-
)
|
| 294 |
|
| 295 |
-
#
|
|
|
|
|
|
|
|
|
|
| 296 |
top_k = 10
|
| 297 |
-
top_indices = similarities.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
```
|
| 299 |
|
| 300 |
## 📝 Evaluation
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
-
|
| 306 |
-
-
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
## 📄 Citation
|
| 310 |
|
| 311 |
-
If you use
|
| 312 |
|
| 313 |
```bibtex
|
| 314 |
-
@misc{
|
| 315 |
-
title={GAP
|
| 316 |
-
author={ },
|
| 317 |
year={2024},
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
}
|
| 320 |
```
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
## 🤝 Contributing
|
| 323 |
|
| 324 |
Contributions are welcome! Feel free to open an issue or a pull request.
|
|
|
|
| 12 |
|
| 13 |
### Architecture
|
| 14 |
|
| 15 |
+
The main model's embedding structure:
|
| 16 |
+
- **Dimensions 0-15** (16 dims): Color embeddings aligned with specialized color model
|
| 17 |
+
- **Dimensions 16-79** (64 dims): Hierarchy embeddings aligned with specialized hierarchy model
|
| 18 |
+
- **Dimensions 80-511** (432 dims): Standard CLIP embeddings for general visual-semantic understanding
|
| 19 |
|
| 20 |
+
**Total: 512 dimensions** per embedding (text or image)
|
| 21 |
+
|
| 22 |
+
**Key Innovation**: The first 80 dimensions are explicitly trained to align with specialized models through direct MSE and cosine similarity losses, ensuring guaranteed attribute positioning (GAP) while maintaining full CLIP capabilities in the remaining dimensions.
|
| 23 |
+
|
| 24 |
+
### Loss Functions
|
| 25 |
+
|
| 26 |
+
**1. Enhanced Contrastive Loss** (`enhanced_contrastive_loss`):
|
| 27 |
+
|
| 28 |
+
Combines multiple objectives:
|
| 29 |
+
- **Original Triple Loss**: Text-image-attributes contrastive learning
|
| 30 |
+
- **Color Alignment**: Forces dims 0-15 to match color model embeddings
|
| 31 |
+
- **Hierarchy Alignment**: Forces dims 16-79 to match hierarchy model embeddings
|
| 32 |
+
- **Reference Loss**: Optional regularization to stay close to base CLIP
|
| 33 |
+
|
| 34 |
+
**2. Alignment Components**:
|
| 35 |
+
```python
|
| 36 |
+
# Color alignment (text & image)
|
| 37 |
+
color_text_mse = F.mse_loss(main_color_dims, color_model_emb)
|
| 38 |
+
color_text_cosine = 1 - F.cosine_similarity(main_color_dims, color_model_emb).mean()
|
| 39 |
+
|
| 40 |
+
# Hierarchy alignment (text & image)
|
| 41 |
+
hierarchy_text_mse = F.mse_loss(main_hierarchy_dims, hierarchy_model_emb)
|
| 42 |
+
hierarchy_text_cosine = 1 - F.cosine_similarity(main_hierarchy_dims, hierarchy_model_emb).mean()
|
| 43 |
+
|
| 44 |
+
# Combined alignment
|
| 45 |
+
alignment_loss = (color_alignment + hierarchy_alignment) / 2
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**3. Final Loss**:
|
| 49 |
+
```python
|
| 50 |
+
total_loss = (1 - α) * contrastive_loss + α * alignment_loss + β * reference_loss
|
| 51 |
+
```
|
| 52 |
+
Where:
|
| 53 |
+
- α (alignment_weight) = 0.2 : Balances contrastive and alignment objectives
|
| 54 |
+
- β (reference_weight) = 0.1 : Keeps text space close to base CLIP
|
| 55 |
|
| 56 |
## 🚀 Installation
|
| 57 |
|
|
|
|
| 80 |
|
| 81 |
```
|
| 82 |
.
|
| 83 |
+
├── color_model.py # Color model architecture and training
|
| 84 |
+
├── hierarchy_model.py # Hierarchy model architecture and training
|
| 85 |
+
├── main_model.py # Main GAP-CLIP model with enhanced loss functions
|
| 86 |
+
├── train_main_model.py # Training script with optimized hyperparameters
|
| 87 |
├── config.py # Configuration for paths and parameters
|
| 88 |
+
├── example_usage.py # Usage examples and HuggingFace loading
|
| 89 |
+
├── tokenizer_vocab.json # Tokenizer vocabulary for color model
|
| 90 |
├── models/
|
| 91 |
+
│ ├── color_model.pt # Trained color model checkpoint
|
| 92 |
+
│ ├── hierarchy_model.pth # Trained hierarchy model checkpoint
|
| 93 |
+
│ └── gap_clip.pth # Main GAP-CLIP model checkpoint
|
| 94 |
+
├── evaluation/ # Comprehensive evaluation scripts
|
| 95 |
+
│ ├── main_model_evaluation.py # Main evaluation with 3 datasets
|
| 96 |
+
│ ├── evaluate_color_embeddings.py # Color embedding analysis
|
| 97 |
+
│ ├── hierarchy_evaluation.py # Hierarchy classification tests
|
| 98 |
+
│ ├── fashion_search.py # Interactive search demos
|
| 99 |
+
│ ├── 0_shot_classification.py # Zero-shot classification
|
| 100 |
+
│ ├── heatmap_color_similarities.py # Color similarity visualization
|
| 101 |
+
│ ├── tsne_images.py # t-SNE embedding visualization
|
| 102 |
+
│ └── basic_test_generalized.py # Basic functionality tests
|
| 103 |
+
├── data/
|
| 104 |
+
│ ├── data_with_local_paths.csv # Training dataset with annotations
|
| 105 |
+
│ ├── fashion-mnist_test.csv # Fashion-MNIST evaluation data
|
| 106 |
+
│ ├── download_data.py # Dataset download utilities
|
| 107 |
+
│ └── get_csv_from_chunks.py # Dataset preprocessing
|
| 108 |
+
├── optuna/ # Hyperparameter optimization
|
| 109 |
+
│ ├── optuna_optimisation.py # Optuna optimization script
|
| 110 |
+
│ ├── optuna_study.pkl # Saved optimization study
|
| 111 |
+
│ ├── optuna_results.txt # Best hyperparameters
|
| 112 |
+
│ ├── optuna_optimization_history.png # Optimization visualization
|
| 113 |
+
│ ├── optuna_param_importances.png # Parameter importance plot
|
| 114 |
+
│ └── optuna_guide.md # Optuna usage guide
|
| 115 |
+
├── upload_hf/ # HuggingFace Hub upload utilities
|
| 116 |
+
│ ├── upload_to_huggingface.py # Upload script
|
| 117 |
+
│ └── GUIDE_UPLOAD_HF.md # Upload guide
|
| 118 |
├── requirements.txt # Python dependencies
|
| 119 |
└── README.md # This documentation
|
| 120 |
```
|
| 121 |
|
| 122 |
+
### Key Files Description
|
| 123 |
+
|
| 124 |
+
**Core Model Files**:
|
| 125 |
+
- `color_model.py`: ResNet18-based color embedding model (16 dims)
|
| 126 |
+
- `hierarchy_model.py`: ResNet18-based hierarchy classification model (64 dims)
|
| 127 |
+
- `main_model.py`: GAP-CLIP implementation with enhanced contrastive loss
|
| 128 |
+
- `train_main_model.py`: Training with Optuna-optimized hyperparameters
|
| 129 |
+
|
| 130 |
+
**Configuration**:
|
| 131 |
+
- `config.py`: Central configuration for all paths, dimensions, and device settings
|
| 132 |
+
- `tokenizer_vocab.json`: Vocabulary for color model's text encoder
|
| 133 |
+
|
| 134 |
+
**Evaluation Suite**:
|
| 135 |
+
- `main_model_evaluation.py`: Comprehensive evaluation across Fashion-MNIST, KAGL, and local datasets
|
| 136 |
+
- Other evaluation scripts provide specialized analysis (color, hierarchy, search, etc.)
|
| 137 |
+
|
| 138 |
+
**Training Data**:
|
| 139 |
+
- `data_with_local_paths.csv`: Main training dataset with text, color, hierarchy, and image paths
|
| 140 |
+
- `fashion-mnist_test.csv`: Evaluation dataset for zero-shot generalization testing
|
| 141 |
+
|
| 142 |
## 🔧 Configuration
|
| 143 |
|
| 144 |
Main parameters are defined in `config.py`:
|
| 145 |
|
| 146 |
```python
|
| 147 |
+
# Embedding dimensions
|
| 148 |
+
color_emb_dim = 16 # Color embedding dimension (dims 0-15)
|
| 149 |
+
hierarchy_emb_dim = 64 # Hierarchy embedding dimension (dims 16-79)
|
| 150 |
+
|
| 151 |
+
# Device configuration
|
| 152 |
device = torch.device("mps") # Device (cuda, mps, cpu)
|
| 153 |
+
|
| 154 |
+
# Column names for dataset
|
| 155 |
+
text_column = 'text' # Description column
|
| 156 |
+
color_column = 'color' # Color label column
|
| 157 |
+
hierarchy_column = 'hierarchy' # Hierarchy category column
|
| 158 |
+
column_local_image_path = 'local_image_path' # Image path column
|
| 159 |
```
|
| 160 |
|
| 161 |
### Model Paths
|
| 162 |
|
| 163 |
+
Default paths configured in `config.py`:
|
| 164 |
+
- `models/color_model.pt` : Trained color model checkpoint
|
| 165 |
+
- `models/hierarchy_model.pth` : Trained hierarchy model checkpoint
|
| 166 |
+
- `models/gap_clip.pth` : Main GAP-CLIP model checkpoint
|
| 167 |
+
- `tokenizer_vocab.json` : Tokenizer vocabulary for color model
|
| 168 |
+
- `data_with_local_paths.csv` : Training/validation dataset
|
| 169 |
+
|
| 170 |
+
### Dataset Format
|
| 171 |
+
|
| 172 |
+
The training dataset CSV should contain:
|
| 173 |
+
- `text`: Text description of the fashion item
|
| 174 |
+
- `color`: Color label (e.g., "red", "blue", "black")
|
| 175 |
+
- `hierarchy`: Category label (e.g., "dress", "shirt", "shoes")
|
| 176 |
+
- `local_image_path`: Path to the image file
|
| 177 |
+
|
| 178 |
+
Example:
|
| 179 |
+
```csv
|
| 180 |
+
text,color,hierarchy,local_image_path
|
| 181 |
+
"red summer dress with floral pattern",red,dress,data/images/001.jpg
|
| 182 |
+
"blue denim jeans casual style",blue,jeans,data/images/002.jpg
|
| 183 |
+
```
|
| 184 |
|
| 185 |
## 📦 Usage
|
| 186 |
|
|
|
|
| 240 |
|
| 241 |
### 4. Using the Example Script
|
| 242 |
|
| 243 |
+
The `example_usage.py` provides ready-to-use examples for loading and using GAP-CLIP:
|
| 244 |
+
|
| 245 |
```bash
|
| 246 |
+
# Load from HuggingFace and search with text
|
| 247 |
+
python example_usage.py \
|
| 248 |
+
--repo-id Leacb4/gap-clip \
|
| 249 |
+
--text "red summer dress"
|
| 250 |
+
|
| 251 |
+
# Search with image
|
| 252 |
+
python example_usage.py \
|
| 253 |
+
--repo-id Leacb4/gap-clip \
|
| 254 |
+
--image path/to/image.jpg
|
| 255 |
+
|
| 256 |
+
# Both text and image
|
| 257 |
python example_usage.py \
|
| 258 |
+
--repo-id Leacb4/gap-clip \
|
| 259 |
+
--text "blue denim jeans" \
|
| 260 |
--image path/to/image.jpg
|
| 261 |
```
|
| 262 |
|
| 263 |
+
This script demonstrates:
|
| 264 |
+
- Loading models from HuggingFace Hub
|
| 265 |
+
- Extracting text and image embeddings
|
| 266 |
+
- Accessing color and hierarchy subspaces
|
| 267 |
+
- Measuring alignment quality with specialized models
|
| 268 |
+
|
| 269 |
## 🎯 Model Training
|
| 270 |
|
| 271 |
### Train the Color Model
|
|
|
|
| 296 |
|
| 297 |
### Train the Main CLIP Model
|
| 298 |
|
| 299 |
+
The main model trains with both specialized models using an enhanced contrastive loss.
|
| 300 |
|
| 301 |
+
**Option 1: Train with optimized hyperparameters (recommended)**:
|
| 302 |
```bash
|
| 303 |
+
python train_main_model.py
|
| 304 |
```
|
| 305 |
+
This uses hyperparameters optimized with Optuna (Trial 29, validation loss ~0.1129).
|
| 306 |
|
| 307 |
+
**Option 2: Train with default parameters**:
|
| 308 |
+
```bash
|
| 309 |
+
python main_model.py
|
| 310 |
+
```
|
| 311 |
+
This runs the main training loop with manually configured parameters.
|
| 312 |
+
|
| 313 |
+
**Default Training Parameters** (in `main_model.py`):
|
| 314 |
+
- `num_epochs = 20` : Number of training epochs
|
| 315 |
+
- `learning_rate = 1.5e-5` : Learning rate with AdamW optimizer
|
| 316 |
+
- `temperature = 0.09` : Temperature for softer contrastive learning
|
| 317 |
+
- `alignment_weight = 0.2` : Weight for color/hierarchy alignment loss
|
| 318 |
+
- `weight_decay = 5e-4` : L2 regularization to prevent overfitting
|
| 319 |
- `batch_size = 32` : Batch size
|
| 320 |
+
- `subset_size = 20000` : Dataset size for better generalization
|
| 321 |
+
- `reference_weight = 0.1` : Weight for base CLIP regularization
|
| 322 |
+
|
| 323 |
+
**Enhanced Loss Function**:
|
| 324 |
+
|
| 325 |
+
The training uses `enhanced_contrastive_loss` which combines:
|
| 326 |
|
| 327 |
+
1. **Triple Contrastive Loss** (weighted):
|
| 328 |
+
- Text-Image alignment (70%)
|
| 329 |
+
- Text-Attributes alignment (15%)
|
| 330 |
+
- Image-Attributes alignment (15%)
|
|
|
|
|
|
|
| 331 |
|
| 332 |
+
2. **Direct Alignment Loss** (combines color & hierarchy):
|
| 333 |
+
- MSE loss between main model color dims (0-15) and color model embeddings
|
| 334 |
+
- MSE loss between main model hierarchy dims (16-79) and hierarchy model embeddings
|
| 335 |
+
- Cosine similarity losses for both color and hierarchy
|
| 336 |
+
- Applied to both text and image embeddings
|
| 337 |
+
|
| 338 |
+
3. **Reference Model Loss** (optional):
|
| 339 |
+
- Keeps text embeddings close to base CLIP
|
| 340 |
+
- Improves cross-domain generalization
|
| 341 |
+
|
| 342 |
+
**Training Features**:
|
| 343 |
+
- Enhanced data augmentation (rotation, color jitter, blur, affine transforms)
|
| 344 |
+
- Gradient clipping (max_norm=1.0) to prevent exploding gradients
|
| 345 |
+
- ReduceLROnPlateau scheduler (patience=3, factor=0.5)
|
| 346 |
+
- Early stopping (patience=7)
|
| 347 |
+
- Automatic best model saving with checkpoints
|
| 348 |
+
- Detailed metrics logging (alignment losses, cosine similarities)
|
| 349 |
+
- Overfitting detection and warnings
|
| 350 |
+
- Training curves visualization with 3 plots (losses, overfitting gap, comparison)
|
| 351 |
+
|
| 352 |
+
### Hyperparameter Optimization
|
| 353 |
+
|
| 354 |
+
The project includes Optuna-based hyperparameter optimization in `optuna/`:
|
| 355 |
+
|
| 356 |
+
```bash
|
| 357 |
+
cd optuna
|
| 358 |
+
python optuna_optimisation.py
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
This optimizes:
|
| 362 |
+
- Learning rate
|
| 363 |
+
- Temperature for contrastive loss
|
| 364 |
+
- Alignment weight
|
| 365 |
+
- Weight decay
|
| 366 |
+
|
| 367 |
+
Results are saved in `optuna_study.pkl` and visualizations in `optuna_optimization_history.png` and `optuna_param_importances.png`.
|
| 368 |
+
|
| 369 |
+
The best hyperparameters from Optuna optimization are used in `train_main_model.py`.
|
| 370 |
|
| 371 |
## 📊 Models
|
| 372 |
|
|
|
|
| 384 |
- **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
|
| 385 |
- **Usage** : Classify and encode categorical hierarchy
|
| 386 |
|
| 387 |
+
### Main CLIP Model (GAP-CLIP)
|
| 388 |
|
| 389 |
- **Architecture** : CLIP ViT-B/32 (LAION)
|
| 390 |
+
- **Base Model** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
|
| 391 |
+
- **Training Approach** : Enhanced contrastive loss with direct attribute alignment
|
| 392 |
+
- **Embedding Dimensions** : 512 total
|
| 393 |
+
- Color subspace: dims 0-15 (16 dims)
|
| 394 |
+
- Hierarchy subspace: dims 16-79 (64 dims)
|
| 395 |
+
- General CLIP: dims 80-511 (432 dims)
|
| 396 |
+
- **Training Dataset** : 20,000 fashion items with color and hierarchy annotations
|
| 397 |
+
- **Validation Split** : 80/20 train-validation split
|
| 398 |
+
- **Optimizer** : AdamW with weight decay (5e-4)
|
| 399 |
+
- **Best Checkpoint** : Automatically saved based on validation loss
|
| 400 |
- **Features** :
|
| 401 |
+
- Multi-modal text-image search
|
| 402 |
+
- Guaranteed attribute positioning (GAP) in specific dimensions
|
| 403 |
+
- Direct alignment with specialized color and hierarchy models
|
| 404 |
+
- Maintains general CLIP capabilities for cross-domain tasks
|
| 405 |
+
- Reduced overfitting through augmentation and regularization
|
| 406 |
|
| 407 |
## 🔍 Advanced Usage Examples
|
| 408 |
|
| 409 |
### Search with Combined Embeddings
|
| 410 |
|
| 411 |
```python
|
| 412 |
+
import torch
|
| 413 |
import torch.nn.functional as F
|
| 414 |
|
| 415 |
# Text query
|
|
|
|
| 420 |
# Main model embeddings
|
| 421 |
with torch.no_grad():
|
| 422 |
outputs = main_model(**text_inputs)
|
| 423 |
+
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 424 |
|
| 425 |
# Extract specialized embeddings from main model
|
| 426 |
+
main_color_emb = text_features[:, :16] # Color dimensions (0-15)
|
| 427 |
+
main_hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions (16-79)
|
| 428 |
+
main_clip_emb = text_features[:, 80:] # General CLIP dimensions (80-511)
|
| 429 |
|
| 430 |
# Compare with specialized models
|
| 431 |
color_emb = color_model.get_text_embeddings([text_query])
|
| 432 |
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 433 |
|
| 434 |
+
# Measure alignment quality
|
| 435 |
color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
| 436 |
hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
|
| 437 |
+
|
| 438 |
+
print(f"Color alignment: {color_similarity.item():.4f}")
|
| 439 |
+
print(f"Hierarchy alignment: {hierarchy_similarity.item():.4f}")
|
| 440 |
+
|
| 441 |
+
# For search, you can use different strategies:
|
| 442 |
+
# 1. Use full embeddings for general search
|
| 443 |
+
# 2. Use color subspace for color-specific search
|
| 444 |
+
# 3. Use hierarchy subspace for category search
|
| 445 |
+
# 4. Weighted combination of subspaces
|
| 446 |
```
|
| 447 |
|
| 448 |
### Search in an Image Database
|
| 449 |
|
| 450 |
```python
|
| 451 |
import numpy as np
|
| 452 |
+
import torch
|
| 453 |
+
import torch.nn.functional as F
|
| 454 |
+
from tqdm import tqdm
|
| 455 |
|
| 456 |
+
# Step 1: Pre-compute image embeddings (do this once)
|
| 457 |
image_paths = [...] # List of image paths
|
| 458 |
image_features_list = []
|
| 459 |
|
| 460 |
+
print("Computing image embeddings...")
|
| 461 |
+
for img_path in tqdm(image_paths):
|
| 462 |
image = Image.open(img_path).convert("RGB")
|
| 463 |
image_inputs = processor(images=[image], return_tensors="pt")
|
| 464 |
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 465 |
|
| 466 |
with torch.no_grad():
|
| 467 |
outputs = main_model(**image_inputs)
|
| 468 |
+
features = outputs.image_embeds # Shape: [1, 512]
|
| 469 |
+
image_features_list.append(features.cpu())
|
| 470 |
|
| 471 |
+
# Stack all features
|
| 472 |
+
image_features = torch.cat(image_features_list, dim=0) # Shape: [N, 512]
|
| 473 |
|
| 474 |
+
# Step 2: Search with text query
|
| 475 |
query = "red dress"
|
| 476 |
+
text_inputs = processor(text=[query], padding=True, return_tensors="pt")
|
| 477 |
+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 478 |
+
|
| 479 |
+
with torch.no_grad():
|
| 480 |
+
outputs = main_model(**text_inputs)
|
| 481 |
+
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 482 |
|
| 483 |
+
# Step 3: Calculate similarities
|
| 484 |
+
# Normalize embeddings for cosine similarity
|
| 485 |
+
text_features_norm = F.normalize(text_features, dim=-1)
|
| 486 |
+
image_features_norm = F.normalize(image_features.to(device), dim=-1)
|
|
|
|
|
|
|
| 487 |
|
| 488 |
+
# Compute cosine similarities
|
| 489 |
+
similarities = (text_features_norm @ image_features_norm.T).squeeze(0) # Shape: [N]
|
| 490 |
+
|
| 491 |
+
# Step 4: Get top-k results
|
| 492 |
top_k = 10
|
| 493 |
+
top_scores, top_indices = similarities.topk(top_k, largest=True)
|
| 494 |
+
|
| 495 |
+
# Display results
|
| 496 |
+
print(f"\nTop {top_k} results for query: '{query}'")
|
| 497 |
+
for i, (idx, score) in enumerate(zip(top_indices, top_scores)):
|
| 498 |
+
print(f"{i+1}. {image_paths[idx]} (similarity: {score.item():.4f})")
|
| 499 |
+
|
| 500 |
+
# Optional: Filter by color or hierarchy
|
| 501 |
+
# Extract color embeddings from query
|
| 502 |
+
query_color_emb = text_features[:, :16]
|
| 503 |
+
# Extract hierarchy embeddings from query
|
| 504 |
+
query_hierarchy_emb = text_features[:, 16:80]
|
| 505 |
+
# Use these for more targeted search
|
| 506 |
```
|
| 507 |
|
| 508 |
## 📝 Evaluation
|
| 509 |
|
| 510 |
+
### Comprehensive Model Evaluation
|
| 511 |
+
|
| 512 |
+
The main evaluation script `evaluation/main_model_evaluation.py` provides extensive testing across multiple datasets:
|
| 513 |
+
|
| 514 |
+
```bash
|
| 515 |
+
python evaluation/main_model_evaluation.py
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
**Evaluation Datasets**:
|
| 519 |
+
1. **Fashion-MNIST** (~10,000 samples) - Grayscale fashion items
|
| 520 |
+
2. **KAGL Marqo** (HuggingFace dataset) - Real fashion images with metadata
|
| 521 |
+
3. **Local Validation Dataset** - Custom validation set with local images
|
| 522 |
+
|
| 523 |
+
**Evaluation Metrics**:
|
| 524 |
+
|
| 525 |
+
For each dataset, the evaluation measures:
|
| 526 |
+
|
| 527 |
+
1. **Color Embeddings Performance** (dimensions 0-15):
|
| 528 |
+
- Nearest Neighbor (NN) Accuracy: Classification accuracy using nearest neighbor
|
| 529 |
+
- Centroid Accuracy: Classification using cluster centroids
|
| 530 |
+
- Separation Score: How well color embeddings separate different classes
|
| 531 |
+
|
| 532 |
+
2. **Hierarchy Embeddings Performance** (dimensions 16-79):
|
| 533 |
+
- Nearest Neighbor (NN) Accuracy: Classification accuracy for fashion categories
|
| 534 |
+
- Centroid Accuracy: Cluster-based classification
|
| 535 |
+
- Separation Score: Class separation quality
|
| 536 |
+
|
| 537 |
+
3. **Full Embeddings Performance** (all 512 dimensions):
|
| 538 |
+
- Evaluates complete embedding space
|
| 539 |
+
- Compares subspace (color/hierarchy) vs full embedding effectiveness
|
| 540 |
+
|
| 541 |
+
**Baseline Comparison**:
|
| 542 |
+
|
| 543 |
+
The evaluation includes comparison against `patrickjohncyh/fashion-clip`:
|
| 544 |
+
- Direct performance comparison on the same datasets
|
| 545 |
+
- Improvement metrics calculation
|
| 546 |
+
- Statistical significance analysis
|
| 547 |
+
|
| 548 |
+
**Key Evaluation Functions**:
|
| 549 |
+
- `evaluate_fashion_mnist()` : Test on Fashion-MNIST dataset
|
| 550 |
+
- `evaluate_kaggle_marqo()` : Test on real fashion images
|
| 551 |
+
- `evaluate_local_validation()` : Test on local validation set
|
| 552 |
+
- `evaluate_baseline_fashion_mnist()` : Baseline model on Fashion-MNIST
|
| 553 |
+
- `evaluate_baseline_kaggle_marqo()` : Baseline model on KAGL
|
| 554 |
+
- `evaluate_full_embeddings()` : Test complete 512D space
|
| 555 |
+
- `analyze_baseline_vs_trained_performance()` : Comparative analysis
|
| 556 |
+
- `compare_subspace_vs_full_embeddings()` : Subspace effectiveness
|
| 557 |
+
|
| 558 |
+
**Visualization Outputs** (saved in analysis directory):
|
| 559 |
+
- Confusion matrices for color and hierarchy classification
|
| 560 |
+
- t-SNE projections of embeddings
|
| 561 |
+
- Similarity heatmaps
|
| 562 |
+
- Performance comparison charts
|
| 563 |
+
|
| 564 |
+
**Other Evaluation Scripts**:
|
| 565 |
+
- `evaluate_color_embeddings.py` : Focused color embeddings evaluation
|
| 566 |
+
- `fashion_search.py` : Interactive fashion search tests
|
| 567 |
+
- `hierarchy_evaluation.py` : Hierarchy classification analysis
|
| 568 |
+
- `0_shot_classification.py` : Zero-shot classification tests
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
## 📊 Performance & Results
|
| 572 |
+
|
| 573 |
+
The evaluation framework (`main_model_evaluation.py`) tests the model across three datasets with comparison to a baseline fashion CLIP model.
|
| 574 |
+
|
| 575 |
+
### Evaluation Metrics
|
| 576 |
+
|
| 577 |
+
**Color Classification** (dimensions 0-15):
|
| 578 |
+
- Nearest Neighbor Accuracy
|
| 579 |
+
- Centroid-based Accuracy
|
| 580 |
+
- Separation Score (class separability)
|
| 581 |
+
|
| 582 |
+
**Hierarchy Classification** (dimensions 16-79):
|
| 583 |
+
- Nearest Neighbor Accuracy
|
| 584 |
+
- Centroid-based Accuracy
|
| 585 |
+
- Separation Score
|
| 586 |
+
|
| 587 |
+
**Full Embedding Quality** (all 512 dims):
|
| 588 |
+
- Tests whether the full space maintains performance
|
| 589 |
+
- Compares subspace vs full embedding effectiveness
|
| 590 |
+
|
| 591 |
+
### Datasets Used for Evaluation
|
| 592 |
|
| 593 |
+
1. **Fashion-MNIST**: 10,000 grayscale fashion item images
|
| 594 |
+
- 10 categories (T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)
|
| 595 |
+
- Mapped to model's hierarchy classes
|
| 596 |
|
| 597 |
+
2. **KAGL Marqo Dataset**: Real-world fashion images from HuggingFace
|
| 598 |
+
- Diverse fashion items with rich metadata
|
| 599 |
+
- Color and category annotations
|
| 600 |
+
- Realistic product images
|
| 601 |
+
|
| 602 |
+
3. **Local Validation Set**: Custom validation dataset
|
| 603 |
+
- Fashion items with local image paths
|
| 604 |
+
- Annotated with colors and hierarchies
|
| 605 |
+
- Domain-specific evaluation
|
| 606 |
+
|
| 607 |
+
### Comparative Analysis
|
| 608 |
+
|
| 609 |
+
The evaluation includes:
|
| 610 |
+
- **Baseline comparison**: GAP-CLIP vs `patrickjohncyh/fashion-clip`
|
| 611 |
+
- **Subspace analysis**: Dedicated dimensions (0-79) vs full space (0-511)
|
| 612 |
+
- **Cross-dataset generalization**: Performance consistency across datasets
|
| 613 |
+
- **Alignment quality**: How well specialized dimensions match expert models
|
| 614 |
+
|
| 615 |
+
All visualizations (confusion matrices, t-SNE plots, heatmaps) are automatically saved in the analysis directory.
|
| 616 |
|
| 617 |
## 📄 Citation
|
| 618 |
|
| 619 |
+
If you use GAP-CLIP in your research, please cite:
|
| 620 |
|
| 621 |
```bibtex
|
| 622 |
+
@misc{gap-clip-2024,
|
| 623 |
+
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 624 |
+
author={Sarfati, Lea Attia},
|
| 625 |
year={2024},
|
| 626 |
+
note={A multi-loss framework combining contrastive learning with direct attribute alignment},
|
| 627 |
+
howpublished={\url{https://huggingface.co/Leacb4/gap-clip}},
|
| 628 |
+
abstract={GAP-CLIP introduces a novel training approach that guarantees specific embedding
|
| 629 |
+
dimensions encode color (dims 0-15) and hierarchy (dims 16-79) information through
|
| 630 |
+
direct alignment with specialized models, while maintaining full CLIP capabilities
|
| 631 |
+
in the remaining dimensions (80-511).}
|
| 632 |
}
|
| 633 |
```
|
| 634 |
|
| 635 |
+
### Key Contributions
|
| 636 |
+
|
| 637 |
+
- **Guaranteed Attribute Positioning**: Specific dimensions reliably encode color and hierarchy
|
| 638 |
+
- **Multi-Loss Training**: Combines contrastive learning with MSE and cosine alignment losses
|
| 639 |
+
- **Specialized Model Alignment**: Direct supervision from expert color and hierarchy models
|
| 640 |
+
- **Preserved Generalization**: Maintains base CLIP capabilities for cross-domain tasks
|
| 641 |
+
- **Comprehensive Evaluation**: Tested across multiple datasets with baseline comparisons
|
| 642 |
+
|
| 643 |
+
## ❓ FAQ & Troubleshooting
|
| 644 |
+
|
| 645 |
+
### Q: What are the minimum hardware requirements?
|
| 646 |
+
|
| 647 |
+
**A**:
|
| 648 |
+
- **GPU**: Recommended for training (CUDA or MPS). CPU training is very slow.
|
| 649 |
+
- **RAM**: Minimum 16GB, recommended 32GB for training
|
| 650 |
+
- **Storage**: ~5GB for models and datasets
|
| 651 |
+
|
| 652 |
+
### Q: Why are my embeddings not aligned?
|
| 653 |
+
|
| 654 |
+
**A**: Check that:
|
| 655 |
+
1. You're using the correct dimension ranges (0-15 for color, 16-79 for hierarchy)
|
| 656 |
+
2. The model was trained with alignment_weight > 0
|
| 657 |
+
3. Color and hierarchy models were properly loaded during training
|
| 658 |
+
|
| 659 |
+
### Q: How do I use only the color or hierarchy subspace for search?
|
| 660 |
+
|
| 661 |
+
**A**:
|
| 662 |
+
```python
|
| 663 |
+
# Extract and use only color embeddings
|
| 664 |
+
text_color_emb = text_features[:, :16]
|
| 665 |
+
image_color_emb = image_features[:, :16]
|
| 666 |
+
color_similarity = F.cosine_similarity(text_color_emb, image_color_emb)
|
| 667 |
+
|
| 668 |
+
# Extract and use only hierarchy embeddings
|
| 669 |
+
text_hierarchy_emb = text_features[:, 16:80]
|
| 670 |
+
image_hierarchy_emb = image_features[:, 16:80]
|
| 671 |
+
hierarchy_similarity = F.cosine_similarity(text_hierarchy_emb, image_hierarchy_emb)
|
| 672 |
+
```
|
| 673 |
+
|
| 674 |
+
### Q: Can I add more attributes beyond color and hierarchy?
|
| 675 |
+
|
| 676 |
+
**A**: Yes! The architecture is extensible:
|
| 677 |
+
1. Train a new specialized model for your attribute
|
| 678 |
+
2. Reserve additional dimensions in the embedding space
|
| 679 |
+
3. Add alignment losses for these dimensions in `enhanced_contrastive_loss`
|
| 680 |
+
4. Update `config.py` with new dimension ranges
|
| 681 |
+
|
| 682 |
+
### Q: How do I evaluate on my own dataset?
|
| 683 |
+
|
| 684 |
+
**A**:
|
| 685 |
+
1. Format your dataset as CSV with columns: `text`, `color`, `hierarchy`, `local_image_path`
|
| 686 |
+
2. Update `config.local_dataset_path` in `config.py`
|
| 687 |
+
3. Run the evaluation: `python evaluation/main_model_evaluation.py`
|
| 688 |
+
|
| 689 |
+
### Q: Training loss is decreasing but validation loss is increasing. What should I do?
|
| 690 |
+
|
| 691 |
+
**A**: This indicates overfitting. Try:
|
| 692 |
+
- Increase `weight_decay` (e.g., from 5e-4 to 1e-3)
|
| 693 |
+
- Reduce `alignment_weight` (e.g., from 0.2 to 0.1)
|
| 694 |
+
- Increase dataset size (`subset_size`)
|
| 695 |
+
- Add more data augmentation in `CustomDataset`
|
| 696 |
+
- Enable or increase early stopping patience
|
| 697 |
+
|
| 698 |
+
### Q: Can I fine-tune GAP-CLIP on a specific domain?
|
| 699 |
+
|
| 700 |
+
**A**: Yes! Load the checkpoint and continue training:
|
| 701 |
+
```python
|
| 702 |
+
checkpoint = torch.load('models/gap_clip.pth')
|
| 703 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 704 |
+
# Continue training with your domain-specific data
|
| 705 |
+
```
|
| 706 |
+
|
| 707 |
## 🤝 Contributing
|
| 708 |
|
| 709 |
Contributions are welcome! Feel free to open an issue or a pull request.
|