from dataclasses import dataclass, field from pathlib import Path import os def _load_dotenv(dotenv_path: Path) -> None: if not dotenv_path.exists(): return for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) key = key.strip() value = value.strip().strip('"').strip("'") os.environ.setdefault(key, value) def _get_env(name: str, default: str, aliases: tuple[str, ...] = ()) -> str: for key in (name, *aliases): value = os.getenv(key) if value is not None and value != "": return value return default def _to_int(value: str, default: int) -> int: try: return int(value) except (TypeError, ValueError): return default def _to_float(value: str, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default def _to_bool(value: str, default: bool) -> bool: if value is None: return default normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False return default _BASE_DIR = Path(__file__).resolve().parent _load_dotenv(_BASE_DIR / ".env") @dataclass class Settings: app_title: str = _get_env("APP_TITLE", "RAG API") model_id: str = _get_env("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct", aliases=("MODEL_NAME",)) embedding_model_id: str = _get_env( "EMBEDDING_MODEL_ID", "hkunlp/instructor-base", aliases=("EMBEDDING_MODEL",), ) models_dir: str = _get_env("MODELS_DIR", "Models") vector_db_file: str = _get_env("VECTOR_DB_FILE", "vector_db.index", aliases=("VECTOR_STORE_PATH",)) chunks_file: str = _get_env("CHUNKS_FILE", "chunks.pkl") hf_assets_repo_id: str = _get_env("HF_ASSETS_REPO_ID", "Pujan-Dev/faiss_emb") hf_assets_subdir: str = _get_env("HF_ASSETS_SUBDIR", "") allow_hf_assets_download: bool = _to_bool(_get_env("ALLOW_HF_ASSETS_DOWNLOAD", "true"), True) retrieval_instruction: str = _get_env( "RETRIEVAL_INSTRUCTION", "Represent the question for retrieving relevant documents", ) max_context_tokens: int = _to_int(_get_env("MAX_CONTEXT_TOKENS", "2048"), 2048) max_new_tokens: int = _to_int(_get_env("MAX_NEW_TOKENS", "220"), 220) max_chars_per_chunk: int = _to_int(_get_env("MAX_CHARS_PER_CHUNK", "1400"), 1400) do_sample: bool = _to_bool(_get_env("DO_SAMPLE", "false"), False) temperature: float = _to_float(_get_env("TEMPERATURE", "0.3"), 0.3) repetition_penalty: float = _to_float(_get_env("REPETITION_PENALTY", "1.3"), 1.3) default_top_k: int = _to_int(_get_env("DEFAULT_TOP_K", "3"), 3) min_top_k: int = _to_int(_get_env("MIN_TOP_K", "1"), 1) max_top_k: int = _to_int(_get_env("MAX_TOP_K", "10"), 10) host: str = _get_env("HOST", "0.0.0.0", aliases=("API_HOST",)) port: int = _to_int(_get_env("PORT", "8000", aliases=("API_PORT",)), 8000) base_dir: Path = field(default_factory=lambda: _BASE_DIR) @property def data_search_roots(self) -> list[Path]: models_path = Path(self.models_dir) return [ self.base_dir / models_path, self.base_dir, self.base_dir.parent / models_path, self.base_dir.parent / "RAG_pipeline" / models_path, self.base_dir.parent / "RAG_pipeline", self.base_dir.parent, ] settings = Settings()