| """ |
| Configuration management using Pydantic Settings. |
| |
| This module provides type-safe access to environment variables and application settings. |
| """ |
|
|
| from pathlib import Path |
| from typing import Literal, Optional, Dict, Any, List |
| from pydantic import Field, validator |
| from pydantic_settings import BaseSettings, SettingsConfigDict |
|
|
|
|
| |
| LLM_MODELS: Dict[str, Dict[str, Any]] = { |
| "claude-sonnet": { |
| "name": "Claude Sonnet", |
| "model_id": "claude-sonnet-4-20250514", |
| "provider": "anthropic", |
| "description": "Fast, intelligent model for everyday tasks", |
| }, |
| "claude-opus": { |
| "name": "Claude Opus", |
| "model_id": "claude-opus-4-20250514", |
| "provider": "anthropic", |
| "description": "Most capable model for complex reasoning", |
| }, |
| "grok": { |
| "name": "Grok", |
| "model_id": "grok-beta", |
| "provider": "xai", |
| "description": "xAI's conversational model", |
| }, |
| } |
|
|
|
|
| def get_llm_model_config(model_key: str) -> Dict[str, Any]: |
| """Get configuration for an LLM model.""" |
| if model_key not in LLM_MODELS: |
| raise ValueError(f"Unknown LLM model: {model_key}") |
| return LLM_MODELS[model_key] |
|
|
|
|
| def list_llm_models() -> List[Dict[str, Any]]: |
| """List all available LLM models with their configurations.""" |
| return [ |
| {"id": model_key, **config} |
| for model_key, config in LLM_MODELS.items() |
| ] |
|
|
|
|
| |
| EMBEDDING_MODELS: Dict[str, Dict[str, Any]] = { |
| "nvidia/NV-Embed-v2": { |
| "name": "NV-Embed-v2", |
| "dimensions": 4096, |
| "description": "NVIDIA's SOTA embedding model (72.31 MTEB, requires GPU, 17GB)", |
| "type": "nvembed", |
| "batch_size": 2, |
| "max_length": 4096, |
| }, |
| "intfloat/e5-small-v2": { |
| "name": "E5-small", |
| "dimensions": 384, |
| "description": "CPU-optimized, fast (58 MTEB, 16ms latency, 130MB)", |
| "type": "sentence-transformers", |
| "batch_size": 64, |
| "max_length": 512, |
| }, |
| } |
|
|
|
|
| def get_embedding_model_config(model_id: str) -> Dict[str, Any]: |
| """Get configuration for an embedding model.""" |
| if model_id not in EMBEDDING_MODELS: |
| raise ValueError(f"Unknown embedding model: {model_id}") |
| return EMBEDDING_MODELS[model_id] |
|
|
|
|
| def get_collection_name_for_model(model_id: str, base_name: str = "pdf_chunks") -> str: |
| """Generate collection name based on embedding model.""" |
| config = EMBEDDING_MODELS.get(model_id, {}) |
| short_name = config.get("name", "unknown").lower().replace(" ", "_").replace("-", "_") |
| return f"{base_name}_{short_name}" |
|
|
|
|
| def list_embedding_models() -> List[Dict[str, Any]]: |
| """List all available embedding models with their configurations.""" |
| return [ |
| {"id": model_id, **config} |
| for model_id, config in EMBEDDING_MODELS.items() |
| ] |
|
|
|
|
| class Settings(BaseSettings): |
| """Application settings loaded from environment variables.""" |
|
|
| model_config = SettingsConfigDict( |
| env_file=".env" if Path(".env").exists() else None, |
| env_file_encoding="utf-8", |
| case_sensitive=False, |
| extra="ignore" |
| ) |
|
|
| |
| anthropic_api_key: str = Field(..., description="Anthropic API key for Claude") |
| grok_api_key: Optional[str] = Field(None, description="Grok API key") |
| grok_api_base: str = Field( |
| default="https://api.x.ai/v1", |
| description="Grok API base URL (OpenAI-compatible)" |
| ) |
|
|
| |
| cohere_api_key: Optional[str] = Field(None, description="Cohere API key for reranking") |
| google_search_api_key: Optional[str] = Field(None, description="Google Custom Search API key") |
| google_search_engine_id: Optional[str] = Field(None, description="Google Custom Search Engine ID") |
|
|
| |
| tavily_api_key: Optional[str] = Field(None, description="Tavily API key for web search") |
|
|
| |
| semantic_scholar_api_key: Optional[str] = Field(None, description="Semantic Scholar API key (optional, improves rate limits)") |
|
|
| |
| default_search_mode: Literal["local", "web", "scientific", "hybrid"] = Field( |
| default="local", |
| description="Default search mode" |
| ) |
| web_search_provider: Literal["duckduckgo", "tavily"] = Field( |
| default="duckduckgo", |
| description="Preferred web search provider" |
| ) |
|
|
| |
| hybrid_local_weight: float = Field( |
| default=0.4, |
| ge=0.0, |
| le=1.0, |
| description="Weight for local document results in hybrid search" |
| ) |
| hybrid_web_weight: float = Field( |
| default=0.3, |
| ge=0.0, |
| le=1.0, |
| description="Weight for web search results in hybrid search" |
| ) |
| hybrid_scientific_weight: float = Field( |
| default=0.3, |
| ge=0.0, |
| le=1.0, |
| description="Weight for scientific paper results in hybrid search" |
| ) |
|
|
| |
| log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( |
| default="INFO", |
| description="Logging level" |
| ) |
| cache_ttl: int = Field(default=3600, description="Cache TTL in seconds") |
| max_upload_size: int = Field(default=100, description="Max upload size in MB") |
|
|
| |
| chroma_persist_dir: str = Field( |
| default="./data/vectordb", |
| description="ChromaDB persistence directory" |
| ) |
| chroma_collection_name: str = Field( |
| default="pdf_chunks", |
| description="ChromaDB collection name for PDF chunks" |
| ) |
|
|
| |
| embedding_model: str = Field( |
| default="nvidia/NV-Embed-v2", |
| description="Embedding model (NV-Embed-v2 is SOTA on MTEB retrieval benchmarks)" |
| ) |
| embedding_device: Literal["cpu", "cuda", "mps"] = Field( |
| default="mps", |
| description="Device for embedding computation (mps for Apple Silicon)" |
| ) |
| embedding_batch_size: int = Field( |
| default=2, |
| description="Batch size for embedding generation (small for large models like NV-Embed)" |
| ) |
|
|
| |
| chunk_size: int = Field( |
| default=800, |
| ge=100, |
| le=2000, |
| description="Target chunk size in tokens" |
| ) |
| chunk_overlap: int = Field( |
| default=150, |
| ge=0, |
| le=500, |
| description="Overlap between chunks in tokens" |
| ) |
| parent_chunk_size: int = Field( |
| default=1000, |
| ge=500, |
| le=5000, |
| description="Parent chunk size for hierarchical chunking" |
| ) |
|
|
| |
| top_k_retrieval: int = Field( |
| default=50, |
| ge=1, |
| le=200, |
| description="Number of chunks to retrieve initially" |
| ) |
| top_k_rerank: int = Field( |
| default=15, |
| ge=1, |
| le=50, |
| description="Number of chunks after reranking" |
| ) |
| retrieval_score_threshold: float = Field( |
| default=0.5, |
| ge=0.0, |
| le=1.0, |
| description="Minimum similarity score for retrieval" |
| ) |
|
|
| |
| max_context_tokens: int = Field( |
| default=100_000, |
| ge=1000, |
| le=200_000, |
| description="Maximum tokens for context (Claude supports 200K)" |
| ) |
| max_response_tokens: int = Field( |
| default=4_000, |
| ge=100, |
| le=8_000, |
| description="Maximum tokens for LLM response" |
| ) |
|
|
| |
| default_llm: Literal["claude", "grok"] = Field( |
| default="claude", |
| description="Default LLM to use" |
| ) |
| claude_model: str = Field( |
| default="claude-3-5-sonnet-20241022", |
| description="Claude model identifier" |
| ) |
| grok_model: str = Field( |
| default="grok-beta", |
| description="Grok model identifier" |
| ) |
| temperature: float = Field( |
| default=0.7, |
| ge=0.0, |
| le=2.0, |
| description="LLM temperature for generation" |
| ) |
| max_retries: int = Field( |
| default=3, |
| ge=0, |
| le=10, |
| description="Maximum retries for LLM API calls" |
| ) |
|
|
| |
| host: str = Field(default="0.0.0.0", description="Server host") |
| port: int = Field(default=8000, ge=1, le=65535, description="Server port") |
| reload: bool = Field(default=False, description="Enable hot reload (development only)") |
|
|
| |
| redis_host: str = Field(default="localhost", description="Redis host") |
| redis_port: int = Field(default=6379, ge=1, le=65535, description="Redis port") |
| redis_db: int = Field(default=0, ge=0, le=15, description="Redis database number") |
| redis_password: Optional[str] = Field(None, description="Redis password") |
|
|
| |
| debug: bool = Field(default=False, description="Enable debug mode") |
|
|
| |
| hf_space: bool = Field( |
| default=False, |
| description="Flag indicating if running on HF Spaces (set via HF_SPACE env)" |
| ) |
| shares_dir: str = Field( |
| default="./data/shares", |
| description="Directory for shared chat storage (use /tmp/shares on HF Spaces)" |
| ) |
|
|
| |
| site_password: Optional[str] = Field( |
| default=None, |
| description="Password to protect the site (enables HTTP Basic Auth when set)" |
| ) |
| site_username: str = Field( |
| default="zeta", |
| description="Username for HTTP Basic Auth (default: zeta)" |
| ) |
| site_title: str = Field( |
| default="Zeta Researcher", |
| description="Site title displayed in UI (use 'Zeta Researcher Light' for HF Spaces)" |
| ) |
|
|
| |
| netlify_auth_token: Optional[str] = Field( |
| None, description="Netlify personal access token for publishing" |
| ) |
| netlify_site_id: Optional[str] = Field( |
| None, description="Netlify site ID for the research publication site" |
| ) |
| publish_site_url: str = Field( |
| default="https://rodrigoetcheto.com", |
| description="Base URL of the published site" |
| ) |
|
|
| @validator("chunk_overlap") |
| def validate_chunk_overlap(cls, v, values): |
| """Ensure chunk overlap is less than chunk size.""" |
| if "chunk_size" in values and v >= values["chunk_size"]: |
| raise ValueError("chunk_overlap must be less than chunk_size") |
| return v |
|
|
| @validator("top_k_rerank") |
| def validate_top_k_rerank(cls, v, values): |
| """Ensure rerank top_k is less than or equal to retrieval top_k.""" |
| if "top_k_retrieval" in values and v > values["top_k_retrieval"]: |
| raise ValueError("top_k_rerank must be <= top_k_retrieval") |
| return v |
|
|
| def get_llm_config(self, llm_type: Optional[str] = None) -> dict: |
| """ |
| Get LLM configuration for the specified type. |
| |
| Args: |
| llm_type: Type of LLM ("claude" or "grok"). Uses default_llm if None. |
| |
| Returns: |
| Dictionary with LLM configuration |
| """ |
| llm = llm_type or self.default_llm |
|
|
| if llm == "claude": |
| return { |
| "api_key": self.anthropic_api_key, |
| "model": self.claude_model, |
| "temperature": self.temperature, |
| "max_tokens": self.max_response_tokens, |
| "max_retries": self.max_retries, |
| } |
| elif llm == "grok": |
| return { |
| "api_key": self.grok_api_key, |
| "base_url": self.grok_api_base, |
| "model": self.grok_model, |
| "temperature": self.temperature, |
| "max_tokens": self.max_response_tokens, |
| "max_retries": self.max_retries, |
| } |
| else: |
| raise ValueError(f"Unknown LLM type: {llm}") |
|
|
| def get_chunking_config(self) -> dict: |
| """Get chunking configuration.""" |
| return { |
| "chunk_size": self.chunk_size, |
| "chunk_overlap": self.chunk_overlap, |
| "parent_chunk_size": self.parent_chunk_size, |
| } |
|
|
| def get_retrieval_config(self) -> dict: |
| """Get retrieval configuration.""" |
| return { |
| "top_k": self.top_k_retrieval, |
| "top_k_rerank": self.top_k_rerank, |
| "score_threshold": self.retrieval_score_threshold, |
| } |
|
|
| def get_redis_url(self) -> str: |
| """Get Redis connection URL.""" |
| if self.redis_password: |
| return f"redis://:{self.redis_password}@{self.redis_host}:{self.redis_port}/{self.redis_db}" |
| return f"redis://{self.redis_host}:{self.redis_port}/{self.redis_db}" |
|
|
| def get_embedding_config(self) -> Dict[str, Any]: |
| """Get configuration for the current embedding model.""" |
| return get_embedding_model_config(self.embedding_model) |
|
|
| def get_collection_name(self) -> str: |
| """Get ChromaDB collection name for current embedding model.""" |
| return get_collection_name_for_model(self.embedding_model, self.chroma_collection_name) |
|
|
| def get_search_config(self) -> dict: |
| """Get search configuration.""" |
| return { |
| "mode": self.default_search_mode, |
| "web_provider": self.web_search_provider, |
| "weights": { |
| "local": self.hybrid_local_weight, |
| "web": self.hybrid_web_weight, |
| "scientific": self.hybrid_scientific_weight, |
| } |
| } |
|
|
|
|
| |
| _settings: Optional[Settings] = None |
|
|
|
|
| def get_settings() -> Settings: |
| """ |
| Get the global settings instance. |
| |
| Returns: |
| Settings instance loaded from environment variables |
| """ |
| global _settings |
| if _settings is None: |
| _settings = Settings() |
| return _settings |
|
|
|
|
| def reload_settings() -> Settings: |
| """ |
| Reload settings from environment variables. |
| |
| Useful for testing or when environment changes. |
| |
| Returns: |
| New Settings instance |
| """ |
| global _settings |
| _settings = Settings() |
| return _settings |
|
|