File size: 7,710 Bytes
79a1985
f2f5c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79a1985
 
f2f5c64
a48e661
f2f5c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a48e661
f2f5c64
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Centralized Configuration Module for GAP-CLIP Project
======================================================

This module contains all configuration parameters, file paths, and constants
used throughout the GAP-CLIP project. It provides a single source of truth
for model paths, embedding dimensions, dataset locations, and device settings.

Key Configuration Categories:
    - Model paths: Paths to trained model checkpoints
    - Data paths: Dataset locations and CSV files
    - Embedding dimensions: Size of color and hierarchy embeddings
    - Column names: CSV column identifiers for data loading
    - Device: Hardware accelerator configuration (CUDA, MPS, or CPU)

Usage:
    >>> import config
    >>> model_path = config.main_model_path
    >>> device = config.device
    >>> color_dim = config.color_emb_dim

Author: Lea Attia Sarfati
Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings)
"""

from typing import Final
import torch
import os

# =============================================================================
# MODEL PATHS
# =============================================================================
# Paths to trained model checkpoints used for inference and fine-tuning

#: Path to the trained color model checkpoint (ColorCLIP)
#: This model extracts 16-dimensional color embeddings from images and text
color_model_path: Final[str] = "models/color_model.pt"

#: Path to the trained hierarchy model checkpoint
#: This model extracts 64-dimensional category embeddings (e.g., dress, shirt, shoes)
hierarchy_model_path: Final[str] = "models/hierarchy_model.pth"

#: Path to the main GAP-CLIP model checkpoint
#: This is the primary 512-dimensional CLIP model with aligned color and hierarchy subspaces
main_model_path: Final[str] = "models/gap_clip.pth"

#: Path to the tokenizer vocabulary JSON file
#: Used by the color model's text encoder for tokenization
tokeniser_path: Final[str] = "tokenizer_vocab.json"

# =============================================================================
# DATASET PATHS
# =============================================================================
# Paths to training, validation, and test datasets

#: Path to the main training dataset with local image paths
#: CSV format with columns: text, color, hierarchy, local_image_path
local_dataset_path: Final[str] = "data/data_with_local_paths.csv"

#: Path to Fashion-MNIST test dataset for evaluation
#: Used for zero-shot classification benchmarking
fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv"

#: Directory containing image files for the dataset
images_dir: Final[str] = "data/images"

#: Directory for evaluation scripts and results
evaluation_directory: Final[str] = "evaluation/"

# =============================================================================
# CSV COLUMN NAMES
# =============================================================================
# Column identifiers used in dataset CSV files

#: Column name for local file paths to images
column_local_image_path: Final[str] = "local_image_path"

#: Column name for image URLs (when using remote images)
column_url_image: Final[str] = "image_url"

#: Column name for text descriptions of fashion items
text_column: Final[str] = "text"

#: Column name for color labels (e.g., "red", "blue", "black")
color_column: Final[str] = "color"

#: Column name for hierarchy/category labels (e.g., "dress", "shirt", "shoes")
hierarchy_column: Final[str] = "hierarchy"

# =============================================================================
# EMBEDDING DIMENSIONS
# =============================================================================
# Dimensionality of various embedding spaces

#: Dimension of color embeddings (positions 0-15 in main model)
#: These dimensions are explicitly trained to encode color information
color_emb_dim: Final[int] = 16

#: Dimension of hierarchy embeddings (positions 16-79 in main model)
#: These dimensions are explicitly trained to encode category information
hierarchy_emb_dim: Final[int] = 64

#: Total dimension of main CLIP embeddings
#: Structure: [color (16) | hierarchy (64) | general CLIP (432)] = 512
main_emb_dim: Final[int] = 512

#: Dimension of general CLIP embeddings (remaining dimensions after color and hierarchy)
general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim

# =============================================================================
# DEVICE CONFIGURATION
# =============================================================================
# Hardware accelerator settings for model training and inference

def get_device() -> torch.device:
    """
    Automatically detect and return the best available device.
    
    Priority order:
        1. CUDA (NVIDIA GPU) if available
        2. MPS (Apple Silicon) if available
        3. CPU as fallback
    
    Returns:
        torch.device: The device to use for tensor operations
        
    Examples:
        >>> device = get_device()
        >>> model = model.to(device)
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

#: Primary device for model operations
#: Automatically selects CUDA > MPS > CPU
device: torch.device = get_device()

# =============================================================================
# TRAINING HYPERPARAMETERS (DEFAULT VALUES)
# =============================================================================
# Default training parameters - can be overridden in training scripts

#: Default batch size for training
DEFAULT_BATCH_SIZE: Final[int] = 32

#: Default number of training epochs
DEFAULT_NUM_EPOCHS: Final[int] = 20

#: Default learning rate for optimizer
DEFAULT_LEARNING_RATE: Final[float] = 1.5e-5

#: Default temperature for contrastive loss
DEFAULT_TEMPERATURE: Final[float] = 0.09

#: Default weight for alignment loss
DEFAULT_ALIGNMENT_WEIGHT: Final[float] = 0.2

#: Default weight decay for L2 regularization
DEFAULT_WEIGHT_DECAY: Final[float] = 5e-4

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def validate_paths() -> bool:
    """
    Validate that all critical paths exist and are accessible.
    
    Returns:
        bool: True if all paths exist, False otherwise
        
    Raises:
        FileNotFoundError: If critical model files are missing
    """
    critical_paths = [
        color_model_path,
        hierarchy_model_path,
        main_model_path,
        tokeniser_path
    ]
    
    missing_paths = [p for p in critical_paths if not os.path.exists(p)]
    
    if missing_paths:
        print(f"⚠️  Warning: Missing files: {', '.join(missing_paths)}")
        return False
    
    return True

def print_config() -> None:
    """
    Print a formatted summary of the current configuration.
    
    Useful for debugging and logging training runs.
    """
    print("=" * 80)
    print("GAP-CLIP Configuration")
    print("=" * 80)
    print(f"Device: {device}")
    print(f"Color embedding dim: {color_emb_dim}")
    print(f"Hierarchy embedding dim: {hierarchy_emb_dim}")
    print(f"Main embedding dim: {main_emb_dim}")
    print(f"Main model path: {main_model_path}")
    print(f"Color model path: {color_model_path}")
    print(f"Hierarchy model path: {hierarchy_model_path}")
    print(f"Dataset path: {local_dataset_path}")
    print("=" * 80)

# Initialize and validate configuration on import
if __name__ == "__main__":
    print_config()
    validate_paths()