Leacb4 commited on
Commit
2027b7a
·
verified ·
1 Parent(s): 9691498

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +464 -79
README.md CHANGED
@@ -12,12 +12,46 @@ This project implements an advanced fashion search system based on CLIP, with th
12
 
13
  ### Architecture
14
 
15
- The main model combines:
16
- - **16 dimensions** : Color embeddings (first 16 dimensions)
17
- - **64 dimensions** : Hierarchy embeddings (dimensions 16-80)
18
- - **512 dimensions** : Standard CLIP embeddings (following dimensions)
19
 
20
- Total : **512 dimensions** for each embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ## 🚀 Installation
23
 
@@ -46,43 +80,107 @@ pip install -r requirements.txt
46
 
47
  ```
48
  .
49
- ├── color_model.py # Color model definition
50
- ├── hierarchy_model.py # Hierarchy model definition
51
- ├── main_model.py # Main CLIP model and training functions
 
52
  ├── config.py # Configuration for paths and parameters
53
- ├── tokenizer_vocab.json # Tokenizer vocabulary
54
- ├── example_usage.py # Usage examples
55
  ├── models/
56
- │ ├── color_model.pt # Trained color model
57
- │ ├── hierarchy_model.pth # Trained hierarchy model
58
- │ └── gap_clip.pth # Main CLIP model
59
- ├── evaluation/ # Evaluation scripts
60
- │ ├── evaluate_color_embeddings.py
61
- │ ├── main_model_evaluation.py
62
- ── ...
63
- ├── data/ # Training data
64
- │ ├── download_data.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  ├── requirements.txt # Python dependencies
66
  └── README.md # This documentation
67
  ```
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ## 🔧 Configuration
70
 
71
  Main parameters are defined in `config.py`:
72
 
73
  ```python
74
- color_emb_dim = 16 # Color embedding dimension
75
- hierarchy_emb_dim = 64 # Hierarchy embedding dimension
 
 
 
76
  device = torch.device("mps") # Device (cuda, mps, cpu)
 
 
 
 
 
 
77
  ```
78
 
79
  ### Model Paths
80
 
81
- Default paths are:
82
- - `models/color_model.pt` : Color model
83
- - `models/hierarchy_model.pth` : Hierarchy model
84
- - `models/gap_clip.pth` : Main CLIP model
85
- - `tokenizer_vocab.json` : Tokenizer vocabulary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  ## 📦 Usage
88
 
@@ -142,13 +240,32 @@ with torch.no_grad():
142
 
143
  ### 4. Using the Example Script
144
 
 
 
145
  ```bash
 
 
 
 
 
 
 
 
 
 
 
146
  python example_usage.py \
147
- --repo-id your-username/your-model \
148
- --text "blue jacket" \
149
  --image path/to/image.jpg
150
  ```
151
 
 
 
 
 
 
 
152
  ## 🎯 Model Training
153
 
154
  ### Train the Color Model
@@ -179,27 +296,77 @@ train_hierarchy_model(model, train_loader, val_loader, num_epochs=20)
179
 
180
  ### Train the Main CLIP Model
181
 
182
- The main model trains with both specialized models:
183
 
 
184
  ```bash
185
- python main_model.py
186
  ```
 
187
 
188
- **Training Parameters** (in `main_model.py`):
189
- - `num_epochs = 20` : Number of epochs
190
- - `learning_rate = 1e-5` : Learning rate
191
- - `temperature = 0.07` : Temperature for contrastive loss
192
- - `alignment_weight = 0.5` : Weight for embedding alignment
 
 
 
 
 
 
 
193
  - `batch_size = 32` : Batch size
194
- - `use_enhanced_loss = True` : Use enhanced loss with alignment
 
 
 
 
 
195
 
196
- **Features**:
197
- - Triple contrastive loss (text-image-attributes)
198
- - Direct alignment between specialized models and main model
199
- - Early stopping with patience
200
- - Automatic learning rate reduction
201
- - Automatic best model saving
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  ## 📊 Models
205
 
@@ -217,22 +384,32 @@ python main_model.py
217
  - **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
218
  - **Usage** : Classify and encode categorical hierarchy
219
 
220
- ### Main CLIP Model
221
 
222
  - **Architecture** : CLIP ViT-B/32 (LAION)
223
- - **Base** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
224
- - **Trained with** : Triple contrastive loss including color and hierarchy embeddings
225
- - **Dimensions** : 592 (512 CLIP + 16 color + 64 hierarchy)
 
 
 
 
 
 
 
