| from dataclasses import dataclass, field |
| from pathlib import Path |
| import os |
|
|
|
|
| def _load_dotenv(dotenv_path: Path) -> None: |
| if not dotenv_path.exists(): |
| return |
|
|
| for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines(): |
| line = raw_line.strip() |
| if not line or line.startswith("#") or "=" not in line: |
| continue |
|
|
| key, value = line.split("=", 1) |
| key = key.strip() |
| value = value.strip().strip('"').strip("'") |
| os.environ.setdefault(key, value) |
|
|
|
|
| def _get_env(name: str, default: str, aliases: tuple[str, ...] = ()) -> str: |
| for key in (name, *aliases): |
| value = os.getenv(key) |
| if value is not None and value != "": |
| return value |
| return default |
|
|
|
|
| def _to_int(value: str, default: int) -> int: |
| try: |
| return int(value) |
| except (TypeError, ValueError): |
| return default |
|
|
|
|
| def _to_float(value: str, default: float) -> float: |
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return default |
|
|
|
|
| def _to_bool(value: str, default: bool) -> bool: |
| if value is None: |
| return default |
| normalized = value.strip().lower() |
| if normalized in {"1", "true", "yes", "on"}: |
| return True |
| if normalized in {"0", "false", "no", "off"}: |
| return False |
| return default |
|
|
|
|
| _BASE_DIR = Path(__file__).resolve().parent |
| _load_dotenv(_BASE_DIR / ".env") |
|
|
|
|
| @dataclass |
| class Settings: |
| app_title: str = _get_env("APP_TITLE", "RAG API") |
| model_id: str = _get_env("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct", aliases=("MODEL_NAME",)) |
| embedding_model_id: str = _get_env( |
| "EMBEDDING_MODEL_ID", |
| "hkunlp/instructor-base", |
| aliases=("EMBEDDING_MODEL",), |
| ) |
|
|
| models_dir: str = _get_env("MODELS_DIR", "Models") |
| vector_db_file: str = _get_env("VECTOR_DB_FILE", "vector_db.index", aliases=("VECTOR_STORE_PATH",)) |
| chunks_file: str = _get_env("CHUNKS_FILE", "chunks.pkl") |
| hf_assets_repo_id: str = _get_env("HF_ASSETS_REPO_ID", "Pujan-Dev/faiss_emb") |
| hf_assets_subdir: str = _get_env("HF_ASSETS_SUBDIR", "") |
| allow_hf_assets_download: bool = _to_bool(_get_env("ALLOW_HF_ASSETS_DOWNLOAD", "true"), True) |
|
|
| retrieval_instruction: str = _get_env( |
| "RETRIEVAL_INSTRUCTION", |
| "Represent the question for retrieving relevant documents", |
| ) |
|
|
| max_context_tokens: int = _to_int(_get_env("MAX_CONTEXT_TOKENS", "2048"), 2048) |
| max_new_tokens: int = _to_int(_get_env("MAX_NEW_TOKENS", "220"), 220) |
| max_chars_per_chunk: int = _to_int(_get_env("MAX_CHARS_PER_CHUNK", "1400"), 1400) |
| do_sample: bool = _to_bool(_get_env("DO_SAMPLE", "false"), False) |
| temperature: float = _to_float(_get_env("TEMPERATURE", "0.3"), 0.3) |
| repetition_penalty: float = _to_float(_get_env("REPETITION_PENALTY", "1.3"), 1.3) |
|
|
| default_top_k: int = _to_int(_get_env("DEFAULT_TOP_K", "3"), 3) |
| min_top_k: int = _to_int(_get_env("MIN_TOP_K", "1"), 1) |
| max_top_k: int = _to_int(_get_env("MAX_TOP_K", "10"), 10) |
|
|
| host: str = _get_env("HOST", "0.0.0.0", aliases=("API_HOST",)) |
| port: int = _to_int(_get_env("PORT", "8000", aliases=("API_PORT",)), 8000) |
|
|
| base_dir: Path = field(default_factory=lambda: _BASE_DIR) |
|
|
| @property |
| def data_search_roots(self) -> list[Path]: |
| models_path = Path(self.models_dir) |
| return [ |
| self.base_dir / models_path, |
| self.base_dir, |
| self.base_dir.parent / models_path, |
| self.base_dir.parent / "RAG_pipeline" / models_path, |
| self.base_dir.parent / "RAG_pipeline", |
| self.base_dir.parent, |
| ] |
|
|
|
|
| settings = Settings() |
|
|