File size: 7,710 Bytes
79a1985 f2f5c64 79a1985 f2f5c64 a48e661 f2f5c64 a48e661 f2f5c64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
"""
Centralized Configuration Module for GAP-CLIP Project
======================================================
This module contains all configuration parameters, file paths, and constants
used throughout the GAP-CLIP project. It provides a single source of truth
for model paths, embedding dimensions, dataset locations, and device settings.
Key Configuration Categories:
- Model paths: Paths to trained model checkpoints
- Data paths: Dataset locations and CSV files
- Embedding dimensions: Size of color and hierarchy embeddings
- Column names: CSV column identifiers for data loading
- Device: Hardware accelerator configuration (CUDA, MPS, or CPU)
Usage:
>>> import config
>>> model_path = config.main_model_path
>>> device = config.device
>>> color_dim = config.color_emb_dim
Author: Lea Attia Sarfati
Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings)
"""
from typing import Final
import torch
import os
# =============================================================================
# MODEL PATHS
# =============================================================================
# Paths to trained model checkpoints used for inference and fine-tuning
#: Path to the trained color model checkpoint (ColorCLIP)
#: This model extracts 16-dimensional color embeddings from images and text
color_model_path: Final[str] = "models/color_model.pt"
#: Path to the trained hierarchy model checkpoint
#: This model extracts 64-dimensional category embeddings (e.g., dress, shirt, shoes)
hierarchy_model_path: Final[str] = "models/hierarchy_model.pth"
#: Path to the main GAP-CLIP model checkpoint
#: This is the primary 512-dimensional CLIP model with aligned color and hierarchy subspaces
main_model_path: Final[str] = "models/gap_clip.pth"
#: Path to the tokenizer vocabulary JSON file
#: Used by the color model's text encoder for tokenization
tokeniser_path: Final[str] = "tokenizer_vocab.json"
# =============================================================================
# DATASET PATHS
# =============================================================================
# Paths to training, validation, and test datasets
#: Path to the main training dataset with local image paths
#: CSV format with columns: text, color, hierarchy, local_image_path
local_dataset_path: Final[str] = "data/data_with_local_paths.csv"
#: Path to Fashion-MNIST test dataset for evaluation
#: Used for zero-shot classification benchmarking
fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv"
#: Directory containing image files for the dataset
images_dir: Final[str] = "data/images"
#: Directory for evaluation scripts and results
evaluation_directory: Final[str] = "evaluation/"
# =============================================================================
# CSV COLUMN NAMES
# =============================================================================
# Column identifiers used in dataset CSV files
#: Column name for local file paths to images
column_local_image_path: Final[str] = "local_image_path"
#: Column name for image URLs (when using remote images)
column_url_image: Final[str] = "image_url"
#: Column name for text descriptions of fashion items
text_column: Final[str] = "text"
#: Column name for color labels (e.g., "red", "blue", "black")
color_column: Final[str] = "color"
#: Column name for hierarchy/category labels (e.g., "dress", "shirt", "shoes")
hierarchy_column: Final[str] = "hierarchy"
# =============================================================================
# EMBEDDING DIMENSIONS
# =============================================================================
# Dimensionality of various embedding spaces
#: Dimension of color embeddings (positions 0-15 in main model)
#: These dimensions are explicitly trained to encode color information
color_emb_dim: Final[int] = 16
#: Dimension of hierarchy embeddings (positions 16-79 in main model)
#: These dimensions are explicitly trained to encode category information
hierarchy_emb_dim: Final[int] = 64
#: Total dimension of main CLIP embeddings
#: Structure: [color (16) | hierarchy (64) | general CLIP (432)] = 512
main_emb_dim: Final[int] = 512
#: Dimension of general CLIP embeddings (remaining dimensions after color and hierarchy)
general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim
# =============================================================================
# DEVICE CONFIGURATION
# =============================================================================
# Hardware accelerator settings for model training and inference
def get_device() -> torch.device:
"""
Automatically detect and return the best available device.
Priority order:
1. CUDA (NVIDIA GPU) if available
2. MPS (Apple Silicon) if available
3. CPU as fallback
Returns:
torch.device: The device to use for tensor operations
Examples:
>>> device = get_device()
>>> model = model.to(device)
"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
#: Primary device for model operations
#: Automatically selects CUDA > MPS > CPU
device: torch.device = get_device()
# =============================================================================
# TRAINING HYPERPARAMETERS (DEFAULT VALUES)
# =============================================================================
# Default training parameters - can be overridden in training scripts
#: Default batch size for training
DEFAULT_BATCH_SIZE: Final[int] = 32
#: Default number of training epochs
DEFAULT_NUM_EPOCHS: Final[int] = 20
#: Default learning rate for optimizer
DEFAULT_LEARNING_RATE: Final[float] = 1.5e-5
#: Default temperature for contrastive loss
DEFAULT_TEMPERATURE: Final[float] = 0.09
#: Default weight for alignment loss
DEFAULT_ALIGNMENT_WEIGHT: Final[float] = 0.2
#: Default weight decay for L2 regularization
DEFAULT_WEIGHT_DECAY: Final[float] = 5e-4
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def validate_paths() -> bool:
"""
Validate that all critical paths exist and are accessible.
Returns:
bool: True if all paths exist, False otherwise
Raises:
FileNotFoundError: If critical model files are missing
"""
critical_paths = [
color_model_path,
hierarchy_model_path,
main_model_path,
tokeniser_path
]
missing_paths = [p for p in critical_paths if not os.path.exists(p)]
if missing_paths:
print(f"⚠️ Warning: Missing files: {', '.join(missing_paths)}")
return False
return True
def print_config() -> None:
"""
Print a formatted summary of the current configuration.
Useful for debugging and logging training runs.
"""
print("=" * 80)
print("GAP-CLIP Configuration")
print("=" * 80)
print(f"Device: {device}")
print(f"Color embedding dim: {color_emb_dim}")
print(f"Hierarchy embedding dim: {hierarchy_emb_dim}")
print(f"Main embedding dim: {main_emb_dim}")
print(f"Main model path: {main_model_path}")
print(f"Color model path: {color_model_path}")
print(f"Hierarchy model path: {hierarchy_model_path}")
print(f"Dataset path: {local_dataset_path}")
print("=" * 80)
# Initialize and validate configuration on import
if __name__ == "__main__":
print_config()
validate_paths() |