Spaces:
Sleeping
Sleeping
Update src/config.py
Browse files- src/config.py +78 -78
src/config.py
CHANGED
|
@@ -1,78 +1,78 @@
|
|
| 1 |
-
"""Global configuration for the multimodal retrieval MVP."""
|
| 2 |
-
from dataclasses import dataclass, field
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from types import EllipsisType
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass(frozen=True)
|
| 8 |
-
class Paths:
|
| 9 |
-
"""Paths used across the project."""
|
| 10 |
-
|
| 11 |
-
root: Path = Path(__file__).resolve().parents[1]
|
| 12 |
-
cache_dir: Path = root / "cache"
|
| 13 |
-
embeddings_dir: Path = root / "artifacts" / "embeddings"
|
| 14 |
-
indexes_dir: Path = root / "artifacts" / "indexes"
|
| 15 |
-
omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"
|
| 16 |
-
|
| 17 |
-
def ensure(self) -> None:
|
| 18 |
-
for path in [
|
| 19 |
-
self.cache_dir,
|
| 20 |
-
self.embeddings_dir,
|
| 21 |
-
self.indexes_dir,
|
| 22 |
-
self.omni_metadata_path.parent,
|
| 23 |
-
]:
|
| 24 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
@dataclass(frozen=True)
|
| 28 |
-
class DatasetConfig:
|
| 29 |
-
"""Dataset parameters."""
|
| 30 |
-
|
| 31 |
-
name: str = "huggan/wikiart"
|
| 32 |
-
split: str = "train"
|
| 33 |
-
streaming: bool = True
|
| 34 |
-
seed: int = 42
|
| 35 |
-
sample_size: int = 5000
|
| 36 |
-
shuffle_buffer: int = 2048
|
| 37 |
-
image_column: str = "image"
|
| 38 |
-
id_column: str = "id"
|
| 39 |
-
artist_column: str = "artist"
|
| 40 |
-
style_column: str = "style"
|
| 41 |
-
genre_column: str = "genre"
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@dataclass(frozen=True)
|
| 45 |
-
class ModelConfig:
|
| 46 |
-
"""Model identifiers and hyper-parameters."""
|
| 47 |
-
|
| 48 |
-
image_encoder: str = "openai/clip-vit-base-patch32"
|
| 49 |
-
caption_model: str = "Salesforce/blip-image-captioning-large"
|
| 50 |
-
vlm_model: str = "openai/clip-vit-base-patch32"
|
| 51 |
-
device: str = "
|
| 52 |
-
batch_size: int = 8
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
@dataclass(frozen=True)
|
| 56 |
-
class IndexConfig:
|
| 57 |
-
"""Parameters for vector indexes."""
|
| 58 |
-
|
| 59 |
-
metric: str = "angular"
|
| 60 |
-
n_trees: int = 64
|
| 61 |
-
search_k: int | EllipsisType = ...
|
| 62 |
-
top_k: int = 10
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
@dataclass(frozen=True)
|
| 66 |
-
class RetrievalConfig:
|
| 67 |
-
"""Configuration dataclass grouping all project level settings."""
|
| 68 |
-
|
| 69 |
-
paths: Paths = field(default_factory=Paths)
|
| 70 |
-
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
| 71 |
-
models: ModelConfig = field(default_factory=ModelConfig)
|
| 72 |
-
index: IndexConfig = field(default_factory=IndexConfig)
|
| 73 |
-
|
| 74 |
-
def prepare(self) -> None:
|
| 75 |
-
self.paths.ensure()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
CONFIG = RetrievalConfig()
|
|
|
|
| 1 |
+
"""Global configuration for the multimodal retrieval MVP."""
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from types import EllipsisType
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class Paths:
|
| 9 |
+
"""Paths used across the project."""
|
| 10 |
+
|
| 11 |
+
root: Path = Path(__file__).resolve().parents[1]
|
| 12 |
+
cache_dir: Path = root / "cache"
|
| 13 |
+
embeddings_dir: Path = root / "artifacts" / "embeddings"
|
| 14 |
+
indexes_dir: Path = root / "artifacts" / "indexes"
|
| 15 |
+
omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"
|
| 16 |
+
|
| 17 |
+
def ensure(self) -> None:
|
| 18 |
+
for path in [
|
| 19 |
+
self.cache_dir,
|
| 20 |
+
self.embeddings_dir,
|
| 21 |
+
self.indexes_dir,
|
| 22 |
+
self.omni_metadata_path.parent,
|
| 23 |
+
]:
|
| 24 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class DatasetConfig:
|
| 29 |
+
"""Dataset parameters."""
|
| 30 |
+
|
| 31 |
+
name: str = "huggan/wikiart"
|
| 32 |
+
split: str = "train"
|
| 33 |
+
streaming: bool = True
|
| 34 |
+
seed: int = 42
|
| 35 |
+
sample_size: int = 5000
|
| 36 |
+
shuffle_buffer: int = 2048
|
| 37 |
+
image_column: str = "image"
|
| 38 |
+
id_column: str = "id"
|
| 39 |
+
artist_column: str = "artist"
|
| 40 |
+
style_column: str = "style"
|
| 41 |
+
genre_column: str = "genre"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class ModelConfig:
|
| 46 |
+
"""Model identifiers and hyper-parameters."""
|
| 47 |
+
|
| 48 |
+
image_encoder: str = "openai/clip-vit-base-patch32"
|
| 49 |
+
caption_model: str = "Salesforce/blip-image-captioning-large"
|
| 50 |
+
vlm_model: str = "openai/clip-vit-base-patch32"
|
| 51 |
+
device: str = "auto"
|
| 52 |
+
batch_size: int = 8
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass(frozen=True)
|
| 56 |
+
class IndexConfig:
|
| 57 |
+
"""Parameters for vector indexes."""
|
| 58 |
+
|
| 59 |
+
metric: str = "angular"
|
| 60 |
+
n_trees: int = 64
|
| 61 |
+
search_k: int | EllipsisType = ...
|
| 62 |
+
top_k: int = 10
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass(frozen=True)
|
| 66 |
+
class RetrievalConfig:
|
| 67 |
+
"""Configuration dataclass grouping all project level settings."""
|
| 68 |
+
|
| 69 |
+
paths: Paths = field(default_factory=Paths)
|
| 70 |
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
| 71 |
+
models: ModelConfig = field(default_factory=ModelConfig)
|
| 72 |
+
index: IndexConfig = field(default_factory=IndexConfig)
|
| 73 |
+
|
| 74 |
+
def prepare(self) -> None:
|
| 75 |
+
self.paths.ensure()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
CONFIG = RetrievalConfig()
|