226
  - **Features** :
227
- - Text-image search
228
- - Specialized embeddings (color + hierarchy)
229
- - Alignment with specialized models
 
 
230
 
231
  ## 🔍 Advanced Usage Examples
232
 
233
  ### Search with Combined Embeddings
234
 
235
  ```python
 
236
  import torch.nn.functional as F
237
 
238
  # Text query
@@ -243,82 +420,290 @@ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
243
  # Main model embeddings
244
  with torch.no_grad():
245
  outputs = main_model(**text_inputs)
246
- text_features = outputs.text_embeds
247
 
248
  # Extract specialized embeddings from main model
249
- main_color_emb = text_features[:, :16] # First 16 dimensions
250
- main_hierarchy_emb = text_features[:, 16:80] # Dimensions 16-80
251
- main_clip_emb = text_features[:, 80:] # CLIP dimensions (80+)
252
 
253
  # Compare with specialized models
254
  color_emb = color_model.get_text_embeddings([text_query])
255
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
256
 
257
- # Cosine similarity
258
  color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
259
  hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
 
 
 
 
 
 
 
 
 
260
  ```
261
 
262
  ### Search in an Image Database
263
 
264
  ```python
265
  import numpy as np
 
 
 
266
 
267
- # Load all images from database
268
  image_paths = [...] # List of image paths
269
  image_features_list = []
270
 
271
- for img_path in image_paths:
 
272
  image = Image.open(img_path).convert("RGB")
273
  image_inputs = processor(images=[image], return_tensors="pt")
274
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
275
 
276
  with torch.no_grad():
277
  outputs = main_model(**image_inputs)
278
- features = outputs.image_embeds
279
- image_features_list.append(features.cpu().numpy())
280
 
281
- # Convert to numpy array
282
- image_features = np.vstack(image_features_list)
283
 
284
- # Search
285
  query = "red dress"
286
- # ... get text_features ...
 
 
 
 
 
287
 
288
- # Calculate similarities
289
- similarities = F.cosine_similarity(
290
- text_features,
291
- torch.from_numpy(image_features).to(device),
292
- dim=1
293
- )
294
 
295
- # Sort by similarity
 
 
 
296
  top_k = 10
297
- top_indices = similarities.argsort(descending=True)[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
298
  ```
299
 
300
  ## 📝 Evaluation
301
 
302
- Evaluation scripts are available in `evaluation/`:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- - `evaluate_color_embeddings.py` : Color embeddings evaluation
305
- - `main_model_evaluation.py` : Main model evaluation
306
- - `fashion_search.py` : Fashion search tests
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  ## 📄 Citation
310
 
311
- If you use these models, please cite:
312
 
313
  ```bibtex
314
- @misc{fashion-search-model,
315
- title={GAP (Guaranteed Attribute Position) CLIP: A Multi-Loss Framework for Attribute-Aware Fashion Embeddings},
316
- author={ },
317
  year={2024},
318
- howpublished={\url{https://huggingface.co/Leacb4/gap-clip}}
 
 
 
 
 
319
  }
320
  ```
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  ## 🤝 Contributing
323
 
324
  Contributions are welcome! Feel free to open an issue or a pull request.
 
12
 
13
  ### Architecture
14
 
15
+ The main model's embedding structure:
16
+ - **Dimensions 0-15** (16 dims): Color embeddings aligned with specialized color model
17
+ - **Dimensions 16-79** (64 dims): Hierarchy embeddings aligned with specialized hierarchy model
18
+ - **Dimensions 80-511** (432 dims): Standard CLIP embeddings for general visual-semantic understanding
19
 
