Babu Pallam
Add configuration management for modular RAG pipeline
6d697d8
Raw
History Blame Contribute Delete
8.83 kB
# ============================================================
# FILE: src/config.py
# ============================================================
# PURPOSE:
# Load all project configuration from environment variables.
#
# WHY THIS FILE EXISTS:
# In a production AI application, configuration should not be
# scattered across multiple files.
#
# This file centralizes:
# - cloud API settings
# - model settings
# - embedding settings
# - chunking settings
# - folder paths
# - RAG behavior flags
# ============================================================
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
def get_project_root() -> Path:
"""
Determine the project root directory.
This function uses heuristics to find the project root in different environments:
- If /app exists, use it (common in Docker).
- If the current directory is "notebooks", assume the parent is the project root.
- Otherwise, use the current working directory.
"""
current_dir = Path.cwd()
# Docker-friendly default.
if Path("/app").exists():
return Path("/app")
# Notebook-friendly behavior.
if current_dir.name == "notebooks":
return current_dir.parent
return current_dir
def _get_env_str(name: str, default: str = "") -> str:
"""
Read a string from environment variables, with optional default.
Example:
CLOUD_API_PROVIDER=clod
Output:
"clod" (as a string, not None)
"""
return os.getenv(name, default).strip()
def _get_env_int(name: str, default: int) -> int:
"""
Read an integer from environment variables.
Example:
CLOUD_TIMEOUT_SECONDS=60
Output:
60 (as an integer, not a string)
"""
value = os.getenv(name)
if value is None or value.strip() == "":
return default
return int(value)
def _get_env_float(name: str, default: float) -> float:
"""
Read a float from environment variables.
Example:
CLOUD_TEMPERATURE=0.2
Output:
0.2 (as a float, not a string)
"""
value = os.getenv(name)
if value is None or value.strip() == "":
return default
return float(value)
def _get_env_bool(name: str, default: bool) -> bool:
"""
Read a boolean from environment variables.
Accepts:
true, yes, 1, y
false, no, 0, n
"""
value = os.getenv(name)
if value is None or value.strip() == "":
return default
return value.strip().lower() in {"true", "yes", "1", "y"}
@dataclass
class AppConfig:
"""
AppConfig stores all settings required by the RAG application.
Using a dataclass makes configuration:
- easy to read
- easy to test
- easy to pass between modules
- easier to refactor later
"""
# Project folders
project_root: Path
data_folder: Path
vector_db_folder: Path
outputs_folder: Path
logs_folder: Path
# Cloud provider settings
cloud_api_provider: str
cloud_api_format: str
cloud_api_base_url: str
cloud_chat_completions_path: str
cloud_chat_completions_url: str
cloud_api_key: str
cloud_auth_header: str
cloud_auth_prefix: str
# Model generation settings
cloud_chat_model: str
cloud_temperature: float
cloud_max_completion_tokens: int
cloud_timeout_seconds: int
cloud_max_retries: int
cloud_retry_sleep_seconds: float
# Embedding settings
embedding_model_name: str
embedding_device: str
# Chunking and retrieval settings
chunk_size: int
chunk_overlap: int
top_k: int
collection_name: str
# RAG behavior
require_context_for_answer: bool
prompt_template_version: str
def load_config(env_file: Optional[Path] = None, override: bool = True) -> AppConfig:
"""
Load configuration from .env and operating system environment variables.
Local development:
- values usually come from .env
Production:
- values usually come from deployment environment variables
- example: Hugging Face Secrets, Docker secrets, Kubernetes secrets
override=True:
- useful during development
- reloads updated .env values inside notebooks or interactive sessions
"""
project_root = get_project_root()
if env_file is None:
env_file = project_root / ".env"
# Loading .env is optional.
# In production, secrets may already exist as environment variables.
if env_file.exists():
load_dotenv(env_file, override=override)
data_folder = project_root / _get_env_str("DATA_FOLDER", "data/raw")
vector_db_folder = project_root / _get_env_str("VECTOR_DB_FOLDER", "vector_db/chroma")
outputs_folder = project_root / "outputs"
logs_folder = project_root / "logs"
cloud_api_base_url = _get_env_str("CLOUD_API_BASE_URL", "https://api.clod.io/v1").rstrip("/")
cloud_chat_completions_path = _get_env_str("CLOUD_CHAT_COMPLETIONS_PATH", "/chat/completions")
if not cloud_chat_completions_path.startswith("/"):
cloud_chat_completions_path = "/" + cloud_chat_completions_path
full_url_override = _get_env_str("CLOUD_CHAT_COMPLETIONS_URL", "")
if full_url_override:
cloud_chat_completions_url = full_url_override
else:
cloud_chat_completions_url = cloud_api_base_url + cloud_chat_completions_path
config = AppConfig(
project_root=project_root,
data_folder=data_folder,
vector_db_folder=vector_db_folder,
outputs_folder=outputs_folder,
logs_folder=logs_folder,
cloud_api_provider=_get_env_str("CLOUD_API_PROVIDER", "clod"),
cloud_api_format=_get_env_str("CLOUD_API_FORMAT", "openai_chat_completions"),
cloud_api_base_url=cloud_api_base_url,
cloud_chat_completions_path=cloud_chat_completions_path,
cloud_chat_completions_url=cloud_chat_completions_url,
cloud_api_key=_get_env_str("CLOUD_API_KEY", ""),
cloud_auth_header=_get_env_str("CLOUD_AUTH_HEADER", "Authorization"),
cloud_auth_prefix=_get_env_str("CLOUD_AUTH_PREFIX", "Bearer"),
cloud_chat_model=_get_env_str("CLOUD_CHAT_MODEL", "Gemma 4 31B IT"),
cloud_temperature=_get_env_float("CLOUD_TEMPERATURE", 0.2),
cloud_max_completion_tokens=_get_env_int("CLOUD_MAX_COMPLETION_TOKENS", 700),
cloud_timeout_seconds=_get_env_int("CLOUD_TIMEOUT_SECONDS", 60),
cloud_max_retries=_get_env_int("CLOUD_MAX_RETRIES", 3),
cloud_retry_sleep_seconds=_get_env_float("CLOUD_RETRY_SLEEP_SECONDS", 2.0),
embedding_model_name=_get_env_str(
"EMBEDDING_MODEL_NAME",
"sentence-transformers/all-MiniLM-L6-v2",
),
embedding_device=_get_env_str("EMBEDDING_DEVICE", "cpu"),
chunk_size=_get_env_int("CHUNK_SIZE", 900),
chunk_overlap=_get_env_int("CHUNK_OVERLAP", 120),
top_k=_get_env_int("TOP_K", 4),
collection_name=_get_env_str("COLLECTION_NAME", "knowflow_ai_documents"),
require_context_for_answer=_get_env_bool("REQUIRE_CONTEXT_FOR_ANSWER", True),
prompt_template_version=_get_env_str("PROMPT_TEMPLATE_VERSION", "rag_v1.0"),
)
return config
def validate_config(config: AppConfig, require_api_key: bool = True) -> None:
"""
Validate important configuration values.
AI ENGINEER PRODUCTION TIP:
Fail fast.
A clear error at application startup is better than a confusing error
after the user uploads documents or asks a question.
"""
if config.chunk_overlap >= config.chunk_size:
raise ValueError("CHUNK_OVERLAP must be smaller than CHUNK_SIZE.")
if config.top_k <= 0:
raise ValueError("TOP_K must be greater than 0.")
if not config.cloud_chat_completions_url.startswith("http"):
raise ValueError("Cloud API URL must start with http or https.")
if config.cloud_api_format != "openai_chat_completions":
raise ValueError(
"This Phase 2 implementation supports only OpenAI-compatible chat completions."
)
if require_api_key:
if not config.cloud_api_key:
raise ValueError("CLOUD_API_KEY is missing.")
if config.cloud_api_key.startswith("replace_with"):
raise ValueError("CLOUD_API_KEY still contains a placeholder value.")
def create_required_folders(config: AppConfig) -> None:
"""
Create required runtime folders.
These folders are intentionally not tracked by Git if they contain
generated files.
"""
config.data_folder.mkdir(parents=True, exist_ok=True)
config.vector_db_folder.mkdir(parents=True, exist_ok=True)
config.outputs_folder.mkdir(parents=True, exist_ok=True)
config.logs_folder.mkdir(parents=True, exist_ok=True)