File size: 3,659 Bytes
79a28b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from dataclasses import dataclass, field
from pathlib import Path
import os


def _load_dotenv(dotenv_path: Path) -> None:
    if not dotenv_path.exists():
        return

    for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
        line = raw_line.strip()
        if not line or line.startswith("#") or "=" not in line:
            continue

        key, value = line.split("=", 1)
        key = key.strip()
        value = value.strip().strip('"').strip("'")
        os.environ.setdefault(key, value)


def _get_env(name: str, default: str, aliases: tuple[str, ...] = ()) -> str:
    for key in (name, *aliases):
        value = os.getenv(key)
        if value is not None and value != "":
            return value
    return default


def _to_int(value: str, default: int) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def _to_float(value: str, default: float) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def _to_bool(value: str, default: bool) -> bool:
    if value is None:
        return default
    normalized = value.strip().lower()
    if normalized in {"1", "true", "yes", "on"}:
        return True
    if normalized in {"0", "false", "no", "off"}:
        return False
    return default


_BASE_DIR = Path(__file__).resolve().parent
_load_dotenv(_BASE_DIR / ".env")


@dataclass
class Settings:
    app_title: str = _get_env("APP_TITLE", "RAG API")
    model_id: str = _get_env("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct", aliases=("MODEL_NAME",))
    embedding_model_id: str = _get_env(
        "EMBEDDING_MODEL_ID",
        "hkunlp/instructor-base",
        aliases=("EMBEDDING_MODEL",),
    )

    models_dir: str = _get_env("MODELS_DIR", "Models")
    vector_db_file: str = _get_env("VECTOR_DB_FILE", "vector_db.index", aliases=("VECTOR_STORE_PATH",))
    chunks_file: str = _get_env("CHUNKS_FILE", "chunks.pkl")
    hf_assets_repo_id: str = _get_env("HF_ASSETS_REPO_ID", "Pujan-Dev/faiss_emb")
    hf_assets_subdir: str = _get_env("HF_ASSETS_SUBDIR", "")
    allow_hf_assets_download: bool = _to_bool(_get_env("ALLOW_HF_ASSETS_DOWNLOAD", "true"), True)

    retrieval_instruction: str = _get_env(
        "RETRIEVAL_INSTRUCTION",
        "Represent the question for retrieving relevant documents",
    )

    max_context_tokens: int = _to_int(_get_env("MAX_CONTEXT_TOKENS", "2048"), 2048)
    max_new_tokens: int = _to_int(_get_env("MAX_NEW_TOKENS", "220"), 220)
    max_chars_per_chunk: int = _to_int(_get_env("MAX_CHARS_PER_CHUNK", "1400"), 1400)
    do_sample: bool = _to_bool(_get_env("DO_SAMPLE", "false"), False)
    temperature: float = _to_float(_get_env("TEMPERATURE", "0.3"), 0.3)
    repetition_penalty: float = _to_float(_get_env("REPETITION_PENALTY", "1.3"), 1.3)

    default_top_k: int = _to_int(_get_env("DEFAULT_TOP_K", "3"), 3)
    min_top_k: int = _to_int(_get_env("MIN_TOP_K", "1"), 1)
    max_top_k: int = _to_int(_get_env("MAX_TOP_K", "10"), 10)

    host: str = _get_env("HOST", "0.0.0.0", aliases=("API_HOST",))
    port: int = _to_int(_get_env("PORT", "8000", aliases=("API_PORT",)), 8000)

    base_dir: Path = field(default_factory=lambda: _BASE_DIR)

    @property
    def data_search_roots(self) -> list[Path]:
        models_path = Path(self.models_dir)
        return [
            self.base_dir / models_path,
            self.base_dir,
            self.base_dir.parent / models_path,
            self.base_dir.parent / "RAG_pipeline" / models_path,
            self.base_dir.parent / "RAG_pipeline",
            self.base_dir.parent,
        ]


settings = Settings()