20
+ **Total: 512 dimensions** per embedding (text or image)
21
+
22
+ **Key Innovation**: The first 80 dimensions are explicitly trained to align with specialized models through direct MSE and cosine similarity losses, ensuring guaranteed attribute positioning (GAP) while maintaining full CLIP capabilities in the remaining dimensions.
23
+
24
+ ### Loss Functions
25
+
26
+ **1. Enhanced Contrastive Loss** (`enhanced_contrastive_loss`):
27
+
28
+ Combines multiple objectives:
29
+ - **Original Triple Loss**: Text-image-attributes contrastive learning
30
+ - **Color Alignment**: Forces dims 0-15 to match color model embeddings
31
+ - **Hierarchy Alignment**: Forces dims 16-79 to match hierarchy model embeddings
32
+ - **Reference Loss**: Optional regularization to stay close to base CLIP
33
+
34
+ **2. Alignment Components**:
35
+ ```python
36
+ # Color alignment (text & image)
37
+ color_text_mse = F.mse_loss(main_color_dims, color_model_emb)
38
+ color_text_cosine = 1 - F.cosine_similarity(main_color_dims, color_model_emb).mean()
39
+
40
+ # Hierarchy alignment (text & image)
41
+ hierarchy_text_mse = F.mse_loss(main_hierarchy_dims, hierarchy_model_emb)
42
+ hierarchy_text_cosine = 1 - F.cosine_similarity(main_hierarchy_dims, hierarchy_model_emb).mean()
43
+
44
+ # Combined alignment
45
+ alignment_loss = (color_alignment + hierarchy_alignment) / 2
46
+ ```
47
+
48
+ **3. Final Loss**:
49
+ ```python
50
+ total_loss = (1 - α) * contrastive_loss + α * alignment_loss + β * reference_loss
51
+ ```
52
+ Where:
53
+ - α (alignment_weight) = 0.2 : Balances contrastive and alignment objectives
54
+ - β (reference_weight) = 0.1 : Keeps text space close to base CLIP
55
 
56
  ## 🚀 Installation
57
 
 
80
 
81
  ```
82
  .
83
+ ├── color_model.py # Color model architecture and training
84
+ ├── hierarchy_model.py # Hierarchy model architecture and training
85
+ ├── main_model.py # Main GAP-CLIP model with enhanced loss functions
86
+ ├── train_main_model.py # Training script with optimized hyperparameters
87
  ├── config.py # Configuration for paths and parameters
88
+ ├── example_usage.py # Usage examples and HuggingFace loading
89
+ ├── tokenizer_vocab.json # Tokenizer vocabulary for color model
90
  ├── models/
91
+ │ ├── color_model.pt # Trained color model checkpoint
92
+ │ ├── hierarchy_model.pth # Trained hierarchy model checkpoint
93
+ │ └── gap_clip.pth # Main GAP-CLIP model checkpoint
94
+ ├── evaluation/ # Comprehensive evaluation scripts
95
+ │ ├── main_model_evaluation.py # Main evaluation with 3 datasets
96
+ │ ├── evaluate_color_embeddings.py # Color embedding analysis
97
+ ── hierarchy_evaluation.py # Hierarchy classification tests
98
+ ├── fashion_search.py # Interactive search demos
99
+ │ ├── 0_shot_classification.py # Zero-shot classification
100
+ │ ├── heatmap_color_similarities.py # Color similarity visualization
101
+ │ ├── tsne_images.py # t-SNE embedding visualization
102
+ │ └── basic_test_generalized.py # Basic functionality tests
103
+ ├── data/
104
+ │ ├── data_with_local_paths.csv # Training dataset with annotations
105
+ │ ├── fashion-mnist_test.csv # Fashion-MNIST evaluation data
106
+ │ ├── download_data.py # Dataset download utilities
107
+ │ └── get_csv_from_chunks.py # Dataset preprocessing
108
+ ├── optuna/ # Hyperparameter optimization
109
+ │ ├── optuna_optimisation.py # Optuna optimization script
110
+ │ ├── optuna_study.pkl # Saved optimization study
111
+ │ ├── optuna_results.txt # Best hyperparameters
112
+ │ ├── optuna_optimization_history.png # Optimization visualization
113
+ │ ├── optuna_param_importances.png # Parameter importance plot
114
+ │ └── optuna_guide.md # Optuna usage guide
115
+ ├── upload_hf/ # HuggingFace Hub upload utilities
116
+ │ ├── upload_to_huggingface.py # Upload script
117
+ │ └── GUIDE_UPLOAD_HF.md # Upload guide
118
  ├── requirements.txt # Python dependencies
119
  └── README.md # This documentation
120
  ```
121
 
