|
|
""" |
|
|
Centralized Configuration Module for GAP-CLIP Project |
|
|
====================================================== |
|
|
|
|
|
This module contains all configuration parameters, file paths, and constants |
|
|
used throughout the GAP-CLIP project. It provides a single source of truth |
|
|
for model paths, embedding dimensions, dataset locations, and device settings. |
|
|
|
|
|
Key Configuration Categories: |
|
|
- Model paths: Paths to trained model checkpoints |
|
|
- Data paths: Dataset locations and CSV files |
|
|
- Embedding dimensions: Size of color and hierarchy embeddings |
|
|
- Column names: CSV column identifiers for data loading |
|
|
- Device: Hardware accelerator configuration (CUDA, MPS, or CPU) |
|
|
|
|
|
Usage: |
|
|
>>> import config |
|
|
>>> model_path = config.main_model_path |
|
|
>>> device = config.device |
|
|
>>> color_dim = config.color_emb_dim |
|
|
|
|
|
Author: Lea Attia Sarfati |
|
|
Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings) |
|
|
""" |
|
|
|
|
|
from typing import Final |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color_model_path: Final[str] = "models/color_model.pt" |
|
|
|
|
|
|
|
|
|
|
|
hierarchy_model_path: Final[str] = "models/hierarchy_model.pth" |
|
|
|
|
|
|
|
|
|
|
|
main_model_path: Final[str] = "models/gap_clip.pth" |
|
|
|
|
|
|
|
|
|
|
|
tokeniser_path: Final[str] = "tokenizer_vocab.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_dataset_path: Final[str] = "data/data_with_local_paths.csv" |
|
|
|
|
|
|
|
|
|
|
|
fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv" |
|
|
|
|
|
|
|
|
images_dir: Final[str] = "data/images" |
|
|
|
|
|
|
|
|
evaluation_directory: Final[str] = "evaluation/" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
column_local_image_path: Final[str] = "local_image_path" |
|
|
|
|
|
|
|
|
column_url_image: Final[str] = "image_url" |
|
|
|
|
|
|
|
|
text_column: Final[str] = "text" |
|
|
|
|
|
|
|
|
color_column: Final[str] = "color" |
|
|
|
|
|
|
|
|
hierarchy_column: Final[str] = "hierarchy" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color_emb_dim: Final[int] = 16 |
|
|
|
|
|
|
|
|
|
|
|
hierarchy_emb_dim: Final[int] = 64 |
|
|
|
|
|
|
|
|
|
|
|
main_emb_dim: Final[int] = 512 |
|
|
|
|
|
|
|
|
general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device() -> torch.device: |
|
|
""" |
|
|
Automatically detect and return the best available device. |
|
|
|
|
|
Priority order: |
|
|
1. CUDA (NVIDIA GPU) if available |
|
|
2. MPS (Apple Silicon) if available |
|
|
3. CPU as fallback |
|
|
|
|
|
Returns: |
|
|
torch.device: The device to use for tensor operations |
|
|
|
|
|
Examples: |
|
|
>>> device = get_device() |
|
|
>>> model = model.to(device) |
|
|
""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
elif torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
else: |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
device: torch.device = get_device() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_BATCH_SIZE: Final[int] = 32 |
|
|
|
|
|
|
|
|
DEFAULT_NUM_EPOCHS: Final[int] = 20 |
|
|
|
|
|
|
|
|
DEFAULT_LEARNING_RATE: Final[float] = 1.5e-5 |
|
|
|
|
|
|
|
|
DEFAULT_TEMPERATURE: Final[float] = 0.09 |
|
|
|
|
|
|
|
|
DEFAULT_ALIGNMENT_WEIGHT: Final[float] = 0.2 |
|
|
|
|
|
|
|
|
DEFAULT_WEIGHT_DECAY: Final[float] = 5e-4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_paths() -> bool: |
|
|
""" |
|
|
Validate that all critical paths exist and are accessible. |
|
|
|
|
|
Returns: |
|
|
bool: True if all paths exist, False otherwise |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If critical model files are missing |
|
|
""" |
|
|
critical_paths = [ |
|
|
color_model_path, |
|
|
hierarchy_model_path, |
|
|
main_model_path, |
|
|
tokeniser_path |
|
|
] |
|
|
|
|
|
missing_paths = [p for p in critical_paths if not os.path.exists(p)] |
|
|
|
|
|
if missing_paths: |
|
|
print(f"⚠️ Warning: Missing files: {', '.join(missing_paths)}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def print_config() -> None: |
|
|
""" |
|
|
Print a formatted summary of the current configuration. |
|
|
|
|
|
Useful for debugging and logging training runs. |
|
|
""" |
|
|
print("=" * 80) |
|
|
print("GAP-CLIP Configuration") |
|
|
print("=" * 80) |
|
|
print(f"Device: {device}") |
|
|
print(f"Color embedding dim: {color_emb_dim}") |
|
|
print(f"Hierarchy embedding dim: {hierarchy_emb_dim}") |
|
|
print(f"Main embedding dim: {main_emb_dim}") |
|
|
print(f"Main model path: {main_model_path}") |
|
|
print(f"Color model path: {color_model_path}") |
|
|
print(f"Hierarchy model path: {hierarchy_model_path}") |
|
|
print(f"Dataset path: {local_dataset_path}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print_config() |
|
|
validate_paths() |