File size: 2,142 Bytes
81d99ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Global configuration for the multimodal retrieval MVP."""
from dataclasses import dataclass, field
from pathlib import Path
from types import EllipsisType


@dataclass(frozen=True)
class Paths:
    """Paths used across the project."""

    root: Path = Path(__file__).resolve().parents[1]
    cache_dir: Path = root / "cache"
    embeddings_dir: Path = root / "artifacts" / "embeddings"
    indexes_dir: Path = root / "artifacts" / "indexes"
    omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"

    def ensure(self) -> None:
        for path in [
            self.cache_dir,
            self.embeddings_dir,
            self.indexes_dir,
            self.omni_metadata_path.parent,
        ]:
            path.mkdir(parents=True, exist_ok=True)


@dataclass(frozen=True)
class DatasetConfig:
    """Dataset parameters."""

    name: str = "huggan/wikiart"
    split: str = "train"
    streaming: bool = True
    seed: int = 42
    sample_size: int = 5000
    shuffle_buffer: int = 2048
    image_column: str = "image"
    id_column: str = "id"
    artist_column: str = "artist"
    style_column: str = "style"
    genre_column: str = "genre"


@dataclass(frozen=True)
class ModelConfig:
    """Model identifiers and hyper-parameters."""

    image_encoder: str = "openai/clip-vit-base-patch32"
    caption_model: str = "Salesforce/blip-image-captioning-large"
    vlm_model: str = "openai/clip-vit-base-patch32"
    device: str = "auto"
    batch_size: int = 8


@dataclass(frozen=True)
class IndexConfig:
    """Parameters for vector indexes."""

    metric: str = "angular"
    n_trees: int = 64
    search_k: int | EllipsisType = ...
    top_k: int = 10


@dataclass(frozen=True)
class RetrievalConfig:
    """Configuration dataclass grouping all project level settings."""

    paths: Paths = field(default_factory=Paths)
    dataset: DatasetConfig = field(default_factory=DatasetConfig)
    models: ModelConfig = field(default_factory=ModelConfig)
    index: IndexConfig = field(default_factory=IndexConfig)

    def prepare(self) -> None:
        self.paths.ensure()


CONFIG = RetrievalConfig()