122
+ ### Key Files Description
123
+
124
+ **Core Model Files**:
125
+ - `color_model.py`: ResNet18-based color embedding model (16 dims)
126
+ - `hierarchy_model.py`: ResNet18-based hierarchy classification model (64 dims)
127
+ - `main_model.py`: GAP-CLIP implementation with enhanced contrastive loss
128
+ - `train_main_model.py`: Training with Optuna-optimized hyperparameters
129
+
130
+ **Configuration**:
131
+ - `config.py`: Central configuration for all paths, dimensions, and device settings
132
+ - `tokenizer_vocab.json`: Vocabulary for color model's text encoder
133
+
134
+ **Evaluation Suite**:
135
+ - `main_model_evaluation.py`: Comprehensive evaluation across Fashion-MNIST, KAGL, and local datasets
136
+ - Other evaluation scripts provide specialized analysis (color, hierarchy, search, etc.)
137
+
138
+ **Training Data**:
139
+ - `data_with_local_paths.csv`: Main training dataset with text, color, hierarchy, and image paths
140
+ - `fashion-mnist_test.csv`: Evaluation dataset for zero-shot generalization testing
141
+
142
  ## 🔧 Configuration
143
 
144
  Main parameters are defined in `config.py`:
145
 
146
  ```python
147
+ # Embedding dimensions
148
+ color_emb_dim = 16 # Color embedding dimension (dims 0-15)
149
+ hierarchy_emb_dim = 64 # Hierarchy embedding dimension (dims 16-79)
150
+
151
+ # Device configuration
152
  device = torch.device("mps") # Device (cuda, mps, cpu)
153
+
154
+ # Column names for dataset
155
+ text_column = 'text' # Description column
156
+ color_column = 'color' # Color label column
157
+ hierarchy_column = 'hierarchy' # Hierarchy category column
158
+ column_local_image_path = 'local_image_path' # Image path column
159
  ```
160
 
161
  ### Model Paths
162
 
163
+ Default paths configured in `config.py`:
164
+ - `models/color_model.pt` : Trained color model checkpoint
165
+ - `models/hierarchy_model.pth` : Trained hierarchy model checkpoint
166
+ - `models/gap_clip.pth` : Main GAP-CLIP model checkpoint
167
+ - `tokenizer_vocab.json` : Tokenizer vocabulary for color model
168
+ - `data_with_local_paths.csv` : Training/validation dataset
169
+
170
+ ### Dataset Format
171
+
172
+ The training dataset CSV should contain:
173
+ - `text`: Text description of the fashion item
174
+ - `color`: Color label (e.g., "red", "blue", "black")
175
+ - `hierarchy`: Category label (e.g., "dress", "shirt", "shoes")
176
+ - `local_image_path`: Path to the image file
177
+
178
+ Example:
179
+ ```csv
180
+ text,color,hierarchy,local_image_path
181
+ "red summer dress with floral pattern",red,dress,data/images/001.jpg
182
+ "blue denim jeans casual style",blue,jeans,data/images/002.jpg
183
+ ```
184
 
185
  ## 📦 Usage
186
 
 
240
 
241
  ### 4. Using the Example Script
242
 
243
+ The `example_usage.py` provides ready-to-use examples for loading and using GAP-CLIP:
244
+
245
  ```bash
246
+ # Load from HuggingFace and search with text
247
+ python example_usage.py \
248
+ --repo-id Leacb4/gap-clip \
249
+ --text "red summer dress"
250
+
251
+ # Search with image
252
+ python example_usage.py \
253
+ --repo-id Leacb4/gap-clip \
254
+ --image path/to/image.jpg
255
+
256
+ # Both text and image
257
  python example_usage.py \
258
+ --repo-id Leacb4/gap-clip \
259
+ --text "blue denim jeans" \
260
  --image path/to/image.jpg
261
  ```
262
 
263
+ This script demonstrates:
264
+ - Loading models from HuggingFace Hub
265
+ - Extracting text and image embeddings
266
+ - Accessing color and hierarchy subspaces
267
+ - Measuring alignment quality with specialized models
268
+
269
  ## 🎯 Model Training
270
 
271
  ### Train the Color Model
 
296
 
297
  ### Train the Main CLIP Model
298
 
299
+ The main model trains with both specialized models using an enhanced contrastive loss.
300
 
301
+ **Option 1: Train with optimized hyperparameters (recommended)**:
302
  ```bash
303
+ python train_main_model.py
304
  ```
305
+ This uses hyperparameters optimized with Optuna (Trial 29, validation loss ~0.1129).
306
 
