rag / config.py
pratyoos's picture
Upload 6 files
79a28b6 verified
from dataclasses import dataclass, field
from pathlib import Path
import os
def _load_dotenv(dotenv_path: Path) -> None:
if not dotenv_path.exists():
return
for raw_line in dotenv_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip('"').strip("'")
os.environ.setdefault(key, value)
def _get_env(name: str, default: str, aliases: tuple[str, ...] = ()) -> str:
for key in (name, *aliases):
value = os.getenv(key)
if value is not None and value != "":
return value
return default
def _to_int(value: str, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def _to_float(value: str, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _to_bool(value: str, default: bool) -> bool:
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
_BASE_DIR = Path(__file__).resolve().parent
_load_dotenv(_BASE_DIR / ".env")
@dataclass
class Settings:
app_title: str = _get_env("APP_TITLE", "RAG API")
model_id: str = _get_env("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct", aliases=("MODEL_NAME",))
embedding_model_id: str = _get_env(
"EMBEDDING_MODEL_ID",
"hkunlp/instructor-base",
aliases=("EMBEDDING_MODEL",),
)
models_dir: str = _get_env("MODELS_DIR", "Models")
vector_db_file: str = _get_env("VECTOR_DB_FILE", "vector_db.index", aliases=("VECTOR_STORE_PATH",))
chunks_file: str = _get_env("CHUNKS_FILE", "chunks.pkl")
hf_assets_repo_id: str = _get_env("HF_ASSETS_REPO_ID", "Pujan-Dev/faiss_emb")
hf_assets_subdir: str = _get_env("HF_ASSETS_SUBDIR", "")
allow_hf_assets_download: bool = _to_bool(_get_env("ALLOW_HF_ASSETS_DOWNLOAD", "true"), True)
retrieval_instruction: str = _get_env(
"RETRIEVAL_INSTRUCTION",
"Represent the question for retrieving relevant documents",
)
max_context_tokens: int = _to_int(_get_env("MAX_CONTEXT_TOKENS", "2048"), 2048)
max_new_tokens: int = _to_int(_get_env("MAX_NEW_TOKENS", "220"), 220)
max_chars_per_chunk: int = _to_int(_get_env("MAX_CHARS_PER_CHUNK", "1400"), 1400)
do_sample: bool = _to_bool(_get_env("DO_SAMPLE", "false"), False)
temperature: float = _to_float(_get_env("TEMPERATURE", "0.3"), 0.3)
repetition_penalty: float = _to_float(_get_env("REPETITION_PENALTY", "1.3"), 1.3)
default_top_k: int = _to_int(_get_env("DEFAULT_TOP_K", "3"), 3)
min_top_k: int = _to_int(_get_env("MIN_TOP_K", "1"), 1)
max_top_k: int = _to_int(_get_env("MAX_TOP_K", "10"), 10)
host: str = _get_env("HOST", "0.0.0.0", aliases=("API_HOST",))
port: int = _to_int(_get_env("PORT", "8000", aliases=("API_PORT",)), 8000)
base_dir: Path = field(default_factory=lambda: _BASE_DIR)
@property
def data_search_roots(self) -> list[Path]:
models_path = Path(self.models_dir)
return [
self.base_dir / models_path,
self.base_dir,
self.base_dir.parent / models_path,
self.base_dir.parent / "RAG_pipeline" / models_path,
self.base_dir.parent / "RAG_pipeline",
self.base_dir.parent,
]
settings = Settings()