File size: 2,168 Bytes
08a811f
 
 
 
 
 
 
 
 
2cec50c
 
 
 
08a811f
 
 
 
85020ae
08a811f
ad9e267
08a811f
 
 
 
 
 
 
 
 
ad9e267
08a811f
 
ad9e267
08a811f
 
936bc6b
08a811f
 
 
 
ad9e267
08a811f
 
936bc6b
 
08a811f
 
936bc6b
08a811f
 
936bc6b
08a811f
 
936bc6b
08a811f
 
 
 
 
 
ad9e267
08a811f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

import os
from functools import lru_cache
from pathlib import Path
from typing import Literal, Optional

from dotenv import load_dotenv
from pydantic import BaseModel, Field

# Explicitly load .env from the project root so it works regardless of CWD
# (e.g. when imported from a Kaggle notebook whose CWD is /kaggle/working/)
_PROJECT_ROOT = Path(__file__).resolve().parents[1]
load_dotenv(_PROJECT_ROOT / ".env")


class Settings(BaseModel):
    """
    All configuration for AMR-Guard, read from environment variables.

    Supports three deployment targets via MEDIC_ENV: local, kaggle, production.
    """

    environment: Literal["local", "kaggle", "production"] = Field(
        default_factory=lambda: os.getenv("MEDIC_ENV", "local")
    )
    project_root: Path = Field(
        default_factory=lambda: Path(__file__).resolve().parents[1]
    )
    data_dir: Path = Field(
        default_factory=lambda: Path(os.getenv("MEDIC_DATA_DIR", "data"))
    )
    chroma_db_dir: Path = Field(
        default_factory=lambda: Path(os.getenv("MEDIC_CHROMA_DB_DIR", "data/chroma_db"))
    )

    # 4-bit quantization via bitsandbytes
    quantization: Literal["none", "4bit"] = Field(
        default_factory=lambda: os.getenv("MEDIC_QUANTIZATION", "4bit")  # type: ignore[arg-type]
    )
    embedding_model_name: str = Field(
        default_factory=lambda: os.getenv("MEDIC_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
    )

    # Local HuggingFace model paths
    medgemma_4b_model: Optional[str] = Field(
        default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_4B_MODEL")
    )
    medgemma_27b_model: Optional[str] = Field(
        default_factory=lambda: os.getenv("MEDIC_LOCAL_MEDGEMMA_27B_MODEL")
    )
    txgemma_9b_model: Optional[str] = Field(
        default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_9B_MODEL")
    )
    txgemma_2b_model: Optional[str] = Field(
        default_factory=lambda: os.getenv("MEDIC_LOCAL_TXGEMMA_2B_MODEL")
    )


@lru_cache(maxsize=1)
def get_settings() -> Settings:
    """Return the cached Settings singleton. Import this instead of instantiating Settings directly."""
    return Settings()