307
+ **Option 2: Train with default parameters**:
308
+ ```bash
309
+ python main_model.py
310
+ ```
311
+ This runs the main training loop with manually configured parameters.
312
+
313
+ **Default Training Parameters** (in `main_model.py`):
314
+ - `num_epochs = 20` : Number of training epochs
315
+ - `learning_rate = 1.5e-5` : Learning rate with AdamW optimizer
316
+ - `temperature = 0.09` : Temperature for softer contrastive learning
317
+ - `alignment_weight = 0.2` : Weight for color/hierarchy alignment loss
318
+ - `weight_decay = 5e-4` : L2 regularization to prevent overfitting
319
  - `batch_size = 32` : Batch size
320
+ - `subset_size = 20000` : Dataset size for better generalization
321
+ - `reference_weight = 0.1` : Weight for base CLIP regularization
322
+
323
+ **Enhanced Loss Function**:
324
+
325
+ The training uses `enhanced_contrastive_loss` which combines:
326
 
327
+ 1. **Triple Contrastive Loss** (weighted):
328
+ - Text-Image alignment (70%)
329
+ - Text-Attributes alignment (15%)
330
+ - Image-Attributes alignment (15%)
 
 
331
 
332
+ 2. **Direct Alignment Loss** (combines color & hierarchy):
333
+ - MSE loss between main model color dims (0-15) and color model embeddings
334
+ - MSE loss between main model hierarchy dims (16-79) and hierarchy model embeddings
335
+ - Cosine similarity losses for both color and hierarchy
336
+ - Applied to both text and image embeddings
337
+
338
+ 3. **Reference Model Loss** (optional):
339
+ - Keeps text embeddings close to base CLIP
340
+ - Improves cross-domain generalization
341
+
342
+ **Training Features**:
343
+ - Enhanced data augmentation (rotation, color jitter, blur, affine transforms)
344
+ - Gradient clipping (max_norm=1.0) to prevent exploding gradients
345
+ - ReduceLROnPlateau scheduler (patience=3, factor=0.5)
346
+ - Early stopping (patience=7)
347
+ - Automatic best model saving with checkpoints
348
+ - Detailed metrics logging (alignment losses, cosine similarities)
349
+ - Overfitting detection and warnings
350
+ - Training curves visualization with 3 plots (losses, overfitting gap, comparison)
351
+
352
+ ### Hyperparameter Optimization
353
+
354
+ The project includes Optuna-based hyperparameter optimization in `optuna/`:
355
+
356
+ ```bash
357
+ cd optuna
358
+ python optuna_optimisation.py
359
+ ```
360
+
361
+ This optimizes:
362
+ - Learning rate
363
+ - Temperature for contrastive loss
364
+ - Alignment weight
365
+ - Weight decay
366
+
367
+ Results are saved in `optuna_study.pkl` and visualizations in `optuna_optimization_history.png` and `optuna_param_importances.png`.
368
+
369
+ The best hyperparameters from Optuna optimization are used in `train_main_model.py`.
370
 
371
  ## 📊 Models
372
 
 
384
  - **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
385
  - **Usage** : Classify and encode categorical hierarchy
386
 
387
+ ### Main CLIP Model (GAP-CLIP)
388
 
389
  - **Architecture** : CLIP ViT-B/32 (LAION)
390
+ - **Base Model** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
391
+ - **Training Approach** : Enhanced contrastive loss with direct attribute alignment
392
+ - **Embedding Dimensions** : 512 total
393
+ - Color subspace: dims 0-15 (16 dims)
394
+ - Hierarchy subspace: dims 16-79 (64 dims)
395
+ - General CLIP: dims 80-511 (432 dims)
396
+ - **Training Dataset** : 20,000 fashion items with color and hierarchy annotations
397
+ - **Validation Split** : 80/20 train-validation split
398
+ - **Optimizer** : AdamW with weight decay (5e-4)
399
+ - **Best Checkpoint** : Automatically saved based on validation loss
400
  - **Features** :
401
+ - Multi-modal text-image search
402
+ - Guaranteed attribute positioning (GAP) in specific dimensions
403
+ - Direct alignment with specialized color and hierarchy models
404
+ - Maintains general CLIP capabilities for cross-domain tasks
405
+ - Reduced overfitting through augmentation and regularization
406
 
407
  ## 🔍 Advanced Usage Examples
408
 
409
  ### Search with Combined Embeddings
410
 
