hpmor / src /config.py
deenaik's picture
Initial commit
6ef4823
"""Configuration management for HPMOR Q&A System."""
import os
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class Config(BaseModel):
"""Application configuration."""
# API Keys
groq_api_key: Optional[str] = Field(default=os.getenv("GROQ_API_KEY"))
# Ollama Settings
ollama_host: str = Field(default=os.getenv("OLLAMA_HOST", "http://localhost:11434"))
# Model Names
local_model_small: str = Field(default=os.getenv("LOCAL_MODEL_SMALL", "llama3.2:7b"))
local_model_large: str = Field(default=os.getenv("LOCAL_MODEL_LARGE", "llama3.2:13b"))
groq_model: str = Field(default=os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile"))
# Embedding Model
embedding_model: str = Field(
default=os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
)
# Processing Parameters
chunk_size: int = Field(default=int(os.getenv("CHUNK_SIZE", "1000")))
chunk_overlap: int = Field(default=int(os.getenv("CHUNK_OVERLAP", "200")))
top_k_retrieval: int = Field(default=int(os.getenv("TOP_K_RETRIEVAL", "5")))
# Model Selection Thresholds
complexity_threshold: float = Field(
default=float(os.getenv("COMPLEXITY_THRESHOLD", "0.7"))
)
max_local_context_size: int = Field(
default=int(os.getenv("MAX_LOCAL_CONTEXT_SIZE", "4000"))
)
# ChromaDB Settings
chroma_persist_dir: Path = Field(
default=Path(os.getenv("CHROMA_PERSIST_DIR", "./chroma_db"))
)
collection_name: str = Field(
default=os.getenv("COLLECTION_NAME", "hpmor_collection")
)
# Gradio Settings
gradio_server_port: int = Field(
default=int(os.getenv("GRADIO_SERVER_PORT", "7860"))
)
gradio_share: bool = Field(
default=os.getenv("GRADIO_SHARE", "False").lower() == "true"
)
# File Paths
data_dir: Path = Field(default=Path("data"))
raw_data_dir: Path = Field(default=Path("data/raw"))
processed_data_dir: Path = Field(default=Path("data/processed"))
hpmor_file: Path = Field(default=Path("data/raw/hpmor.html"))
def validate_paths(self) -> None:
"""Create necessary directories if they don't exist."""
for dir_path in [self.data_dir, self.raw_data_dir, self.processed_data_dir]:
dir_path.mkdir(parents=True, exist_ok=True)
self.chroma_persist_dir.mkdir(parents=True, exist_ok=True)
def has_groq_api(self) -> bool:
"""Check if Groq API key is configured."""
return self.groq_api_key and self.groq_api_key != "your_groq_api_key_here"
def get_model_config(self, model_type: str) -> dict:
"""Get configuration for a specific model type."""
configs = {
"local_small": {
"model": self.local_model_small,
"type": "ollama",
"max_tokens": 2048,
"temperature": 0.7,
},
"local_large": {
"model": self.local_model_large,
"type": "ollama",
"max_tokens": 4096,
"temperature": 0.7,
},
"groq": {
"model": self.groq_model,
"type": "groq",
"api_key": self.groq_api_key,
"max_tokens": 8192,
"temperature": 0.7,
},
}
return configs.get(model_type, configs["local_small"])
# Create global config instance
config = Config()
config.validate_paths()