zeta / src /config /settings.py
rodrigo-moonray
Deploy zeta-only embeddings (NV-Embed-v2 + E5-small)
9b457ed
"""
Configuration management using Pydantic Settings.
This module provides type-safe access to environment variables and application settings.
"""
from pathlib import Path
from typing import Literal, Optional, Dict, Any, List
from pydantic import Field, validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Available LLM models configuration
LLM_MODELS: Dict[str, Dict[str, Any]] = {
"claude-sonnet": {
"name": "Claude Sonnet",
"model_id": "claude-sonnet-4-20250514",
"provider": "anthropic",
"description": "Fast, intelligent model for everyday tasks",
},
"claude-opus": {
"name": "Claude Opus",
"model_id": "claude-opus-4-20250514",
"provider": "anthropic",
"description": "Most capable model for complex reasoning",
},
"grok": {
"name": "Grok",
"model_id": "grok-beta",
"provider": "xai",
"description": "xAI's conversational model",
},
}
def get_llm_model_config(model_key: str) -> Dict[str, Any]:
"""Get configuration for an LLM model."""
if model_key not in LLM_MODELS:
raise ValueError(f"Unknown LLM model: {model_key}")
return LLM_MODELS[model_key]
def list_llm_models() -> List[Dict[str, Any]]:
"""List all available LLM models with their configurations."""
return [
{"id": model_key, **config}
for model_key, config in LLM_MODELS.items()
]
# Available embedding models configuration
EMBEDDING_MODELS: Dict[str, Dict[str, Any]] = {
"nvidia/NV-Embed-v2": {
"name": "NV-Embed-v2",
"dimensions": 4096,
"description": "NVIDIA's SOTA embedding model (72.31 MTEB, requires GPU, 17GB)",
"type": "nvembed",
"batch_size": 2,
"max_length": 4096,
},
"intfloat/e5-small-v2": {
"name": "E5-small",
"dimensions": 384,
"description": "CPU-optimized, fast (58 MTEB, 16ms latency, 130MB)",
"type": "sentence-transformers",
"batch_size": 64,
"max_length": 512,
},
}
def get_embedding_model_config(model_id: str) -> Dict[str, Any]:
"""Get configuration for an embedding model."""
if model_id not in EMBEDDING_MODELS:
raise ValueError(f"Unknown embedding model: {model_id}")
return EMBEDDING_MODELS[model_id]
def get_collection_name_for_model(model_id: str, base_name: str = "pdf_chunks") -> str:
"""Generate collection name based on embedding model."""
config = EMBEDDING_MODELS.get(model_id, {})
short_name = config.get("name", "unknown").lower().replace(" ", "_").replace("-", "_")
return f"{base_name}_{short_name}"
def list_embedding_models() -> List[Dict[str, Any]]:
"""List all available embedding models with their configurations."""
return [
{"id": model_id, **config}
for model_id, config in EMBEDDING_MODELS.items()
]
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env" if Path(".env").exists() else None,
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore"
)
# LLM API Keys
anthropic_api_key: str = Field(..., description="Anthropic API key for Claude")
grok_api_key: Optional[str] = Field(None, description="Grok API key")
grok_api_base: str = Field(
default="https://api.x.ai/v1",
description="Grok API base URL (OpenAI-compatible)"
)
# Optional API Keys
cohere_api_key: Optional[str] = Field(None, description="Cohere API key for reranking")
google_search_api_key: Optional[str] = Field(None, description="Google Custom Search API key")
google_search_engine_id: Optional[str] = Field(None, description="Google Custom Search Engine ID")
# Web Search API Keys
tavily_api_key: Optional[str] = Field(None, description="Tavily API key for web search")
# Scientific Search API Keys
semantic_scholar_api_key: Optional[str] = Field(None, description="Semantic Scholar API key (optional, improves rate limits)")
# Search Mode Settings
default_search_mode: Literal["local", "web", "scientific", "hybrid"] = Field(
default="local",
description="Default search mode"
)
web_search_provider: Literal["duckduckgo", "tavily"] = Field(
default="duckduckgo",
description="Preferred web search provider"
)
# Hybrid Search Weights
hybrid_local_weight: float = Field(
default=0.4,
ge=0.0,
le=1.0,
description="Weight for local document results in hybrid search"
)
hybrid_web_weight: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Weight for web search results in hybrid search"
)
hybrid_scientific_weight: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Weight for scientific paper results in hybrid search"
)
# Application Settings
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO",
description="Logging level"
)
cache_ttl: int = Field(default=3600, description="Cache TTL in seconds")
max_upload_size: int = Field(default=100, description="Max upload size in MB")
# ChromaDB Settings
chroma_persist_dir: str = Field(
default="./data/vectordb",
description="ChromaDB persistence directory"
)
chroma_collection_name: str = Field(
default="pdf_chunks",
description="ChromaDB collection name for PDF chunks"
)
# Embedding Model Settings
embedding_model: str = Field(
default="nvidia/NV-Embed-v2",
description="Embedding model (NV-Embed-v2 is SOTA on MTEB retrieval benchmarks)"
)
embedding_device: Literal["cpu", "cuda", "mps"] = Field(
default="mps",
description="Device for embedding computation (mps for Apple Silicon)"
)
embedding_batch_size: int = Field(
default=2,
description="Batch size for embedding generation (small for large models like NV-Embed)"
)
# Chunking Settings
chunk_size: int = Field(
default=800,
ge=100,
le=2000,
description="Target chunk size in tokens"
)
chunk_overlap: int = Field(
default=150,
ge=0,
le=500,
description="Overlap between chunks in tokens"
)
parent_chunk_size: int = Field(
default=1000,
ge=500,
le=5000,
description="Parent chunk size for hierarchical chunking"
)
# Retrieval Settings
top_k_retrieval: int = Field(
default=50,
ge=1,
le=200,
description="Number of chunks to retrieve initially"
)
top_k_rerank: int = Field(
default=15,
ge=1,
le=50,
description="Number of chunks after reranking"
)
retrieval_score_threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum similarity score for retrieval"
)
# Context Settings
max_context_tokens: int = Field(
default=100_000,
ge=1000,
le=200_000,
description="Maximum tokens for context (Claude supports 200K)"
)
max_response_tokens: int = Field(
default=4_000,
ge=100,
le=8_000,
description="Maximum tokens for LLM response"
)
# LLM Settings
default_llm: Literal["claude", "grok"] = Field(
default="claude",
description="Default LLM to use"
)
claude_model: str = Field(
default="claude-3-5-sonnet-20241022",
description="Claude model identifier"
)
grok_model: str = Field(
default="grok-beta",
description="Grok model identifier"
)
temperature: float = Field(
default=0.7,
ge=0.0,
le=2.0,
description="LLM temperature for generation"
)
max_retries: int = Field(
default=3,
ge=0,
le=10,
description="Maximum retries for LLM API calls"
)
# Server Settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8000, ge=1, le=65535, description="Server port")
reload: bool = Field(default=False, description="Enable hot reload (development only)")
# Redis Cache Settings
redis_host: str = Field(default="localhost", description="Redis host")
redis_port: int = Field(default=6379, ge=1, le=65535, description="Redis port")
redis_db: int = Field(default=0, ge=0, le=15, description="Redis database number")
redis_password: Optional[str] = Field(None, description="Redis password")
# Development Settings
debug: bool = Field(default=False, description="Enable debug mode")
# HuggingFace Spaces Settings
hf_space: bool = Field(
default=False,
description="Flag indicating if running on HF Spaces (set via HF_SPACE env)"
)
shares_dir: str = Field(
default="./data/shares",
description="Directory for shared chat storage (use /tmp/shares on HF Spaces)"
)
# Site Protection
site_password: Optional[str] = Field(
default=None,
description="Password to protect the site (enables HTTP Basic Auth when set)"
)
site_username: str = Field(
default="zeta",
description="Username for HTTP Basic Auth (default: zeta)"
)
site_title: str = Field(
default="Zeta Researcher",
description="Site title displayed in UI (use 'Zeta Researcher Light' for HF Spaces)"
)
# Netlify Publishing (local version only)
netlify_auth_token: Optional[str] = Field(
None, description="Netlify personal access token for publishing"
)
netlify_site_id: Optional[str] = Field(
None, description="Netlify site ID for the research publication site"
)
publish_site_url: str = Field(
default="https://rodrigoetcheto.com",
description="Base URL of the published site"
)
@validator("chunk_overlap")
def validate_chunk_overlap(cls, v, values):
"""Ensure chunk overlap is less than chunk size."""
if "chunk_size" in values and v >= values["chunk_size"]:
raise ValueError("chunk_overlap must be less than chunk_size")
return v
@validator("top_k_rerank")
def validate_top_k_rerank(cls, v, values):
"""Ensure rerank top_k is less than or equal to retrieval top_k."""
if "top_k_retrieval" in values and v > values["top_k_retrieval"]:
raise ValueError("top_k_rerank must be <= top_k_retrieval")
return v
def get_llm_config(self, llm_type: Optional[str] = None) -> dict:
"""
Get LLM configuration for the specified type.
Args:
llm_type: Type of LLM ("claude" or "grok"). Uses default_llm if None.
Returns:
Dictionary with LLM configuration
"""
llm = llm_type or self.default_llm
if llm == "claude":
return {
"api_key": self.anthropic_api_key,
"model": self.claude_model,
"temperature": self.temperature,
"max_tokens": self.max_response_tokens,
"max_retries": self.max_retries,
}
elif llm == "grok":
return {
"api_key": self.grok_api_key,
"base_url": self.grok_api_base,
"model": self.grok_model,
"temperature": self.temperature,
"max_tokens": self.max_response_tokens,
"max_retries": self.max_retries,
}
else:
raise ValueError(f"Unknown LLM type: {llm}")
def get_chunking_config(self) -> dict:
"""Get chunking configuration."""
return {
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
"parent_chunk_size": self.parent_chunk_size,
}
def get_retrieval_config(self) -> dict:
"""Get retrieval configuration."""
return {
"top_k": self.top_k_retrieval,
"top_k_rerank": self.top_k_rerank,
"score_threshold": self.retrieval_score_threshold,
}
def get_redis_url(self) -> str:
"""Get Redis connection URL."""
if self.redis_password:
return f"redis://:{self.redis_password}@{self.redis_host}:{self.redis_port}/{self.redis_db}"
return f"redis://{self.redis_host}:{self.redis_port}/{self.redis_db}"
def get_embedding_config(self) -> Dict[str, Any]:
"""Get configuration for the current embedding model."""
return get_embedding_model_config(self.embedding_model)
def get_collection_name(self) -> str:
"""Get ChromaDB collection name for current embedding model."""
return get_collection_name_for_model(self.embedding_model, self.chroma_collection_name)
def get_search_config(self) -> dict:
"""Get search configuration."""
return {
"mode": self.default_search_mode,
"web_provider": self.web_search_provider,
"weights": {
"local": self.hybrid_local_weight,
"web": self.hybrid_web_weight,
"scientific": self.hybrid_scientific_weight,
}
}
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""
Get the global settings instance.
Returns:
Settings instance loaded from environment variables
"""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reload_settings() -> Settings:
"""
Reload settings from environment variables.
Useful for testing or when environment changes.
Returns:
New Settings instance
"""
global _settings
_settings = Settings()
return _settings