411
  ```python
412
+ import torch
413
  import torch.nn.functional as F
414
 
415
  # Text query
 
420
  # Main model embeddings
421
  with torch.no_grad():
422
  outputs = main_model(**text_inputs)
423
+ text_features = outputs.text_embeds # Shape: [1, 512]
424
 
425
  # Extract specialized embeddings from main model
426
+ main_color_emb = text_features[:, :16] # Color dimensions (0-15)
427
+ main_hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions (16-79)
428
+ main_clip_emb = text_features[:, 80:] # General CLIP dimensions (80-511)
429
 
430
  # Compare with specialized models
431
  color_emb = color_model.get_text_embeddings([text_query])
432
  hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
433
 
434
+ # Measure alignment quality
435
  color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
436
  hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
437
+
438
+ print(f"Color alignment: {color_similarity.item():.4f}")
439
+ print(f"Hierarchy alignment: {hierarchy_similarity.item():.4f}")
440
+
441
+ # For search, you can use different strategies:
442
+ # 1. Use full embeddings for general search
443
+ # 2. Use color subspace for color-specific search
444
+ # 3. Use hierarchy subspace for category search
445
+ # 4. Weighted combination of subspaces
446
  ```
447
 
448
  ### Search in an Image Database
449
 
450
  ```python
451
  import numpy as np
452
+ import torch
453
+ import torch.nn.functional as F
454
+ from tqdm import tqdm
455
 
456
+ # Step 1: Pre-compute image embeddings (do this once)
457
  image_paths = [...] # List of image paths
458
  image_features_list = []
459
 
460
+ print("Computing image embeddings...")
461
+ for img_path in tqdm(image_paths):
462
  image = Image.open(img_path).convert("RGB")
463
  image_inputs = processor(images=[image], return_tensors="pt")
464
  image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
465
 
466
  with torch.no_grad():
467
  outputs = main_model(**image_inputs)
468
+ features = outputs.image_embeds # Shape: [1, 512]
469
+ image_features_list.append(features.cpu())
470
 
471
+ # Stack all features
472
+ image_features = torch.cat(image_features_list, dim=0) # Shape: [N, 512]
473
 
474
+ # Step 2: Search with text query
475
  query = "red dress"
476
+ text_inputs = processor(text=[query], padding=True, return_tensors="pt")
477
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
478
+
479
+ with torch.no_grad():
480
+ outputs = main_model(**text_inputs)
481
+ text_features = outputs.text_embeds # Shape: [1, 512]
482
 
483
+ # Step 3: Calculate similarities
484
+ # Normalize embeddings for cosine similarity
485
+ text_features_norm = F.normalize(text_features, dim=-1)
486
+ image_features_norm = F.normalize(image_features.to(device), dim=-1)
 
 
487
 
488
+ # Compute cosine similarities
489
+ similarities = (text_features_norm @ image_features_norm.T).squeeze(0) # Shape: [N]
490
+
491
+ # Step 4: Get top-k results
492
  top_k = 10
493
+ top_scores, top_indices = similarities.topk(top_k, largest=True)
494
+
495
+ # Display results
496
+ print(f"\nTop {top_k} results for query: '{query}'")
497
+ for i, (idx, score) in enumerate(zip(top_indices, top_scores)):
498
+ print(f"{i+1}. {image_paths[idx]} (similarity: {score.item():.4f})")
499
+
500
+ # Optional: Filter by color or hierarchy
501
+ # Extract color embeddings from query
502
+ query_color_emb = text_features[:, :16]
503
+ # Extract hierarchy embeddings from query
504
+ query_hierarchy_emb = text_features[:, 16:80]
505
+ # Use these for more targeted search
506
  ```
507
 
508
  ## 📝 Evaluation
509
 
