multimodal-rag / src /utils /config.py
itachi
Initial deployment
a809248
"""
Configuration module for Multimodal RAG System.
Centralized settings using Pydantic for validation.
"""
from pathlib import Path
from typing import Literal, Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class PathConfig(BaseSettings):
"""File and directory paths."""
base_dir: Path = Path(__file__).parent.parent.parent
data_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "data")
models_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "artifacts" / "models")
logs_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "logs")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.models_dir.mkdir(parents=True, exist_ok=True)
self.logs_dir.mkdir(parents=True, exist_ok=True)
class EmbeddingConfig(BaseSettings):
"""Embedding model configuration."""
model_name: str = "sentence-transformers/all-mpnet-base-v2"
embedding_dim: int = 768
reduced_dim: int = 256
batch_size: int = 32
max_seq_length: int = 512
use_fp16: bool = True
normalize_embeddings: bool = True
device: str = "cuda" # "cuda" or "cpu"
class ChunkingConfig(BaseSettings):
"""Text chunking configuration."""
chunk_size: int = 500
chunk_overlap: int = 50
min_chunk_size: int = 100
max_chunk_size: int = 800
separator: str = "\n\n"
class RetrievalConfig(BaseSettings):
"""Retrieval configuration."""
top_k: int = 20
final_k: int = 5
dense_weight: float = 0.7
sparse_weight: float = 0.3
rrf_k: int = 60
similarity_threshold: float = 0.5
# BM25 parameters
bm25_k1: float = 1.5
bm25_b: float = 0.75
class LLMConfig(BaseSettings):
"""LLM configuration."""
model_name: str = "mistralai/Mistral-7B-Instruct-v0.2"
max_new_tokens: int = 512
temperature: float = 0.3
top_p: float = 0.9
do_sample: bool = True
context_window: int = 4096
# Alternative models
available_models: list = [
"mistralai/Mistral-7B-Instruct-v0.2",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"meta-llama/Llama-3.1-8B-Instruct",
"microsoft/Phi-3-medium-4k-instruct"
]
class DatabaseConfig(BaseSettings):
"""Database configuration."""
# PostgreSQL
pg_host: str = "localhost"
pg_port: int = 5432
pg_database: str = "rag_db"
pg_user: str = "postgres"
pg_password: str = "postgres"
# pgvector
index_type: Literal["ivfflat", "hnsw"] = "ivfflat"
num_lists: int = 100
probes: int = 10
# FAISS
use_faiss: bool = False
faiss_index_type: Literal["flat", "ivf", "hnsw"] = "ivf"
faiss_nlist: int = 100
faiss_nprobe: int = 10
class EvaluationConfig(BaseSettings):
"""Evaluation configuration."""
# Retrieval metrics
retrieval_k_values: list = [1, 3, 5, 10, 20]
# Generation metrics
rouge_types: list = ["rouge1", "rouge2", "rougeL"]
bertscore_model: str = "microsoft/deberta-xlarge-mnli"
# Hallucination detection
hallucination_threshold: float = 0.7
entailment_model: str = "microsoft/deberta-large-mnli"
class MLflowConfig(BaseSettings):
"""MLflow configuration."""
tracking_uri: str = "mlruns"
experiment_name: str = "multimodal-rag"
artifact_location: Optional[str] = None
class Config(BaseSettings):
"""Main configuration class."""
paths: PathConfig = Field(default_factory=PathConfig)
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
chunking: ChunkingConfig = Field(default_factory=ChunkingConfig)
retrieval: RetrievalConfig = Field(default_factory=RetrievalConfig)
llm: LLMConfig = Field(default_factory=LLMConfig)
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
mlflow: MLflowConfig = Field(default_factory=MLflowConfig)
# General
seed: int = 42
debug: bool = False
# Global config instance
config = Config()
def get_config() -> Config:
"""Get the global configuration instance."""
return config