gap-clip / config.py
Leacb4's picture
Upload config.py with huggingface_hub
f2f5c64 verified
"""
Centralized Configuration Module for GAP-CLIP Project
======================================================
This module contains all configuration parameters, file paths, and constants
used throughout the GAP-CLIP project. It provides a single source of truth
for model paths, embedding dimensions, dataset locations, and device settings.
Key Configuration Categories:
- Model paths: Paths to trained model checkpoints
- Data paths: Dataset locations and CSV files
- Embedding dimensions: Size of color and hierarchy embeddings
- Column names: CSV column identifiers for data loading
- Device: Hardware accelerator configuration (CUDA, MPS, or CPU)
Usage:
>>> import config
>>> model_path = config.main_model_path
>>> device = config.device
>>> color_dim = config.color_emb_dim
Author: Lea Attia Sarfati
Project: GAP-CLIP (Guaranteed Attribute Positioning in CLIP Embeddings)
"""
from typing import Final
import torch
import os
# =============================================================================
# MODEL PATHS
# =============================================================================
# Paths to trained model checkpoints used for inference and fine-tuning
#: Path to the trained color model checkpoint (ColorCLIP)
#: This model extracts 16-dimensional color embeddings from images and text
color_model_path: Final[str] = "models/color_model.pt"
#: Path to the trained hierarchy model checkpoint
#: This model extracts 64-dimensional category embeddings (e.g., dress, shirt, shoes)
hierarchy_model_path: Final[str] = "models/hierarchy_model.pth"
#: Path to the main GAP-CLIP model checkpoint
#: This is the primary 512-dimensional CLIP model with aligned color and hierarchy subspaces
main_model_path: Final[str] = "models/gap_clip.pth"
#: Path to the tokenizer vocabulary JSON file
#: Used by the color model's text encoder for tokenization
tokeniser_path: Final[str] = "tokenizer_vocab.json"
# =============================================================================
# DATASET PATHS
# =============================================================================
# Paths to training, validation, and test datasets
#: Path to the main training dataset with local image paths
#: CSV format with columns: text, color, hierarchy, local_image_path
local_dataset_path: Final[str] = "data/data_with_local_paths.csv"
#: Path to Fashion-MNIST test dataset for evaluation
#: Used for zero-shot classification benchmarking
fashion_mnist_test_path: Final[str] = "data/fashion-mnist_test.csv"
#: Directory containing image files for the dataset
images_dir: Final[str] = "data/images"
#: Directory for evaluation scripts and results
evaluation_directory: Final[str] = "evaluation/"
# =============================================================================
# CSV COLUMN NAMES
# =============================================================================
# Column identifiers used in dataset CSV files
#: Column name for local file paths to images
column_local_image_path: Final[str] = "local_image_path"
#: Column name for image URLs (when using remote images)
column_url_image: Final[str] = "image_url"
#: Column name for text descriptions of fashion items
text_column: Final[str] = "text"
#: Column name for color labels (e.g., "red", "blue", "black")
color_column: Final[str] = "color"
#: Column name for hierarchy/category labels (e.g., "dress", "shirt", "shoes")
hierarchy_column: Final[str] = "hierarchy"
# =============================================================================
# EMBEDDING DIMENSIONS
# =============================================================================
# Dimensionality of various embedding spaces
#: Dimension of color embeddings (positions 0-15 in main model)
#: These dimensions are explicitly trained to encode color information
color_emb_dim: Final[int] = 16
#: Dimension of hierarchy embeddings (positions 16-79 in main model)
#: These dimensions are explicitly trained to encode category information
hierarchy_emb_dim: Final[int] = 64
#: Total dimension of main CLIP embeddings
#: Structure: [color (16) | hierarchy (64) | general CLIP (432)] = 512
main_emb_dim: Final[int] = 512
#: Dimension of general CLIP embeddings (remaining dimensions after color and hierarchy)
general_clip_dim: Final[int] = main_emb_dim - color_emb_dim - hierarchy_emb_dim
# =============================================================================
# DEVICE CONFIGURATION
# =============================================================================
# Hardware accelerator settings for model training and inference
def get_device() -> torch.device:
"""
Automatically detect and return the best available device.
Priority order:
1. CUDA (NVIDIA GPU) if available
2. MPS (Apple Silicon) if available
3. CPU as fallback
Returns:
torch.device: The device to use for tensor operations
Examples:
>>> device = get_device()
>>> model = model.to(device)
"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
#: Primary device for model operations
#: Automatically selects CUDA > MPS > CPU
device: torch.device = get_device()
# =============================================================================
# TRAINING HYPERPARAMETERS (DEFAULT VALUES)
# =============================================================================
# Default training parameters - can be overridden in training scripts
#: Default batch size for training
DEFAULT_BATCH_SIZE: Final[int] = 32
#: Default number of training epochs
DEFAULT_NUM_EPOCHS: Final[int] = 20
#: Default learning rate for optimizer
DEFAULT_LEARNING_RATE: Final[float] = 1.5e-5
#: Default temperature for contrastive loss
DEFAULT_TEMPERATURE: Final[float] = 0.09
#: Default weight for alignment loss
DEFAULT_ALIGNMENT_WEIGHT: Final[float] = 0.2
#: Default weight decay for L2 regularization
DEFAULT_WEIGHT_DECAY: Final[float] = 5e-4
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def validate_paths() -> bool:
"""
Validate that all critical paths exist and are accessible.
Returns:
bool: True if all paths exist, False otherwise
Raises:
FileNotFoundError: If critical model files are missing
"""
critical_paths = [
color_model_path,
hierarchy_model_path,
main_model_path,
tokeniser_path
]
missing_paths = [p for p in critical_paths if not os.path.exists(p)]
if missing_paths:
print(f"⚠️ Warning: Missing files: {', '.join(missing_paths)}")
return False
return True
def print_config() -> None:
"""
Print a formatted summary of the current configuration.
Useful for debugging and logging training runs.
"""
print("=" * 80)
print("GAP-CLIP Configuration")
print("=" * 80)
print(f"Device: {device}")
print(f"Color embedding dim: {color_emb_dim}")
print(f"Hierarchy embedding dim: {hierarchy_emb_dim}")
print(f"Main embedding dim: {main_emb_dim}")
print(f"Main model path: {main_model_path}")
print(f"Color model path: {color_model_path}")
print(f"Hierarchy model path: {hierarchy_model_path}")
print(f"Dataset path: {local_dataset_path}")
print("=" * 80)
# Initialize and validate configuration on import
if __name__ == "__main__":
print_config()
validate_paths()