510
+ ### Comprehensive Model Evaluation
511
+
512
+ The main evaluation script `evaluation/main_model_evaluation.py` provides extensive testing across multiple datasets:
513
+
514
+ ```bash
515
+ python evaluation/main_model_evaluation.py
516
+ ```
517
+
518
+ **Evaluation Datasets**:
519
+ 1. **Fashion-MNIST** (~10,000 samples) - Grayscale fashion items
520
+ 2. **KAGL Marqo** (HuggingFace dataset) - Real fashion images with metadata
521
+ 3. **Local Validation Dataset** - Custom validation set with local images
522
+
523
+ **Evaluation Metrics**:
524
+
525
+ For each dataset, the evaluation measures:
526
+
527
+ 1. **Color Embeddings Performance** (dimensions 0-15):
528
+ - Nearest Neighbor (NN) Accuracy: Classification accuracy using nearest neighbor
529
+ - Centroid Accuracy: Classification using cluster centroids
530
+ - Separation Score: How well color embeddings separate different classes
531
+
532
+ 2. **Hierarchy Embeddings Performance** (dimensions 16-79):
533
+ - Nearest Neighbor (NN) Accuracy: Classification accuracy for fashion categories
534
+ - Centroid Accuracy: Cluster-based classification
535
+ - Separation Score: Class separation quality
536
+
537
+ 3. **Full Embeddings Performance** (all 512 dimensions):
538
+ - Evaluates complete embedding space
539
+ - Compares subspace (color/hierarchy) vs full embedding effectiveness
540
+
541
+ **Baseline Comparison**:
542
+
543
+ The evaluation includes comparison against `patrickjohncyh/fashion-clip`:
544
+ - Direct performance comparison on the same datasets
545
+ - Improvement metrics calculation
546
+ - Statistical significance analysis
547
+
548
+ **Key Evaluation Functions**:
549
+ - `evaluate_fashion_mnist()` : Test on Fashion-MNIST dataset
550
+ - `evaluate_kaggle_marqo()` : Test on real fashion images
551
+ - `evaluate_local_validation()` : Test on local validation set
552
+ - `evaluate_baseline_fashion_mnist()` : Baseline model on Fashion-MNIST
553
+ - `evaluate_baseline_kaggle_marqo()` : Baseline model on KAGL
554
+ - `evaluate_full_embeddings()` : Test complete 512D space
555
+ - `analyze_baseline_vs_trained_performance()` : Comparative analysis
556
+ - `compare_subspace_vs_full_embeddings()` : Subspace effectiveness
557
+
558
+ **Visualization Outputs** (saved in analysis directory):
559
+ - Confusion matrices for color and hierarchy classification
560
+ - t-SNE projections of embeddings
561
+ - Similarity heatmaps
562
+ - Performance comparison charts
563
+
564
+ **Other Evaluation Scripts**:
565
+ - `evaluate_color_embeddings.py` : Focused color embeddings evaluation
566
+ - `fashion_search.py` : Interactive fashion search tests
567
+ - `hierarchy_evaluation.py` : Hierarchy classification analysis
568
+ - `0_shot_classification.py` : Zero-shot classification tests
569
+
570
+
571
+ ## 📊 Performance & Results
572
+
573
+ The evaluation framework (`main_model_evaluation.py`) tests the model across three datasets with comparison to a baseline fashion CLIP model.
574
+
575
+ ### Evaluation Metrics
576
+
577
+ **Color Classification** (dimensions 0-15):
578
+ - Nearest Neighbor Accuracy
579
+ - Centroid-based Accuracy
580
+ - Separation Score (class separability)
581
+
582
+ **Hierarchy Classification** (dimensions 16-79):
583
+ - Nearest Neighbor Accuracy
584
+ - Centroid-based Accuracy
585
+ - Separation Score
586
+
587
+ **Full Embedding Quality** (all 512 dims):
588
+ - Tests whether the full space maintains performance
589
+ - Compares subspace vs full embedding effectiveness
590
+
591
+ ### Datasets Used for Evaluation
592
 
593
+ 1. **Fashion-MNIST**: 10,000 grayscale fashion item images
594
+ - 10 categories (T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)
595
+ - Mapped to model's hierarchy classes
596
 
597
+ 2. **KAGL Marqo Dataset**: Real-world fashion images from HuggingFace
598
+ - Diverse fashion items with rich metadata
599
+ - Color and category annotations
600
+ - Realistic product images
601
+
602
+ 3. **Local Validation Set**: Custom validation dataset
603
+ - Fashion items with local image paths
604
+ - Annotated with colors and hierarchies
605
+ - Domain-specific evaluation
606
+
607
+ ### Comparative Analysis
608
+
609
+ The evaluation includes:
610
+ - **Baseline comparison**: GAP-CLIP vs `patrickjohncyh/fashion-clip`
611
+ - **Subspace analysis**: Dedicated dimensions (0-79) vs full space (0-511)
612
+ - **Cross-dataset generalization**: Performance consistency across datasets
613
+ - **Alignment quality**: How well specialized dimensions match expert models
614
+
615
+ All visualizations (confusion matrices, t-SNE plots, heatmaps) are automatically saved in the analysis directory.
616
 
617
  ## 📄 Citation
618
 
619
+ If you use GAP-CLIP in your research, please cite:
620
 
