AlekMan commited on
Commit
81d99ab
·
verified ·
1 Parent(s): 61d3625

Update src/config.py

Browse files
Files changed (1) hide show
  1. src/config.py +78 -78
src/config.py CHANGED
@@ -1,78 +1,78 @@
1
- """Global configuration for the multimodal retrieval MVP."""
2
- from dataclasses import dataclass, field
3
- from pathlib import Path
4
- from types import EllipsisType
5
-
6
-
7
- @dataclass(frozen=True)
8
- class Paths:
9
- """Paths used across the project."""
10
-
11
- root: Path = Path(__file__).resolve().parents[1]
12
- cache_dir: Path = root / "cache"
13
- embeddings_dir: Path = root / "artifacts" / "embeddings"
14
- indexes_dir: Path = root / "artifacts" / "indexes"
15
- omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"
16
-
17
- def ensure(self) -> None:
18
- for path in [
19
- self.cache_dir,
20
- self.embeddings_dir,
21
- self.indexes_dir,
22
- self.omni_metadata_path.parent,
23
- ]:
24
- path.mkdir(parents=True, exist_ok=True)
25
-
26
-
27
- @dataclass(frozen=True)
28
- class DatasetConfig:
29
- """Dataset parameters."""
30
-
31
- name: str = "huggan/wikiart"
32
- split: str = "train"
33
- streaming: bool = True
34
- seed: int = 42
35
- sample_size: int = 5000
36
- shuffle_buffer: int = 2048
37
- image_column: str = "image"
38
- id_column: str = "id"
39
- artist_column: str = "artist"
40
- style_column: str = "style"
41
- genre_column: str = "genre"
42
-
43
-
44
- @dataclass(frozen=True)
45
- class ModelConfig:
46
- """Model identifiers and hyper-parameters."""
47
-
48
- image_encoder: str = "openai/clip-vit-base-patch32"
49
- caption_model: str = "Salesforce/blip-image-captioning-large"
50
- vlm_model: str = "openai/clip-vit-base-patch32"
51
- device: str = "cuda"
52
- batch_size: int = 8
53
-
54
-
55
- @dataclass(frozen=True)
56
- class IndexConfig:
57
- """Parameters for vector indexes."""
58
-
59
- metric: str = "angular"
60
- n_trees: int = 64
61
- search_k: int | EllipsisType = ...
62
- top_k: int = 10
63
-
64
-
65
- @dataclass(frozen=True)
66
- class RetrievalConfig:
67
- """Configuration dataclass grouping all project level settings."""
68
-
69
- paths: Paths = field(default_factory=Paths)
70
- dataset: DatasetConfig = field(default_factory=DatasetConfig)
71
- models: ModelConfig = field(default_factory=ModelConfig)
72
- index: IndexConfig = field(default_factory=IndexConfig)
73
-
74
- def prepare(self) -> None:
75
- self.paths.ensure()
76
-
77
-
78
- CONFIG = RetrievalConfig()
 
1
+ """Global configuration for the multimodal retrieval MVP."""
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from types import EllipsisType
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class Paths:
9
+ """Paths used across the project."""
10
+
11
+ root: Path = Path(__file__).resolve().parents[1]
12
+ cache_dir: Path = root / "cache"
13
+ embeddings_dir: Path = root / "artifacts" / "embeddings"
14
+ indexes_dir: Path = root / "artifacts" / "indexes"
15
+ omni_metadata_path: Path = root / "artifacts" / "datasets" / "omni_metadata.parquet"
16
+
17
+ def ensure(self) -> None:
18
+ for path in [
19
+ self.cache_dir,
20
+ self.embeddings_dir,
21
+ self.indexes_dir,
22
+ self.omni_metadata_path.parent,
23
+ ]:
24
+ path.mkdir(parents=True, exist_ok=True)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class DatasetConfig:
29
+ """Dataset parameters."""
30
+
31
+ name: str = "huggan/wikiart"
32
+ split: str = "train"
33
+ streaming: bool = True
34
+ seed: int = 42
35
+ sample_size: int = 5000
36
+ shuffle_buffer: int = 2048
37
+ image_column: str = "image"
38
+ id_column: str = "id"
39
+ artist_column: str = "artist"
40
+ style_column: str = "style"
41
+ genre_column: str = "genre"
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class ModelConfig:
46
+ """Model identifiers and hyper-parameters."""
47
+
48
+ image_encoder: str = "openai/clip-vit-base-patch32"
49
+ caption_model: str = "Salesforce/blip-image-captioning-large"
50
+ vlm_model: str = "openai/clip-vit-base-patch32"
51
+ device: str = "auto"
52
+ batch_size: int = 8
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class IndexConfig:
57
+ """Parameters for vector indexes."""
58
+
59
+ metric: str = "angular"
60
+ n_trees: int = 64
61
+ search_k: int | EllipsisType = ...
62
+ top_k: int = 10
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class RetrievalConfig:
67
+ """Configuration dataclass grouping all project level settings."""
68
+
69
+ paths: Paths = field(default_factory=Paths)
70
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
71
+ models: ModelConfig = field(default_factory=ModelConfig)
72
+ index: IndexConfig = field(default_factory=IndexConfig)
73
+
74
+ def prepare(self) -> None:
75
+ self.paths.ensure()
76
+
77
+
78
+ CONFIG = RetrievalConfig()