Leacb4 commited on
Commit
f2f5c64
·
verified ·
1 Parent(s): 51820f5

Upload config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. config.py +211 -22
config.py CHANGED
@@ -1,27 +1,216 @@
1
  """
2
- Centralized configuration file for the project.
3
- This file contains all file paths, embedding dimensions, and configuration parameters
4
- used throughout the project (models, datasets, save paths, etc.).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
 
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- local_dataset_path = "data/data_with_local_paths.csv"
10
- color_model_path = "models/color_model.pt"
11
- hierarchy_model_path = "models/hierarchy_model.pth"
12
- main_model_path = "models/gap_clip.pth"
13
- tokeniser_path = "tokenizer_vocab.json"
14
- fashion_mnist_test_path = "data/fashion-mnist_test.csv"
15
-
16
- images_dir = "data/images"
17
- evaluation_directory = 'evaluation/'
18
-
19
- column_local_image_path = "local_image_path"
20
- column_url_image = "image_url"
21
- text_column = "text"
22
- color_column = "color"
23
- hierarchy_column = "hierarchy"
24
- hierarchy_emb_dim = 64
25
- color_emb_dim = 16
26
-
27
- device = torch.device("mps")
 
1
  """
2
+ Centralized Configuration Module for GAP-CLIP Project
3
+ ======================================================
4
+
5
+ This module contains all configuration parameters, file paths, and constants
6
+ used throughout the GAP-CLIP project. It provides a single source of truth
7
+ for model paths, embedding dimensions, dataset locations, and device settings.
8
+
9
+ Key Configuration Categories:
10
+ - Model paths: Paths to trained model checkpoints
11
+ - Data paths: Dataset locations and CSV files
12
+ - Embedding dimensions: Size of color and hierarchy embeddings
13
+ - Column names: CSV column identifiers for data loading
14
+ - Device: Hardware accelerator configuration (CUDA, MPS, or CPU)
15
+
16
+ Usage:
17
+ >>> import config
18
+ >>> model_path = config.main_model_path
19
+ >>> device = config.device
20
+ >>> color_dim = config.color_emb_dim
21
+
22
+ Author: Lea Attia Sarfati
23
+ Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings)
24
  """
25
 
26
+ from typing import Final
27
  import torch
28
+ import os
29
+
30
+ # =============================================================================
31
+ # MODEL PATHS
32
+ # =============================================================================
33
+ # Paths to trained model checkpoints used for inference and fine-tuning
34
+
35
+ #: Path to the trained color model checkpoint (ColorCLIP)
36
+ #: This model extracts 16-dimensional color embeddings from images and text
37
+ color_model_path: Final[str] = "models/color_model.pt"
38
+
39
+ #: Path to the trained hierarchy model checkpoint
40
+ #: This model extracts 64-dimensional category embeddings (e.g., dress, shirt, shoes)
41
+ hierarchy_model_path: Final[str] = "models/hierarchy_model.pth"
42
+
43
+ #: Path to the main GAP-CLIP model checkpoint
44
+ #: This is the primary 512-dimensional CLIP model with aligned color and hierarchy subspaces
45
+ main_model_path: Final[str] = "models/gap_clip.pth"
46
+
47
+ #: Path to the tokenizer vocabulary JSON file
48
+ #: Used by the color model's text encoder for tokenization
49
+ tokeniser_path: Final[str] = "tokenizer_vocab.json"
50
+
51
+ # =============================================================================
52
+ # DATASET PATHS
53
+ # =============================================================================
54
+ # Paths to training, validation, and test datasets
55
+
56
+ #: Path to the main training dataset with local image paths
57
+ #: CSV format with columns: text, color, hierarchy, local_image_path
58
+ local_dataset_path: Final[str] = "data/data_with_local_paths.csv"
59
+
60
+ #: Path to Fashion-MNIST test dataset for evaluation
61
+ #: Used for zero-shot classification benchmarking
62
+ fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv"
63
+
64
+ #: Directory containing image files for the dataset
65
+ images_dir: Final[str] = "data/images"
66
+
67
+ #: Directory for evaluation scripts and results
68
+ evaluation_directory: Final[str] = "evaluation/"
69
+
70
+ # =============================================================================
71
+ # CSV COLUMN NAMES
72
+ # =============================================================================
73
+ # Column identifiers used in dataset CSV files
74
+
75
+ #: Column name for local file paths to images
76
+ column_local_image_path: Final[str] = "local_image_path"
77
+
78
+ #: Column name for image URLs (when using remote images)
79
+ column_url_image: Final[str] = "image_url"
80
+
81
+ #: Column name for text descriptions of fashion items
82
+ text_column: Final[str] = "text"
83
+
84
+ #: Column name for color labels (e.g., "red", "blue", "black")
85
+ color_column: Final[str] = "color"
86
+
87
+ #: Column name for hierarchy/category labels (e.g., "dress", "shirt", "shoes")
88
+ hierarchy_column: Final[str] = "hierarchy"
89
+
90
+ # =============================================================================
91
+ # EMBEDDING DIMENSIONS
92
+ # =============================================================================
93
+ # Dimensionality of various embedding spaces
94
+
95
+ #: Dimension of color embeddings (positions 0-15 in main model)
96
+ #: These dimensions are explicitly trained to encode color information
97
+ color_emb_dim: Final[int] = 16
98
+
99
+ #: Dimension of hierarchy embeddings (positions 16-79 in main model)
100
+ #: These dimensions are explicitly trained to encode category information
101
+ hierarchy_emb_dim: Final[int] = 64
102
+
103
+ #: Total dimension of main CLIP embeddings
104
+ #: Structure: [color (16) | hierarchy (64) | general CLIP (432)] = 512
105
+ main_emb_dim: Final[int] = 512
106
+
107
+ #: Dimension of general CLIP embeddings (remaining dimensions after color and hierarchy)
108
+ general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim
109
+
110
+ # =============================================================================
111
+ # DEVICE CONFIGURATION
112
+ # =============================================================================
113
+ # Hardware accelerator settings for model training and inference
114
+
115
+ def get_device() -> torch.device:
116
+ """
117
+ Automatically detect and return the best available device.
118
+
119
+ Priority order:
120
+ 1. CUDA (NVIDIA GPU) if available
121
+ 2. MPS (Apple Silicon) if available
122
+ 3. CPU as fallback
123
+
124
+ Returns:
125
+ torch.device: The device to use for tensor operations
126
+
127
+ Examples:
128
+ >>> device = get_device()
129
+ >>> model = model.to(device)
130
+ """
131
+ if torch.cuda.is_available():
132
+ return torch.device("cuda")
133
+ elif torch.backends.mps.is_available():
134
+ return torch.device("mps")
135
+ else:
136
+ return torch.device("cpu")
137
+
138
+ #: Primary device for model operations
139
+ #: Automatically selects CUDA > MPS > CPU
140
+ device: torch.device = get_device()
141
+
142
+ # =============================================================================
143
+ # TRAINING HYPERPARAMETERS (DEFAULT VALUES)
144
+ # =============================================================================
145
+ # Default training parameters - can be overridden in training scripts
146
+
147
+ #: Default batch size for training
148
+ DEFAULT_BATCH_SIZE: Final[int] = 32
149
+
150
+ #: Default number of training epochs
151
+ DEFAULT_NUM_EPOCHS: Final[int] = 20
152
+
153
+ #: Default learning rate for optimizer
154
+ DEFAULT_LEARNING_RATE: Final[float] = 1.5e-5
155
+
156
+ #: Default temperature for contrastive loss
157
+ DEFAULT_TEMPERATURE: Final[float] = 0.09
158
+
159
+ #: Default weight for alignment loss
160
+ DEFAULT_ALIGNMENT_WEIGHT: Final[float] = 0.2
161
+
162
+ #: Default weight decay for L2 regularization
163
+ DEFAULT_WEIGHT_DECAY: Final[float] = 5e-4
164
+
165
+ # =============================================================================
166
+ # UTILITY FUNCTIONS
167
+ # =============================================================================
168
+
169
+ def validate_paths() -> bool:
170
+ """
171
+ Validate that all critical paths exist and are accessible.
172
+
173
+ Returns:
174
+ bool: True if all paths exist, False otherwise
175
+
176
+ Raises:
177
+ FileNotFoundError: If critical model files are missing
178
+ """
179
+ critical_paths = [
180
+ color_model_path,
181
+ hierarchy_model_path,
182
+ main_model_path,
183
+ tokeniser_path
184
+ ]
185
+
186
+ missing_paths = [p for p in critical_paths if not os.path.exists(p)]
187
+
188
+ if missing_paths:
189
+ print(f"⚠️ Warning: Missing files: {', '.join(missing_paths)}")
190
+ return False
191
+
192
+ return True
193
+
194
+ def print_config() -> None:
195
+ """
196
+ Print a formatted summary of the current configuration.
197
+
198
+ Useful for debugging and logging training runs.
199
+ """
200
+ print("=" * 80)
201
+ print("GAP-CLIP Configuration")
202
+ print("=" * 80)
203
+ print(f"Device: {device}")
204
+ print(f"Color embedding dim: {color_emb_dim}")
205
+ print(f"Hierarchy embedding dim: {hierarchy_emb_dim}")
206
+ print(f"Main embedding dim: {main_emb_dim}")
207
+ print(f"Main model path: {main_model_path}")
208
+ print(f"Color model path: {color_model_path}")
209
+ print(f"Hierarchy model path: {hierarchy_model_path}")
210
+ print(f"Dataset path: {local_dataset_path}")
211
+ print("=" * 80)
212
 
213
+ # Initialize and validate configuration on import
214
+ if __name__ == "__main__":
215
+ print_config()
216
+ validate_paths()