621
  ```bibtex
622
+ @misc{gap-clip-2024,
623
+ title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
624
+ author={Sarfati, Lea Attia},
625
  year={2024},
626
+ note={A multi-loss framework combining contrastive learning with direct attribute alignment},
627
+ howpublished={\url{https://huggingface.co/Leacb4/gap-clip}},
628
+ abstract={GAP-CLIP introduces a novel training approach that guarantees specific embedding
629
+ dimensions encode color (dims 0-15) and hierarchy (dims 16-79) information through
630
+ direct alignment with specialized models, while maintaining full CLIP capabilities
631
+ in the remaining dimensions (80-511).}
632
  }
633
  ```
634
 
635
+ ### Key Contributions
636
+
637
+ - **Guaranteed Attribute Positioning**: Specific dimensions reliably encode color and hierarchy
638
+ - **Multi-Loss Training**: Combines contrastive learning with MSE and cosine alignment losses
639
+ - **Specialized Model Alignment**: Direct supervision from expert color and hierarchy models
640
+ - **Preserved Generalization**: Maintains base CLIP capabilities for cross-domain tasks
641
+ - **Comprehensive Evaluation**: Tested across multiple datasets with baseline comparisons
642
+
643
+ ## ❓ FAQ & Troubleshooting
644
+
645
+ ### Q: What are the minimum hardware requirements?
646
+
647
+ **A**:
648
+ - **GPU**: Recommended for training (CUDA or MPS). CPU training is very slow.
649
+ - **RAM**: Minimum 16GB, recommended 32GB for training
650
+ - **Storage**: ~5GB for models and datasets
651
+
652
+ ### Q: Why are my embeddings not aligned?
653
+
654
+ **A**: Check that:
655
+ 1. You're using the correct dimension ranges (0-15 for color, 16-79 for hierarchy)
656
+ 2. The model was trained with alignment_weight > 0
657
+ 3. Color and hierarchy models were properly loaded during training
658
+
659
+ ### Q: How do I use only the color or hierarchy subspace for search?
660
+
661
+ **A**:
662
+ ```python
663
+ # Extract and use only color embeddings
664
+ text_color_emb = text_features[:, :16]
665
+ image_color_emb = image_features[:, :16]
666
+ color_similarity = F.cosine_similarity(text_color_emb, image_color_emb)
667
+
668
+ # Extract and use only hierarchy embeddings
669
+ text_hierarchy_emb = text_features[:, 16:80]
670
+ image_hierarchy_emb = image_features[:, 16:80]
671
+ hierarchy_similarity = F.cosine_similarity(text_hierarchy_emb, image_hierarchy_emb)
672
+ ```
673
+
674
+ ### Q: Can I add more attributes beyond color and hierarchy?
675
+
676
+ **A**: Yes! The architecture is extensible:
677
+ 1. Train a new specialized model for your attribute
678
+ 2. Reserve additional dimensions in the embedding space
679
+ 3. Add alignment losses for these dimensions in `enhanced_contrastive_loss`
680
+ 4. Update `config.py` with new dimension ranges
681
+
682
+ ### Q: How do I evaluate on my own dataset?
683
+
684
+ **A**:
685
+ 1. Format your dataset as CSV with columns: `text`, `color`, `hierarchy`, `local_image_path`
686
+ 2. Update `config.local_dataset_path` in `config.py`
687
+ 3. Run the evaluation: `python evaluation/main_model_evaluation.py`
688
+
689
+ ### Q: Training loss is decreasing but validation loss is increasing. What should I do?
690
+
691
+ **A**: This indicates overfitting. Try:
692
+ - Increase `weight_decay` (e.g., from 5e-4 to 1e-3)
693
+ - Reduce `alignment_weight` (e.g., from 0.2 to 0.1)
694
+ - Increase dataset size (`subset_size`)
695
+ - Add more data augmentation in `CustomDataset`
696
+ - Enable or increase early stopping patience
697
+
698
+ ### Q: Can I fine-tune GAP-CLIP on a specific domain?
699
+
700
+ **A**: Yes! Load the checkpoint and continue training:
701
+ ```python
702
+ checkpoint = torch.load('models/gap_clip.pth')
703
+ model.load_state_dict(checkpoint['model_state_dict'])
704
+ # Continue training with your domain-specific data
705
+ ```
706
+
707
  ## 🤝 Contributing
708
 
709
  Contributions are welcome! Feel free to open an issue or a pull request.