JimmyBhoy's picture
Upload 7 files
0ef94af verified
import os
from pathlib import Path
class RAGConfig:
"""Configuration settings for the RAG Agent"""
# Model settings
MODEL_NAME = "microsoft/DialoGPT-medium" # Default model, can be changed
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
# Agent settings
MAX_ITERATIONS = 5
TEMPERATURE = 0.7
MAX_TOKENS = 2048
# Retrieval settings
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
TOP_K_RETRIEVAL = 5
SIMILARITY_THRESHOLD = 0.7
# Paths
BASE_DIR = Path(__file__).parent
KNOWLEDGE_BASE_PATH = BASE_DIR / "knowledge_base"
VECTOR_STORE_PATH = BASE_DIR / "vector_store"
LOGS_PATH = BASE_DIR / "logs"
# API Keys (set as environment variables)
HF_TOKEN = os.getenv("HF_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# Web search settings
MAX_SEARCH_RESULTS = 5
SEARCH_TIMEOUT = 10
# Vector database settings
VECTOR_DB_TYPE = "faiss" # Options: faiss, chroma, pinecone
PERSIST_DIRECTORY = str(VECTOR_STORE_PATH)
# Gradio settings
GRADIO_SHARE = True
GRADIO_PORT = 7860
GRADIO_HOST = "0.0.0.0"
# Supported file types for knowledge base
SUPPORTED_EXTENSIONS = ['.txt', '.md', '.pdf', '.docx', '.json', '.csv']
# Advanced settings
USE_RERANKING = True
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-2-v2"
# Logging
LOG_LEVEL = "INFO"
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def __init__(self):
"""Initialize configuration and create necessary directories"""
self._create_directories()
self._validate_config()
def _create_directories(self):
"""Create necessary directories if they don't exist"""
directories = [
self.KNOWLEDGE_BASE_PATH,
self.VECTOR_STORE_PATH,
self.LOGS_PATH
]
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _validate_config(self):
"""Validate configuration settings"""
# Check if HF token is available
if not self.HF_TOKEN:
print("⚠️ Warning: HF_TOKEN not set. Some features may not work.")
# Validate paths
if not self.BASE_DIR.exists():
raise ValueError(f"Base directory does not exist: {self.BASE_DIR}")
def get_model_config(self) -> dict:
"""Get model configuration dictionary"""
return {
"model_name": self.MODEL_NAME,
"temperature": self.TEMPERATURE,
"max_tokens": self.MAX_TOKENS,
"token": self.HF_TOKEN
}
def get_retrieval_config(self) -> dict:
"""Get retrieval configuration dictionary"""
return {
"chunk_size": self.CHUNK_SIZE,
"chunk_overlap": self.CHUNK_OVERLAP,
"top_k": self.TOP_K_RETRIEVAL,
"similarity_threshold": self.SIMILARITY_THRESHOLD,
"embedding_model": self.EMBEDDING_MODEL,
"reranker_model": self.RERANKER_MODEL if self.USE_RERANKING else None
}
def get_gradio_config(self) -> dict:
"""Get Gradio configuration dictionary"""
return {
"share": self.GRADIO_SHARE,
"server_port": self.GRADIO_PORT,
"server_name": self.GRADIO_HOST
}
@classmethod
def from_env(cls):
"""Create configuration from environment variables"""
instance = cls()
# Override with environment variables if available
env_mappings = {
"RAG_MODEL_NAME": "MODEL_NAME",
"RAG_MAX_ITERATIONS": "MAX_ITERATIONS",
"RAG_TEMPERATURE": "TEMPERATURE",
"RAG_CHUNK_SIZE": "CHUNK_SIZE",
"RAG_TOP_K": "TOP_K_RETRIEVAL"
}
for env_var, attr_name in env_mappings.items():
env_value = os.getenv(env_var)
if env_value:
# Convert to appropriate type
if attr_name in ["MAX_ITERATIONS", "CHUNK_SIZE", "TOP_K_RETRIEVAL"]:
setattr(instance, attr_name, int(env_value))
elif attr_name in ["TEMPERATURE"]:
setattr(instance, attr_name, float(env_value))
else:
setattr(instance, attr_name, env_value)
return instance
# Global config instance
config = RAGConfig.from_env()