Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -1,990 +1,68 @@
|
|
| 1 |
-
# GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings
|
| 2 |
-
|
| 3 |
-
[](https://www.python.org/downloads/)
|
| 4 |
-
[](https://pytorch.org/)
|
| 5 |
-
[](https://opensource.org/licenses/MIT)
|
| 6 |
-
[](https://huggingface.co/Leacb4/gap-clip)
|
| 7 |
-
|
| 8 |
-
**Advanced multimodal fashion search model combining specialized color embeddings, hierarchical category embeddings, and CLIP for intelligent fashion item retrieval.**
|
| 9 |
-
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
### Try It Now (< 2 minutes)
|
| 29 |
-
|
| 30 |
-
```python
|
| 31 |
-
from example_usage import load_models_from_hf
|
| 32 |
-
|
| 33 |
-
# Load pre-trained models from Hugging Face
|
| 34 |
-
models = load_models_from_hf("Leacb4/gap-clip")
|
| 35 |
-
|
| 36 |
-
# Search with text
|
| 37 |
-
import torch.nn.functional as F
|
| 38 |
-
text_query = "red summer dress"
|
| 39 |
-
text_inputs = models['processor'](text=[text_query], padding=True, return_tensors="pt")
|
| 40 |
-
text_inputs = {k: v.to(models['device']) for k, v in text_inputs.items()}
|
| 41 |
-
|
| 42 |
-
with torch.no_grad():
|
| 43 |
-
text_features = models['main_model'](**text_inputs).text_embeds
|
| 44 |
-
|
| 45 |
-
# Extract specialized embeddings
|
| 46 |
-
color_emb = text_features[:, :16] # Color (dims 0-15)
|
| 47 |
-
category_emb = text_features[:, 16:80] # Category (dims 16-79)
|
| 48 |
-
general_emb = text_features[:, 80:] # General CLIP (dims 80-511)
|
| 49 |
-
|
| 50 |
-
print(f"✅ Successfully extracted embeddings!")
|
| 51 |
-
print(f" Color: {color_emb.shape}, Category: {category_emb.shape}, General: {general_emb.shape}")
|
| 52 |
-
```
|
| 53 |
-
|
| 54 |
---
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
This project implements an advanced fashion search system based on CLIP, with three specialized models:
|
| 59 |
-
|
| 60 |
-
1. **Color Model** (`color_model.pt`) : Specialized CLIP model for extracting reduced-size color embeddings from text and images
|
| 61 |
-
2. **Hierarchy Model** (`hierarchy_model.pth`) : Model for classifying and encoding reduced-size categorical hierarchy of fashion items
|
| 62 |
-
3. **Main CLIP Model** (`gap_clip.pth`) : Main CLIP model based on LAION, trained with color and hierarchy embeddings
|
| 63 |
-
|
| 64 |
-
### Architecture
|
| 65 |
-
|
| 66 |
-
The main model's embedding structure:
|
| 67 |
-
- **Dimensions 0-15** (16 dims): Color embeddings aligned with specialized color model
|
| 68 |
-
- **Dimensions 16-79** (64 dims): Hierarchy embeddings aligned with specialized hierarchy model
|
| 69 |
-
- **Dimensions 80-511** (432 dims): Standard CLIP embeddings for general visual-semantic understanding
|
| 70 |
-
|
| 71 |
-
**Total: 512 dimensions** per embedding (text or image)
|
| 72 |
-
|
| 73 |
-
**Key Innovation**: The first 80 dimensions are explicitly trained to align with specialized models through direct MSE and cosine similarity losses, ensuring guaranteed attribute positioning (GAP) while maintaining full CLIP capabilities in the remaining dimensions.
|
| 74 |
-
|
| 75 |
-
### Loss Functions
|
| 76 |
-
|
| 77 |
-
**1. Enhanced Contrastive Loss** (`enhanced_contrastive_loss`):
|
| 78 |
-
|
| 79 |
-
Combines multiple objectives:
|
| 80 |
-
- **Original Triple Loss**: Text-image-attributes contrastive learning
|
| 81 |
-
- **Color Alignment**: Forces dims 0-15 to match color model embeddings
|
| 82 |
-
- **Hierarchy Alignment**: Forces dims 16-79 to match hierarchy model embeddings
|
| 83 |
-
- **Reference Loss**: Optional regularization to stay close to base CLIP
|
| 84 |
-
|
| 85 |
-
**2. Alignment Components**:
|
| 86 |
-
```python
|
| 87 |
-
# Color alignment (text & image)
|
| 88 |
-
color_text_mse = F.mse_loss(main_color_dims, color_model_emb)
|
| 89 |
-
color_text_cosine = 1 - F.cosine_similarity(main_color_dims, color_model_emb).mean()
|
| 90 |
-
|
| 91 |
-
# Hierarchy alignment (text & image)
|
| 92 |
-
hierarchy_text_mse = F.mse_loss(main_hierarchy_dims, hierarchy_model_emb)
|
| 93 |
-
hierarchy_text_cosine = 1 - F.cosine_similarity(main_hierarchy_dims, hierarchy_model_emb).mean()
|
| 94 |
-
|
| 95 |
-
# Combined alignment
|
| 96 |
-
alignment_loss = (color_alignment + hierarchy_alignment) / 2
|
| 97 |
-
```
|
| 98 |
-
|
| 99 |
-
**3. Final Loss**:
|
| 100 |
-
```python
|
| 101 |
-
total_loss = (1 - α) * contrastive_loss + α * alignment_loss + β * reference_loss
|
| 102 |
-
```
|
| 103 |
-
Where:
|
| 104 |
-
- α (alignment_weight) = 0.2 : Balances contrastive and alignment objectives
|
| 105 |
-
- β (reference_weight) = 0.1 : Keeps text space close to base CLIP
|
| 106 |
-
|
| 107 |
-
## 🚀 Installation
|
| 108 |
-
|
| 109 |
-
### Prerequisites
|
| 110 |
-
|
| 111 |
-
- Python 3.8 or higher
|
| 112 |
-
- PyTorch 2.0+ (with CUDA for GPU support, optional but recommended)
|
| 113 |
-
- 16GB RAM minimum (32GB recommended for training)
|
| 114 |
-
- ~5GB disk space for models and data
|
| 115 |
-
|
| 116 |
-
### Method 1: Install as Package (Recommended)
|
| 117 |
-
|
| 118 |
-
```bash
|
| 119 |
-
# Clone repository
|
| 120 |
-
git clone https://github.com/Leacb4/gap-clip.git
|
| 121 |
-
cd gap-clip
|
| 122 |
-
|
| 123 |
-
# Install in development mode
|
| 124 |
-
pip install -e .
|
| 125 |
-
|
| 126 |
-
# Or install with optional dependencies
|
| 127 |
-
pip install -e ".[dev]" # With development tools
|
| 128 |
-
pip install -e ".[optuna]" # With hyperparameter optimization
|
| 129 |
-
pip install -e ".[all]" # With all extras
|
| 130 |
-
```
|
| 131 |
-
|
| 132 |
-
### Method 2: Install Dependencies Only
|
| 133 |
-
|
| 134 |
-
```bash
|
| 135 |
-
pip install -r requirements.txt
|
| 136 |
-
```
|
| 137 |
-
|
| 138 |
-
### Method 3: From Hugging Face (Model Only)
|
| 139 |
-
|
| 140 |
-
```python
|
| 141 |
-
from example_usage import load_models_from_hf
|
| 142 |
-
models = load_models_from_hf("Leacb4/gap-clip")
|
| 143 |
-
```
|
| 144 |
-
|
| 145 |
-
### Main Dependencies
|
| 146 |
-
|
| 147 |
-
| Package | Version | Purpose |
|
| 148 |
-
|---------|---------|---------|
|
| 149 |
-
| `torch` | ≥2.0.0 | Deep learning framework |
|
| 150 |
-
| `transformers` | ≥4.30.0 | Hugging Face CLIP models |
|
| 151 |
-
| `huggingface-hub` | ≥0.16.0 | Model download/upload |
|
| 152 |
-
| `pillow` | ≥9.0.0 | Image processing |
|
| 153 |
-
| `pandas` | ≥1.5.0 | Data manipulation |
|
| 154 |
-
| `scikit-learn` | ≥1.3.0 | ML metrics & evaluation |
|
| 155 |
-
| `tqdm` | ≥4.65.0 | Progress bars |
|
| 156 |
-
| `matplotlib` | ≥3.7.0 | Visualization |
|
| 157 |
-
|
| 158 |
-
### Verify Installation
|
| 159 |
-
|
| 160 |
-
```python
|
| 161 |
-
# Test that everything works
|
| 162 |
-
import config
|
| 163 |
-
config.print_config()
|
| 164 |
-
|
| 165 |
-
# Check device
|
| 166 |
-
print(f"Using device: {config.device}")
|
| 167 |
-
```
|
| 168 |
-
|
| 169 |
-
## 📁 Project Structure
|
| 170 |
-
|
| 171 |
-
```
|
| 172 |
-
.
|
| 173 |
-
├── color_model.py # Color model architecture and training
|
| 174 |
-
├── hierarchy_model.py # Hierarchy model architecture and training
|
| 175 |
-
├── main_model.py # Main GAP-CLIP model with enhanced loss functions
|
| 176 |
-
├── train_main_model.py # Training script with optimized hyperparameters
|
| 177 |
-
├── config.py # Configuration for paths and parameters
|
| 178 |
-
├── example_usage.py # Usage examples and HuggingFace loading
|
| 179 |
-
├── tokenizer_vocab.json # Tokenizer vocabulary for color model
|
| 180 |
-
├── models/
|
| 181 |
-
│ ├── color_model.pt # Trained color model checkpoint
|
| 182 |
-
│ ├── hierarchy_model.pth # Trained hierarchy model checkpoint
|
| 183 |
-
│ └── gap_clip.pth # Main GAP-CLIP model checkpoint
|
| 184 |
-
├── evaluation/ # Comprehensive evaluation scripts
|
| 185 |
-
│ ├── main_model_evaluation.py # Main evaluation with 3 datasets
|
| 186 |
-
│ ├── evaluate_color_embeddings.py # Color embedding analysis
|
| 187 |
-
│ ├── hierarchy_evaluation.py # Hierarchy classification tests
|
| 188 |
-
│ ├── fashion_search.py # Interactive search demos
|
| 189 |
-
│ ├── 0_shot_classification.py # Zero-shot classification
|
| 190 |
-
│ ├── heatmap_color_similarities.py # Color similarity visualization
|
| 191 |
-
│ ├── tsne_images.py # t-SNE embedding visualization
|
| 192 |
-
│ └── basic_test_generalized.py # Basic functionality tests
|
| 193 |
-
├── data/
|
| 194 |
-
│ ├── data_with_local_paths.csv # Training dataset with annotations
|
| 195 |
-
│ ├── fashion-mnist_test.csv # Fashion-MNIST evaluation data
|
| 196 |
-
│ ├── download_data.py # Dataset download utilities
|
| 197 |
-
│ └── get_csv_from_chunks.py # Dataset preprocessing
|
| 198 |
-
├── optuna/ # Hyperparameter optimization
|
| 199 |
-
│ ├── optuna_optimisation.py # Optuna optimization script
|
| 200 |
-
│ ├── optuna_study.pkl # Saved optimization study
|
| 201 |
-
│ ├── optuna_results.txt # Best hyperparameters
|
| 202 |
-
│ ├── optuna_optimization_history.png # Optimization visualization
|
| 203 |
-
│ ├── optuna_param_importances.png # Parameter importance plot
|
| 204 |
-
│ └── optuna_guide.md # Optuna usage guide
|
| 205 |
-
├── upload_hf/ # HuggingFace Hub upload utilities
|
| 206 |
-
│ ├── upload_to_huggingface.py # Professional upload script (rewritten)
|
| 207 |
-
│ └── README_UPLOAD.md # Complete upload guide
|
| 208 |
-
├── requirements.txt # Python dependencies (organized)
|
| 209 |
-
├── setup.py # Package installation (NEW)
|
| 210 |
-
├── __init__.py # Package initialization (NEW)
|
| 211 |
-
├── .gitignore # Git ignore rules (NEW)
|
| 212 |
-
└── README.md # This documentation
|
| 213 |
-
```
|
| 214 |
-
|
| 215 |
-
### Key Files Description
|
| 216 |
-
|
| 217 |
-
**Core Model Files**:
|
| 218 |
-
- `color_model.py`: ResNet18-based color embedding model (16 dims) - Bug fixed ✨
|
| 219 |
-
- `hierarchy_model.py`: ResNet18-based hierarchy classification model (64 dims)
|
| 220 |
-
- `main_model.py`: GAP-CLIP implementation with enhanced contrastive loss - Bug fixed ✨
|
| 221 |
-
- `train_main_model.py`: Training with Optuna-optimized hyperparameters - Improved ✨
|
| 222 |
-
|
| 223 |
-
**Configuration & Setup** (✨ New/Improved):
|
| 224 |
-
- `config.py`: ✨ Completely rewritten with type hints, auto device detection, validation utilities
|
| 225 |
-
- `setup.py`: ✨ NEW - Professional package installer with CLI entry points
|
| 226 |
-
- `__init__.py`: ✨ NEW - Package initialization for easy imports
|
| 227 |
-
- `.gitignore`: ✨ NEW - Comprehensive Git ignore rules
|
| 228 |
-
- `requirements.txt`: ✨ Improved - Organized with comments and categories
|
| 229 |
-
- `tokenizer_vocab.json`: Vocabulary for color model's text encoder
|
| 230 |
-
|
| 231 |
-
**Upload Tools** (✨ Rewritten):
|
| 232 |
-
- `upload_hf/upload_to_huggingface.py`: ✨ Complete professional rewrite with:
|
| 233 |
-
- Object-oriented design
|
| 234 |
-
- Multiple authentication methods
|
| 235 |
-
- Category-based uploads (models, code, docs, etc.)
|
| 236 |
-
- Progress tracking
|
| 237 |
-
- Automatic model card generation
|
| 238 |
-
- Detailed error handling
|
| 239 |
-
- `upload_hf/README_UPLOAD.md`: ✨ NEW - Complete upload guide
|
| 240 |
-
|
| 241 |
-
**Evaluation Suite**:
|
| 242 |
-
- `main_model_evaluation.py`: Comprehensive evaluation across Fashion-MNIST, KAGL, and local datasets
|
| 243 |
-
- `evaluation/run_all_evaluations.py`: ✨ NEW - Automated evaluation runner with reports
|
| 244 |
-
- Other scripts provide specialized analysis (color, hierarchy, search, t-SNE, etc.)
|
| 245 |
-
|
| 246 |
-
**Training Data**:
|
| 247 |
-
- `data_with_local_paths.csv`: Main training dataset with text, color, hierarchy, and image paths
|
| 248 |
-
- `fashion-mnist_test.csv`: Evaluation dataset for zero-shot generalization testing
|
| 249 |
-
|
| 250 |
-
**CLI Commands** (✨ New):
|
| 251 |
-
After installation with `pip install -e .`, you can use:
|
| 252 |
-
```bash
|
| 253 |
-
gap-clip-train # Start training
|
| 254 |
-
gap-clip-example # Run usage examples
|
| 255 |
-
```
|
| 256 |
-
|
| 257 |
-
## 🔧 Configuration
|
| 258 |
-
|
| 259 |
-
Main parameters are defined in `config.py` (✨ completely rewritten with improvements):
|
| 260 |
-
|
| 261 |
-
```python
|
| 262 |
-
import config
|
| 263 |
-
|
| 264 |
-
# Automatic device detection (CUDA > MPS > CPU)
|
| 265 |
-
device = config.device # Automatically selects best available device
|
| 266 |
-
|
| 267 |
-
# Embedding dimensions
|
| 268 |
-
color_emb_dim = config.color_emb_dim # 16 dims (0-15)
|
| 269 |
-
hierarchy_emb_dim = config.hierarchy_emb_dim # 64 dims (16-79)
|
| 270 |
-
main_emb_dim = config.main_emb_dim # 512 dims total
|
| 271 |
-
|
| 272 |
-
# Default training hyperparameters
|
| 273 |
-
batch_size = config.DEFAULT_BATCH_SIZE # 32
|
| 274 |
-
learning_rate = config.DEFAULT_LEARNING_RATE # 1.5e-5
|
| 275 |
-
temperature = config.DEFAULT_TEMPERATURE # 0.09
|
| 276 |
-
|
| 277 |
-
# Utility functions
|
| 278 |
-
config.print_config() # Print current configuration
|
| 279 |
-
config.validate_paths() # Validate that all files exist
|
| 280 |
-
```
|
| 281 |
-
|
| 282 |
-
### New Features in config.py ✨
|
| 283 |
-
|
| 284 |
-
- **Automatic device detection**: Selects CUDA > MPS > CPU automatically
|
| 285 |
-
- **Type hints**: Full type annotations for better IDE support
|
| 286 |
-
- **Validation**: `validate_paths()` checks all model files exist
|
| 287 |
-
- **Print utility**: `print_config()` shows current settings
|
| 288 |
-
- **Constants**: Pre-defined default hyperparameters
|
| 289 |
-
- **Documentation**: Comprehensive docstrings for all settings
|
| 290 |
-
|
| 291 |
-
### Model Paths
|
| 292 |
-
|
| 293 |
-
Default paths configured in `config.py`:
|
| 294 |
-
- `models/color_model.pt` : Trained color model checkpoint
|
| 295 |
-
- `models/hierarchy_model.pth` : Trained hierarchy model checkpoint
|
| 296 |
-
- `models/gap_clip.pth` : Main GAP-CLIP model checkpoint
|
| 297 |
-
- `tokenizer_vocab.json` : Tokenizer vocabulary for color model
|
| 298 |
-
- `data_with_local_paths.csv` : Training/validation dataset
|
| 299 |
-
|
| 300 |
-
### Dataset Format
|
| 301 |
-
|
| 302 |
-
The training dataset CSV should contain:
|
| 303 |
-
- `text`: Text description of the fashion item
|
| 304 |
-
- `color`: Color label (e.g., "red", "blue", "black")
|
| 305 |
-
- `hierarchy`: Category label (e.g., "dress", "shirt", "shoes")
|
| 306 |
-
- `local_image_path`: Path to the image file
|
| 307 |
-
|
| 308 |
-
Example:
|
| 309 |
-
```csv
|
| 310 |
-
text,color,hierarchy,local_image_path
|
| 311 |
-
"red summer dress with floral pattern",red,dress,data/images/001.jpg
|
| 312 |
-
"blue denim jeans casual style",blue,jeans,data/images/002.jpg
|
| 313 |
-
```
|
| 314 |
-
|
| 315 |
-
## 📦 Usage
|
| 316 |
-
|
| 317 |
-
### 1. Load Models from Hugging Face
|
| 318 |
-
|
| 319 |
-
If your models are already uploaded to Hugging Face:
|
| 320 |
-
|
| 321 |
-
```python
|
| 322 |
-
from example_usage import load_models_from_hf
|
| 323 |
-
|
| 324 |
-
# Load all models
|
| 325 |
-
models = load_models_from_hf("your-username/your-model")
|
| 326 |
-
|
| 327 |
-
color_model = models['color_model']
|
| 328 |
-
hierarchy_model = models['hierarchy_model']
|
| 329 |
-
main_model = models['main_model']
|
| 330 |
-
processor = models['processor']
|
| 331 |
-
device = models['device']
|
| 332 |
-
```
|
| 333 |
-
|
| 334 |
-
### 2. Text Search
|
| 335 |
-
|
| 336 |
-
```python
|
| 337 |
-
import torch
|
| 338 |
-
from transformers import CLIPProcessor
|
| 339 |
-
|
| 340 |
-
# Prepare text query
|
| 341 |
-
text_query = "red dress"
|
| 342 |
-
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
| 343 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 344 |
-
|
| 345 |
-
# Get main model embeddings
|
| 346 |
-
with torch.no_grad():
|
| 347 |
-
outputs = main_model(**text_inputs)
|
| 348 |
-
text_features = outputs.text_embeds
|
| 349 |
-
|
| 350 |
-
# Get specialized embeddings
|
| 351 |
-
color_emb = color_model.get_text_embeddings([text_query])
|
| 352 |
-
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 353 |
-
```
|
| 354 |
-
|
| 355 |
-
### 3. Image Search
|
| 356 |
-
|
| 357 |
-
```python
|
| 358 |
-
from PIL import Image
|
| 359 |
-
|
| 360 |
-
# Load image
|
| 361 |
-
image = Image.open("path/to/image.jpg").convert("RGB")
|
| 362 |
-
image_inputs = processor(images=[image], return_tensors="pt")
|
| 363 |
-
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 364 |
-
|
| 365 |
-
# Get embeddings
|
| 366 |
-
with torch.no_grad():
|
| 367 |
-
outputs = main_model(**image_inputs)
|
| 368 |
-
image_features = outputs.image_embeds
|
| 369 |
-
```
|
| 370 |
-
|
| 371 |
-
### 4. Using the Example Script
|
| 372 |
-
|
| 373 |
-
The `example_usage.py` provides ready-to-use examples for loading and using GAP-CLIP:
|
| 374 |
-
|
| 375 |
-
```bash
|
| 376 |
-
# Load from HuggingFace and search with text
|
| 377 |
-
python example_usage.py \
|
| 378 |
-
--repo-id Leacb4/gap-clip \
|
| 379 |
-
--text "red summer dress"
|
| 380 |
-
|
| 381 |
-
# Search with image
|
| 382 |
-
python example_usage.py \
|
| 383 |
-
--repo-id Leacb4/gap-clip \
|
| 384 |
-
--image path/to/image.jpg
|
| 385 |
-
|
| 386 |
-
# Both text and image
|
| 387 |
-
python example_usage.py \
|
| 388 |
-
--repo-id Leacb4/gap-clip \
|
| 389 |
-
--text "blue denim jeans" \
|
| 390 |
-
--image path/to/image.jpg
|
| 391 |
-
```
|
| 392 |
-
|
| 393 |
-
This script demonstrates:
|
| 394 |
-
- Loading models from HuggingFace Hub
|
| 395 |
-
- Extracting text and image embeddings
|
| 396 |
-
- Accessing color and hierarchy subspaces
|
| 397 |
-
- Measuring alignment quality with specialized models
|
| 398 |
-
|
| 399 |
-
## 🎯 Model Training
|
| 400 |
-
|
| 401 |
-
### Train the Color Model
|
| 402 |
-
|
| 403 |
-
```python
|
| 404 |
-
from color_model import ColorCLIP, train_color_model
|
| 405 |
-
|
| 406 |
-
# Configuration
|
| 407 |
-
model = ColorCLIP(vocab_size=10000, embedding_dim=16)
|
| 408 |
-
# ... dataset configuration ...
|
| 409 |
-
|
| 410 |
-
# Training
|
| 411 |
-
train_color_model(model, train_loader, val_loader, num_epochs=20)
|
| 412 |
-
```
|
| 413 |
-
|
| 414 |
-
### Train the Hierarchy Model
|
| 415 |
-
|
| 416 |
-
```python
|
| 417 |
-
from hierarchy_model import Model as HierarchyModel, train_hierarchy_model
|
| 418 |
-
|
| 419 |
-
# Configuration
|
| 420 |
-
model = HierarchyModel(num_hierarchy_classes=10, embed_dim=64)
|
| 421 |
-
# ... dataset configuration ...
|
| 422 |
-
|
| 423 |
-
# Training
|
| 424 |
-
train_hierarchy_model(model, train_loader, val_loader, num_epochs=20)
|
| 425 |
-
```
|
| 426 |
-
|
| 427 |
-
### Train the Main CLIP Model
|
| 428 |
-
|
| 429 |
-
The main model trains with both specialized models using an enhanced contrastive loss.
|
| 430 |
-
|
| 431 |
-
**Option 1: Train with optimized hyperparameters (recommended)**:
|
| 432 |
-
```bash
|
| 433 |
-
python train_main_model.py
|
| 434 |
-
```
|
| 435 |
-
This uses hyperparameters optimized with Optuna (Trial 29, validation loss ~0.1129).
|
| 436 |
-
|
| 437 |
-
**Option 2: Train with default parameters**:
|
| 438 |
-
```bash
|
| 439 |
-
python main_model.py
|
| 440 |
-
```
|
| 441 |
-
This runs the main training loop with manually configured parameters.
|
| 442 |
-
|
| 443 |
-
**Default Training Parameters** (in `main_model.py`):
|
| 444 |
-
- `num_epochs = 20` : Number of training epochs
|
| 445 |
-
- `learning_rate = 1.5e-5` : Learning rate with AdamW optimizer
|
| 446 |
-
- `temperature = 0.09` : Temperature for softer contrastive learning
|
| 447 |
-
- `alignment_weight = 0.2` : Weight for color/hierarchy alignment loss
|
| 448 |
-
- `weight_decay = 5e-4` : L2 regularization to prevent overfitting
|
| 449 |
-
- `batch_size = 32` : Batch size
|
| 450 |
-
- `subset_size = 20000` : Dataset size for better generalization
|
| 451 |
-
- `reference_weight = 0.1` : Weight for base CLIP regularization
|
| 452 |
-
|
| 453 |
-
**Enhanced Loss Function**:
|
| 454 |
-
|
| 455 |
-
The training uses `enhanced_contrastive_loss` which combines:
|
| 456 |
-
|
| 457 |
-
1. **Triple Contrastive Loss** (weighted):
|
| 458 |
-
- Text-Image alignment (70%)
|
| 459 |
-
- Text-Attributes alignment (15%)
|
| 460 |
-
- Image-Attributes alignment (15%)
|
| 461 |
-
|
| 462 |
-
2. **Direct Alignment Loss** (combines color & hierarchy):
|
| 463 |
-
- MSE loss between main model color dims (0-15) and color model embeddings
|
| 464 |
-
- MSE loss between main model hierarchy dims (16-79) and hierarchy model embeddings
|
| 465 |
-
- Cosine similarity losses for both color and hierarchy
|
| 466 |
-
- Applied to both text and image embeddings
|
| 467 |
-
|
| 468 |
-
3. **Reference Model Loss** (optional):
|
| 469 |
-
- Keeps text embeddings close to base CLIP
|
| 470 |
-
- Improves cross-domain generalization
|
| 471 |
-
|
| 472 |
-
**Training Features**:
|
| 473 |
-
- Enhanced data augmentation (rotation, color jitter, blur, affine transforms)
|
| 474 |
-
- Gradient clipping (max_norm=1.0) to prevent exploding gradients
|
| 475 |
-
- ReduceLROnPlateau scheduler (patience=3, factor=0.5)
|
| 476 |
-
- Early stopping (patience=7)
|
| 477 |
-
- Automatic best model saving with checkpoints
|
| 478 |
-
- Detailed metrics logging (alignment losses, cosine similarities)
|
| 479 |
-
- Overfitting detection and warnings
|
| 480 |
-
- Training curves visualization with 3 plots (losses, overfitting gap, comparison)
|
| 481 |
-
|
| 482 |
-
### Hyperparameter Optimization
|
| 483 |
-
|
| 484 |
-
The project includes Optuna-based hyperparameter optimization in `optuna/`:
|
| 485 |
-
|
| 486 |
-
```bash
|
| 487 |
-
cd optuna
|
| 488 |
-
python optuna_optimisation.py
|
| 489 |
-
```
|
| 490 |
-
|
| 491 |
-
This optimizes:
|
| 492 |
-
- Learning rate
|
| 493 |
-
- Temperature for contrastive loss
|
| 494 |
-
- Alignment weight
|
| 495 |
-
- Weight decay
|
| 496 |
-
|
| 497 |
-
Results are saved in `optuna_study.pkl` and visualizations in `optuna_optimization_history.png` and `optuna_param_importances.png`.
|
| 498 |
-
|
| 499 |
-
The best hyperparameters from Optuna optimization are used in `train_main_model.py`.
|
| 500 |
-
|
| 501 |
-
## 📊 Models
|
| 502 |
-
|
| 503 |
-
### Color Model
|
| 504 |
-
|
| 505 |
-
- **Architecture** : ResNet18 (image encoder) + Embedding (text encoder)
|
| 506 |
-
- **Embedding dimension** : 16
|
| 507 |
-
- **Trained on** : Fashion data with color annotations
|
| 508 |
-
- **Usage** : Extract color embeddings from text or images
|
| 509 |
-
|
| 510 |
-
### Hierarchy Model
|
| 511 |
-
|
| 512 |
-
- **Architecture** : ResNet18 (image encoder) + Embedding (hierarchy encoder)
|
| 513 |
-
- **Embedding dimension** : 64
|
| 514 |
-
- **Hierarchy classes** : shirt, dress, pant, shoe, bag, etc.
|
| 515 |
-
- **Usage** : Classify and encode categorical hierarchy
|
| 516 |
-
|
| 517 |
-
### Main CLIP Model (GAP-CLIP)
|
| 518 |
-
|
| 519 |
-
- **Architecture** : CLIP ViT-B/32 (LAION)
|
| 520 |
-
- **Base Model** : `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
|
| 521 |
-
- **Training Approach** : Enhanced contrastive loss with direct attribute alignment
|
| 522 |
-
- **Embedding Dimensions** : 512 total
|
| 523 |
-
- Color subspace: dims 0-15 (16 dims)
|
| 524 |
-
- Hierarchy subspace: dims 16-79 (64 dims)
|
| 525 |
-
- General CLIP: dims 80-511 (432 dims)
|
| 526 |
-
- **Training Dataset** : 20,000 fashion items with color and hierarchy annotations
|
| 527 |
-
- **Validation Split** : 80/20 train-validation split
|
| 528 |
-
- **Optimizer** : AdamW with weight decay (5e-4)
|
| 529 |
-
- **Best Checkpoint** : Automatically saved based on validation loss
|
| 530 |
-
- **Features** :
|
| 531 |
-
- Multi-modal text-image search
|
| 532 |
-
- Guaranteed attribute positioning (GAP) in specific dimensions
|
| 533 |
-
- Direct alignment with specialized color and hierarchy models
|
| 534 |
-
- Maintains general CLIP capabilities for cross-domain tasks
|
| 535 |
-
- Reduced overfitting through augmentation and regularization
|
| 536 |
-
|
| 537 |
-
## 🔍 Advanced Usage Examples
|
| 538 |
-
|
| 539 |
-
### Search with Combined Embeddings
|
| 540 |
-
|
| 541 |
-
```python
|
| 542 |
-
import torch
|
| 543 |
-
import torch.nn.functional as F
|
| 544 |
-
|
| 545 |
-
# Text query
|
| 546 |
-
text_query = "red dress"
|
| 547 |
-
text_inputs = processor(text=[text_query], padding=True, return_tensors="pt")
|
| 548 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 549 |
-
|
| 550 |
-
# Main model embeddings
|
| 551 |
-
with torch.no_grad():
|
| 552 |
-
outputs = main_model(**text_inputs)
|
| 553 |
-
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 554 |
-
|
| 555 |
-
# Extract specialized embeddings from main model
|
| 556 |
-
main_color_emb = text_features[:, :16] # Color dimensions (0-15)
|
| 557 |
-
main_hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions (16-79)
|
| 558 |
-
main_clip_emb = text_features[:, 80:] # General CLIP dimensions (80-511)
|
| 559 |
|
| 560 |
-
|
| 561 |
-
color_emb = color_model.get_text_embeddings([text_query])
|
| 562 |
-
hierarchy_emb = hierarchy_model.get_text_embeddings([text_query])
|
| 563 |
|
| 564 |
-
|
| 565 |
-
color_similarity = F.cosine_similarity(color_emb, main_color_emb, dim=1)
|
| 566 |
-
hierarchy_similarity = F.cosine_similarity(hierarchy_emb, main_hierarchy_emb, dim=1)
|
| 567 |
|
| 568 |
-
|
| 569 |
-
|
|
|
|
|
|
|
| 570 |
|
| 571 |
-
|
| 572 |
-
# 1. Use full embeddings for general search
|
| 573 |
-
# 2. Use color subspace for color-specific search
|
| 574 |
-
# 3. Use hierarchy subspace for category search
|
| 575 |
-
# 4. Weighted combination of subspaces
|
| 576 |
-
```
|
| 577 |
|
| 578 |
-
|
| 579 |
|
| 580 |
```python
|
| 581 |
-
import
|
|
|
|
| 582 |
import torch
|
| 583 |
-
import torch.nn.functional as F
|
| 584 |
-
from tqdm import tqdm
|
| 585 |
-
|
| 586 |
-
# Step 1: Pre-compute image embeddings (do this once)
|
| 587 |
-
image_paths = [...] # List of image paths
|
| 588 |
-
image_features_list = []
|
| 589 |
-
|
| 590 |
-
print("Computing image embeddings...")
|
| 591 |
-
for img_path in tqdm(image_paths):
|
| 592 |
-
image = Image.open(img_path).convert("RGB")
|
| 593 |
-
image_inputs = processor(images=[image], return_tensors="pt")
|
| 594 |
-
image_inputs = {k: v.to(device) for k, v in image_inputs.items()}
|
| 595 |
-
|
| 596 |
-
with torch.no_grad():
|
| 597 |
-
outputs = main_model(**image_inputs)
|
| 598 |
-
features = outputs.image_embeds # Shape: [1, 512]
|
| 599 |
-
image_features_list.append(features.cpu())
|
| 600 |
-
|
| 601 |
-
# Stack all features
|
| 602 |
-
image_features = torch.cat(image_features_list, dim=0) # Shape: [N, 512]
|
| 603 |
-
|
| 604 |
-
# Step 2: Search with text query
|
| 605 |
-
query = "red dress"
|
| 606 |
-
text_inputs = processor(text=[query], padding=True, return_tensors="pt")
|
| 607 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 608 |
-
|
| 609 |
-
with torch.no_grad():
|
| 610 |
-
outputs = main_model(**text_inputs)
|
| 611 |
-
text_features = outputs.text_embeds # Shape: [1, 512]
|
| 612 |
-
|
| 613 |
-
# Step 3: Calculate similarities
|
| 614 |
-
# Normalize embeddings for cosine similarity
|
| 615 |
-
text_features_norm = F.normalize(text_features, dim=-1)
|
| 616 |
-
image_features_norm = F.normalize(image_features.to(device), dim=-1)
|
| 617 |
-
|
| 618 |
-
# Compute cosine similarities
|
| 619 |
-
similarities = (text_features_norm @ image_features_norm.T).squeeze(0) # Shape: [N]
|
| 620 |
-
|
| 621 |
-
# Step 4: Get top-k results
|
| 622 |
-
top_k = 10
|
| 623 |
-
top_scores, top_indices = similarities.topk(top_k, largest=True)
|
| 624 |
|
| 625 |
-
#
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
print(f"{i+1}. {image_paths[idx]} (similarity: {score.item():.4f})")
|
| 629 |
|
| 630 |
-
#
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
query_hierarchy_emb = text_features[:, 16:80]
|
| 635 |
-
# Use these for more targeted search
|
| 636 |
-
```
|
| 637 |
-
|
| 638 |
-
## 📝 Evaluation
|
| 639 |
-
|
| 640 |
-
### Comprehensive Model Evaluation
|
| 641 |
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
```
|
| 647 |
|
| 648 |
-
|
| 649 |
-
1. **Fashion-MNIST** (~10,000 samples) - Grayscale fashion items
|
| 650 |
-
2. **KAGL Marqo** (HuggingFace dataset) - Real fashion images with metadata
|
| 651 |
-
3. **Local Validation Dataset** - Custom validation set with local images
|
| 652 |
-
|
| 653 |
-
**Evaluation Metrics**:
|
| 654 |
-
|
| 655 |
-
For each dataset, the evaluation measures:
|
| 656 |
-
|
| 657 |
-
1. **Color Embeddings Performance** (dimensions 0-15):
|
| 658 |
-
- Nearest Neighbor (NN) Accuracy: Classification accuracy using nearest neighbor
|
| 659 |
-
- Centroid Accuracy: Classification using cluster centroids
|
| 660 |
-
- Separation Score: How well color embeddings separate different classes
|
| 661 |
-
|
| 662 |
-
2. **Hierarchy Embeddings Performance** (dimensions 16-79):
|
| 663 |
-
- Nearest Neighbor (NN) Accuracy: Classification accuracy for fashion categories
|
| 664 |
-
- Centroid Accuracy: Cluster-based classification
|
| 665 |
-
- Separation Score: Class separation quality
|
| 666 |
-
|
| 667 |
-
3. **Full Embeddings Performance** (all 512 dimensions):
|
| 668 |
-
- Evaluates complete embedding space
|
| 669 |
-
- Compares subspace (color/hierarchy) vs full embedding effectiveness
|
| 670 |
-
|
| 671 |
-
**Baseline Comparison**:
|
| 672 |
-
|
| 673 |
-
The evaluation includes comparison against `patrickjohncyh/fashion-clip`:
|
| 674 |
-
- Direct performance comparison on the same datasets
|
| 675 |
-
- Improvement metrics calculation
|
| 676 |
-
- Statistical significance analysis
|
| 677 |
-
|
| 678 |
-
**Key Evaluation Functions**:
|
| 679 |
-
- `evaluate_fashion_mnist()` : Test on Fashion-MNIST dataset
|
| 680 |
-
- `evaluate_kaggle_marqo()` : Test on real fashion images
|
| 681 |
-
- `evaluate_local_validation()` : Test on local validation set
|
| 682 |
-
- `evaluate_baseline_fashion_mnist()` : Baseline model on Fashion-MNIST
|
| 683 |
-
- `evaluate_baseline_kaggle_marqo()` : Baseline model on KAGL
|
| 684 |
-
- `evaluate_full_embeddings()` : Test complete 512D space
|
| 685 |
-
- `analyze_baseline_vs_trained_performance()` : Comparative analysis
|
| 686 |
-
- `compare_subspace_vs_full_embeddings()` : Subspace effectiveness
|
| 687 |
-
|
| 688 |
-
**Visualization Outputs** (saved in analysis directory):
|
| 689 |
-
- Confusion matrices for color and hierarchy classification
|
| 690 |
-
- t-SNE projections of embeddings
|
| 691 |
-
- Similarity heatmaps
|
| 692 |
-
- Performance comparison charts
|
| 693 |
-
|
| 694 |
-
**Other Evaluation Scripts**:
|
| 695 |
-
- `evaluate_color_embeddings.py` : Focused color embeddings evaluation
|
| 696 |
-
- `fashion_search.py` : Interactive fashion search tests
|
| 697 |
-
- `hierarchy_evaluation.py` : Hierarchy classification analysis
|
| 698 |
-
- `0_shot_classification.py` : Zero-shot classification tests
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
## 📊 Performance & Results
|
| 702 |
-
|
| 703 |
-
The evaluation framework (`main_model_evaluation.py`) tests the model across three datasets with comparison to a baseline fashion CLIP model.
|
| 704 |
-
|
| 705 |
-
### Evaluation Metrics
|
| 706 |
-
|
| 707 |
-
**Color Classification** (dimensions 0-15):
|
| 708 |
-
- Nearest Neighbor Accuracy
|
| 709 |
-
- Centroid-based Accuracy
|
| 710 |
-
- Separation Score (class separability)
|
| 711 |
-
|
| 712 |
-
**Hierarchy Classification** (dimensions 16-79):
|
| 713 |
-
- Nearest Neighbor Accuracy
|
| 714 |
-
- Centroid-based Accuracy
|
| 715 |
-
- Separation Score
|
| 716 |
-
|
| 717 |
-
**Full Embedding Quality** (all 512 dims):
|
| 718 |
-
- Tests whether the full space maintains performance
|
| 719 |
-
- Compares subspace vs full embedding effectiveness
|
| 720 |
-
|
| 721 |
-
### Datasets Used for Evaluation
|
| 722 |
-
|
| 723 |
-
1. **Fashion-MNIST**: 10,000 grayscale fashion item images
|
| 724 |
-
- 10 categories (T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)
|
| 725 |
-
- Mapped to model's hierarchy classes
|
| 726 |
-
|
| 727 |
-
2. **KAGL Marqo Dataset**: Real-world fashion images from HuggingFace
|
| 728 |
-
- Diverse fashion items with rich metadata
|
| 729 |
-
- Color and category annotations
|
| 730 |
-
- Realistic product images
|
| 731 |
-
|
| 732 |
-
3. **Local Validation Set**: Custom validation dataset
|
| 733 |
-
- Fashion items with local image paths
|
| 734 |
-
- Annotated with colors and hierarchies
|
| 735 |
-
- Domain-specific evaluation
|
| 736 |
-
|
| 737 |
-
### Comparative Analysis
|
| 738 |
-
|
| 739 |
-
The evaluation includes:
|
| 740 |
-
- **Baseline comparison**: GAP-CLIP vs `patrickjohncyh/fashion-clip`
|
| 741 |
-
- **Subspace analysis**: Dedicated dimensions (0-79) vs full space (0-511)
|
| 742 |
-
- **Cross-dataset generalization**: Performance consistency across datasets
|
| 743 |
-
- **Alignment quality**: How well specialized dimensions match expert models
|
| 744 |
-
|
| 745 |
-
All visualizations (confusion matrices, t-SNE plots, heatmaps) are automatically saved in the analysis directory.
|
| 746 |
-
|
| 747 |
-
## 📄 Citation
|
| 748 |
-
|
| 749 |
-
If you use GAP-CLIP in your research, please cite:
|
| 750 |
|
| 751 |
```bibtex
|
| 752 |
@misc{gap-clip-2024,
|
| 753 |
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 754 |
author={Sarfati, Lea Attia},
|
| 755 |
year={2024},
|
| 756 |
-
|
| 757 |
-
howpublished={\url{https://huggingface.co/Leacb4/gap-clip}},
|
| 758 |
-
abstract={GAP-CLIP introduces a novel training approach that guarantees specific embedding
|
| 759 |
-
dimensions encode color (dims 0-15) and hierarchy (dims 16-79) information through
|
| 760 |
-
direct alignment with specialized models, while maintaining full CLIP capabilities
|
| 761 |
-
in the remaining dimensions (80-511).}
|
| 762 |
}
|
| 763 |
```
|
| 764 |
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
- **Guaranteed Attribute Positioning**: Specific dimensions reliably encode color and hierarchy
|
| 768 |
-
- **Multi-Loss Training**: Combines contrastive learning with MSE and cosine alignment losses
|
| 769 |
-
- **Specialized Model Alignment**: Direct supervision from expert color and hierarchy models
|
| 770 |
-
- **Preserved Generalization**: Maintains base CLIP capabilities for cross-domain tasks
|
| 771 |
-
- **Comprehensive Evaluation**: Tested across multiple datasets with baseline comparisons
|
| 772 |
-
|
| 773 |
-
## ❓ FAQ & Troubleshooting
|
| 774 |
-
|
| 775 |
-
### Q: What are the minimum hardware requirements?
|
| 776 |
-
|
| 777 |
-
**A**:
|
| 778 |
-
- **GPU**: Recommended for training (CUDA or MPS). CPU training is very slow.
|
| 779 |
-
- **RAM**: Minimum 16GB, recommended 32GB for training
|
| 780 |
-
- **Storage**: ~5GB for models and datasets
|
| 781 |
-
|
| 782 |
-
### Q: Why are my embeddings not aligned?
|
| 783 |
-
|
| 784 |
-
**A**: Check that:
|
| 785 |
-
1. You're using the correct dimension ranges (0-15 for color, 16-79 for hierarchy)
|
| 786 |
-
2. The model was trained with alignment_weight > 0
|
| 787 |
-
3. Color and hierarchy models were properly loaded during training
|
| 788 |
-
|
| 789 |
-
### Q: How do I use only the color or hierarchy subspace for search?
|
| 790 |
-
|
| 791 |
-
**A**:
|
| 792 |
-
```python
|
| 793 |
-
# Extract and use only color embeddings
|
| 794 |
-
text_color_emb = text_features[:, :16]
|
| 795 |
-
image_color_emb = image_features[:, :16]
|
| 796 |
-
color_similarity = F.cosine_similarity(text_color_emb, image_color_emb)
|
| 797 |
-
|
| 798 |
-
# Extract and use only hierarchy embeddings
|
| 799 |
-
text_hierarchy_emb = text_features[:, 16:80]
|
| 800 |
-
image_hierarchy_emb = image_features[:, 16:80]
|
| 801 |
-
hierarchy_similarity = F.cosine_similarity(text_hierarchy_emb, image_hierarchy_emb)
|
| 802 |
-
```
|
| 803 |
-
|
| 804 |
-
### Q: Can I add more attributes beyond color and hierarchy?
|
| 805 |
-
|
| 806 |
-
**A**: Yes! The architecture is extensible:
|
| 807 |
-
1. Train a new specialized model for your attribute
|
| 808 |
-
2. Reserve additional dimensions in the embedding space
|
| 809 |
-
3. Add alignment losses for these dimensions in `enhanced_contrastive_loss`
|
| 810 |
-
4. Update `config.py` with new dimension ranges
|
| 811 |
-
|
| 812 |
-
### Q: How do I evaluate on my own dataset?
|
| 813 |
-
|
| 814 |
-
**A**:
|
| 815 |
-
1. Format your dataset as CSV with columns: `text`, `color`, `hierarchy`, `local_image_path`
|
| 816 |
-
2. Update `config.local_dataset_path` in `config.py`
|
| 817 |
-
3. Run the evaluation: `python evaluation/main_model_evaluation.py`
|
| 818 |
-
|
| 819 |
-
### Q: Training loss is decreasing but validation loss is increasing. What should I do?
|
| 820 |
-
|
| 821 |
-
**A**: This indicates overfitting. Try:
|
| 822 |
-
- Increase `weight_decay` (e.g., from 5e-4 to 1e-3)
|
| 823 |
-
- Reduce `alignment_weight` (e.g., from 0.2 to 0.1)
|
| 824 |
-
- Increase dataset size (`subset_size`)
|
| 825 |
-
- Add more data augmentation in `CustomDataset`
|
| 826 |
-
- Enable or increase early stopping patience
|
| 827 |
-
|
| 828 |
-
### Q: Can I fine-tune GAP-CLIP on a specific domain?
|
| 829 |
-
|
| 830 |
-
**A**: Yes! Load the checkpoint and continue training:
|
| 831 |
-
```python
|
| 832 |
-
checkpoint = torch.load('models/gap_clip.pth')
|
| 833 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 834 |
-
# Continue training with your domain-specific data
|
| 835 |
-
```
|
| 836 |
-
|
| 837 |
-
## 📦 Upload to Hugging Face
|
| 838 |
-
|
| 839 |
-
The project includes a **professional upload script** (✨ completely rewritten) for easy deployment:
|
| 840 |
-
|
| 841 |
-
```bash
|
| 842 |
-
cd upload_hf
|
| 843 |
-
|
| 844 |
-
# Authenticate (first time only)
|
| 845 |
-
huggingface-cli login
|
| 846 |
-
|
| 847 |
-
# Upload everything
|
| 848 |
-
python upload_to_huggingface.py --repo-id your-username/gap-clip --categories all
|
| 849 |
-
|
| 850 |
-
# Or upload specific categories
|
| 851 |
-
python upload_to_huggingface.py --repo-id your-username/gap-clip --categories models code
|
| 852 |
-
|
| 853 |
-
# Create private repository
|
| 854 |
-
python upload_to_huggingface.py --repo-id your-username/gap-clip --private
|
| 855 |
-
```
|
| 856 |
-
|
| 857 |
-
**Features**:
|
| 858 |
-
- ✨ Object-oriented design with `HuggingFaceUploader` class
|
| 859 |
-
- ✨ Multiple authentication methods (token, saved, interactive)
|
| 860 |
-
- ✨ Category-based uploads: models, code, docs, data, optuna, evaluation
|
| 861 |
-
- ✨ Progress tracking with tqdm
|
| 862 |
-
- ✨ Automatic model card generation
|
| 863 |
-
- ✨ Detailed error handling and recovery
|
| 864 |
-
- ✨ Upload statistics and summary
|
| 865 |
-
|
| 866 |
-
See `upload_hf/README_UPLOAD.md` for complete documentation.
|
| 867 |
-
|
| 868 |
-
## 🧪 Testing & Evaluation
|
| 869 |
-
|
| 870 |
-
### Quick Test
|
| 871 |
-
|
| 872 |
-
```bash
|
| 873 |
-
# Test configuration
|
| 874 |
-
python -c "import config; config.print_config()"
|
| 875 |
-
|
| 876 |
-
# Test model loading
|
| 877 |
-
python example_usage.py --repo-id Leacb4/gap-clip --text "red dress"
|
| 878 |
-
```
|
| 879 |
-
|
| 880 |
-
### Full Evaluation Suite
|
| 881 |
-
|
| 882 |
-
```bash
|
| 883 |
-
# Run all evaluations
|
| 884 |
-
cd evaluation
|
| 885 |
-
python run_all_evaluations.py --repo-id Leacb4/gap-clip
|
| 886 |
-
|
| 887 |
-
# Results will be saved to evaluation_results/ with:
|
| 888 |
-
# - summary.json: Detailed metrics
|
| 889 |
-
# - summary_comparison.png: Visual comparison
|
| 890 |
-
```
|
| 891 |
-
|
| 892 |
-
## 🐛 Known Issues & Fixes
|
| 893 |
-
|
| 894 |
-
### Fixed Issues ✨
|
| 895 |
-
|
| 896 |
-
1. **Color model image loading bug** (Fixed in `color_model.py`)
|
| 897 |
-
- Previous: `Image.open(config.column_local_image_path)`
|
| 898 |
-
- Fixed: `Image.open(img_path)` - Now correctly gets path from dataframe
|
| 899 |
-
|
| 900 |
-
2. **Function naming in training** (Fixed in `main_model.py` and `train_main_model.py`)
|
| 901 |
-
- Previous: `train_one_epoch_enhanced`
|
| 902 |
-
- Fixed: `train_one_epoch` - Consistent naming
|
| 903 |
-
|
| 904 |
-
3. **Device compatibility** (Improved in `config.py`)
|
| 905 |
-
- Now automatically detects and selects best device (CUDA > MPS > CPU)
|
| 906 |
-
|
| 907 |
-
## 🎓 Learning Resources
|
| 908 |
-
|
| 909 |
-
### Documentation Files
|
| 910 |
-
|
| 911 |
-
- **README.md** (this file): Complete project documentation
|
| 912 |
-
- **upload_hf/README_UPLOAD.md**: Upload guide for Hugging Face
|
| 913 |
-
- **evaluation/**: Multiple evaluation examples
|
| 914 |
-
|
| 915 |
-
### Code Examples
|
| 916 |
-
|
| 917 |
-
- **example_usage.py**: Basic usage with Hugging Face Hub
|
| 918 |
-
- **evaluation/fashion_search.py**: Interactive search examples
|
| 919 |
-
- **evaluation/tsne_images.py**: Visualization examples
|
| 920 |
-
|
| 921 |
-
## 🤝 Contributing
|
| 922 |
-
|
| 923 |
-
We welcome contributions! Here's how:
|
| 924 |
-
|
| 925 |
-
1. **Report bugs**: Open an issue with detailed description
|
| 926 |
-
2. **Suggest features**: Describe your idea in an issue
|
| 927 |
-
3. **Submit PR**: Fork, create branch, commit, and open pull request
|
| 928 |
-
4. **Improve docs**: Help make documentation clearer
|
| 929 |
-
|
| 930 |
-
### Development Setup
|
| 931 |
-
|
| 932 |
-
```bash
|
| 933 |
-
# Install with dev dependencies
|
| 934 |
-
pip install -e ".[dev]"
|
| 935 |
-
|
| 936 |
-
# Run tests (if available)
|
| 937 |
-
pytest
|
| 938 |
-
|
| 939 |
-
# Format code
|
| 940 |
-
black .
|
| 941 |
-
flake8 .
|
| 942 |
-
```
|
| 943 |
-
|
| 944 |
-
## 📊 Project Statistics
|
| 945 |
-
|
| 946 |
-
- **Language**: Python 3.8+
|
| 947 |
-
- **Framework**: PyTorch 2.0+
|
| 948 |
-
- **Models**: 3 specialized models (color, hierarchy, main)
|
| 949 |
-
- **Embedding Size**: 512 dimensions
|
| 950 |
-
- **Training Data**: 20,000+ fashion items
|
| 951 |
-
- **Lines of Code**: 5,000+ (including documentation)
|
| 952 |
-
- **Documentation**: Comprehensive docstrings and guides
|
| 953 |
-
|
| 954 |
-
## 🔗 Links
|
| 955 |
-
|
| 956 |
-
- **Hugging Face Hub**: [Leacb4/gap-clip](https://huggingface.co/Leacb4/gap-clip)
|
| 957 |
-
- **GitHub**: [github.com/Leacb4/gap-clip](https://github.com/Leacb4/gap-clip)
|
| 958 |
-
- **Contact**: lea.attia@gmail.com
|
| 959 |
-
|
| 960 |
-
## 📧 Contact & Support
|
| 961 |
-
|
| 962 |
-
**Author**: Lea Attia Sarfati
|
| 963 |
-
**Email**: lea.attia@gmail.com
|
| 964 |
-
**Hugging Face**: [@Leacb4](https://huggingface.co/Leacb4)
|
| 965 |
-
|
| 966 |
-
For questions, issues, or suggestions:
|
| 967 |
-
- 🐛 **Bug reports**: Open an issue on GitHub
|
| 968 |
-
- 💡 **Feature requests**: Open an issue with [Feature Request] tag
|
| 969 |
-
- 📧 **Direct contact**: lea.attia@gmail.com
|
| 970 |
-
- 💬 **Discussions**: Hugging Face Discussions
|
| 971 |
-
|
| 972 |
-
---
|
| 973 |
-
|
| 974 |
-
## 📜 License
|
| 975 |
-
|
| 976 |
-
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 977 |
-
|
| 978 |
-
## 🙏 Acknowledgments
|
| 979 |
-
|
| 980 |
-
- LAION team for the base CLIP model
|
| 981 |
-
- Hugging Face for transformers library and model hosting
|
| 982 |
-
- PyTorch team for the deep learning framework
|
| 983 |
-
- Fashion-MNIST dataset creators
|
| 984 |
-
- All contributors and users of this project
|
| 985 |
-
|
| 986 |
-
---
|
| 987 |
-
|
| 988 |
-
**⭐ If you find this project useful, please consider giving it a star on GitHub!**
|
| 989 |
|
| 990 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
language: en
|
| 3 |
+
tags:
|
| 4 |
+
- fashion
|
| 5 |
+
- clip
|
| 6 |
+
- multimodal
|
| 7 |
+
- image-search
|
| 8 |
+
- text-search
|
| 9 |
+
- embeddings
|
| 10 |
+
- contrastive-learning
|
| 11 |
+
license: mit
|
| 12 |
+
datasets:
|
| 13 |
+
- custom
|
| 14 |
+
metrics:
|
| 15 |
+
- accuracy
|
| 16 |
+
- cosine-similarity
|
| 17 |
+
library_name: transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
This model is part of the GAP-CLIP project for fashion search with guaranteed attribute positioning.
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
## Model Description
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
GAP-CLIP is a multi-modal search model for fashion that combines:
|
| 27 |
+
- **Color embeddings** (16 dimensions): Specialized for color representation
|
| 28 |
+
- **Hierarchy embeddings** (64 dimensions): Specialized for category classification
|
| 29 |
+
- **General CLIP embeddings** (432 dimensions): General visual-semantic understanding
|
| 30 |
|
| 31 |
+
**Total embedding size**: 512 dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
## Quick Start
|
| 34 |
|
| 35 |
```python
|
| 36 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 37 |
+
from huggingface_hub import hf_hub_download
|
| 38 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Load model
|
| 41 |
+
model = CLIPModel.from_pretrained("Leacb4/gap-clip")
|
| 42 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
|
|
|
| 43 |
|
| 44 |
+
# Process text
|
| 45 |
+
text = "red dress"
|
| 46 |
+
inputs = processor(text=[text], return_tensors="pt", padding=True)
|
| 47 |
+
text_features = model.get_text_features(**inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# Extract subspaces
|
| 50 |
+
color_emb = text_features[:, :16] # Color dimensions
|
| 51 |
+
hierarchy_emb = text_features[:, 16:80] # Hierarchy dimensions
|
| 52 |
+
general_emb = text_features[:, 80:] # General CLIP dimensions
|
| 53 |
```
|
| 54 |
|
| 55 |
+
## Citation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
```bibtex
|
| 58 |
@misc{gap-clip-2024,
|
| 59 |
title={GAP-CLIP: Guaranteed Attribute Positioning in CLIP Embeddings for Fashion Search},
|
| 60 |
author={Sarfati, Lea Attia},
|
| 61 |
year={2024},
|
| 62 |
+
url={https://huggingface.co/Leacb4/gap-clip}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
```
|
| 65 |
|
| 66 |
+
## License
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
MIT License - See LICENSE file for details.
|