multimodal-hw / src /config.py
AlekMan's picture
Update src/config.py
81d99ab verified
"""Global configuration for the multimodal retrieval MVP."""
from dataclasses import dataclass, field
from pathlib import Path
from types import EllipsisType
@dataclass(frozen=True)
class Paths:
"""Paths used across the project."""
root: Path = Path(__file__).resolve().parents[1]
cache_dir: Path = root / "cache"
embeddings_dir: Path = root / "artifacts" / "embeddings"
indexes_dir: Path = root / "artifacts" / "indexes"
omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"
def ensure(self) -> None:
for path in [
self.cache_dir,
self.embeddings_dir,
self.indexes_dir,
self.omni_metadata_path.parent,
]:
path.mkdir(parents=True, exist_ok=True)
@dataclass(frozen=True)
class DatasetConfig:
"""Dataset parameters."""
name: str = "huggan/wikiart"
split: str = "train"
streaming: bool = True
seed: int = 42
sample_size: int = 5000
shuffle_buffer: int = 2048
image_column: str = "image"
id_column: str = "id"
artist_column: str = "artist"
style_column: str = "style"
genre_column: str = "genre"
@dataclass(frozen=True)
class ModelConfig:
"""Model identifiers and hyper-parameters."""
image_encoder: str = "openai/clip-vit-base-patch32"
caption_model: str = "Salesforce/blip-image-captioning-large"
vlm_model: str = "openai/clip-vit-base-patch32"
device: str = "auto"
batch_size: int = 8
@dataclass(frozen=True)
class IndexConfig:
"""Parameters for vector indexes."""
metric: str = "angular"
n_trees: int = 64
search_k: int | EllipsisType = ...
top_k: int = 10
@dataclass(frozen=True)
class RetrievalConfig:
"""Configuration dataclass grouping all project level settings."""
paths: Paths = field(default_factory=Paths)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
models: ModelConfig = field(default_factory=ModelConfig)
index: IndexConfig = field(default_factory=IndexConfig)
def prepare(self) -> None:
self.paths.ensure()
CONFIG = RetrievalConfig()