Document-Audit-RAG / api /config.py
mayankchugh-learning
Default LLM to Hugging Face on Spaces; fix HF_TOKEN settings merge
358882c
import os
from functools import lru_cache
from typing import Any, Self
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
case_sensitive=False,
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def _map_max_upload_env_alias(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
out = dict(data)
if out.get("max_file_size_mb") in (None, "") and out.get("max_upload_size_mb") not in (None, ""):
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
elif "max_upload_size_mb" in out and "max_file_size_mb" not in out:
out["max_file_size_mb"] = out.pop("max_upload_size_mb")
return out
app_name: str = Field(default="DocuAudit AI", description="FastAPI title and product name")
app_version: str = Field(default="1.0.0", description="Application version")
app_description: str = Field(
default=(
"Multi-document RAG API for high-stakes consulting environments. "
"Every answer is grounded in source documents with full audit trails."
),
description="OpenAPI /docs description",
)
llm_provider: str = Field(default="ollama", description="Embedding provider")
openai_api_key: str | None = Field(default=None, description="OpenAI API key")
openai_model: str = "gpt-4o"
openai_embedding_model: str = "text-embedding-3-small"
anthropic_api_key: str = ""
anthropic_model: str = "claude-3-5-sonnet-20241022"
huggingface_api_key: str = ""
huggingface_model: str = Field(
default="meta-llama/Meta-Llama-3-8B-Instruct",
description=(
"HF chat model id (use a repo your Hub account already has access to; Llama 3.1 needs the "
"separate Llama 3.1 gate). Chat tries hf-inference then router auto when unset."
),
)
huggingface_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
huggingface_inference_provider: str | None = Field(
default=None,
description=(
"Optional huggingface_hub InferenceClient provider (e.g. hf-inference, together). "
"Unset uses hf-inference in chat code; set to `auto` for router auto-routing."
),
)
ollama_base_url: str = Field(default="http://localhost:11434", description="Ollama base URL")
ollama_chat_model: str = "llama3.1:8b"
ollama_embedding_model: str = "nomic-embed-text"
chroma_persist_directory: str = Field(default="./data/chroma", description="Chroma persistence path")
chroma_persist_dir: str = Field(default="./chroma", description="Chroma persistence path")
chroma_collection_name: str = "docuaudit_docs"
chunk_size: int = Field(default=1000, ge=100, le=8000, description="Chunk size for splitting")
chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Chunk overlap for splitting")
top_k_results: int = Field(default=5, ge=1, le=20, description="Default number of chunks to retrieve")
audit_db_path: str = "./audit.db"
jobs_db_path: str = Field(default="./data/jobs.db", description="SQLite path for ingest job tracking")
max_file_size_mb: int = Field(default=50, ge=1, le=200, description="Max upload file size (MB)")
max_documents_per_batch: int = Field(default=100, ge=1, le=1000, description="Max documents per batch")
@model_validator(mode="after")
def _space_default_llm_provider(self) -> Self:
"""Hugging Face Spaces do not run Ollama locally; use Hub inference unless the user set LLM_PROVIDER."""
if not (os.environ.get("SPACE_ID") or "").strip():
return self
if "LLM_PROVIDER" in os.environ:
return self
if self.llm_provider.lower() != "ollama":
return self
self.llm_provider = "huggingface"
return self
@model_validator(mode="after")
def _huggingface_token_from_hub_env(self) -> Self:
"""When using the Hugging Face inference stack, accept the Hub token from standard env names.
Spaces often expose `HF_TOKEN` (read/write per Space secrets). Map it into `huggingface_api_key`
when `HUGGINGFACE_API_KEY` is unset so embedder/chat clients receive a token.
"""
if self.llm_provider.lower() != "huggingface":
return self
if (self.huggingface_api_key or "").strip():
return self
for key in ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"):
token = (os.environ.get(key) or "").strip()
if token:
self.huggingface_api_key = token
break
return self
@lru_cache
def get_settings() -> Settings:
return Settings()