File size: 3,895 Bytes
43efcb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """
Centralized configuration for the RAG system.
"""
import os
from typing import Optional, Dict, Any
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Embedding model settings
EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "384"))
USE_GPU = os.getenv("USE_GPU", "True").lower() in ("true", "1", "t")
# Document processing settings
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1000"))
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "200"))
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
# Vector database settings
VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss") # Options: "faiss", "milvus", etc.
FAISS_INDEX_TYPE = os.getenv("FAISS_INDEX_TYPE", "Flat") # Options: "Flat", "IVF", "HNSW"
MONGODB_URI = os.getenv("MONGODB_URI", "mongodb://localhost:27017/")
DB_NAME = os.getenv("DB_NAME", "rag_db")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents")
# Retrieval settings
TOP_K = int(os.getenv("TOP_K", "5"))
SEARCH_TYPE = os.getenv("SEARCH_TYPE", "hybrid") # Options: "semantic", "keyword", "hybrid"
SEMANTIC_SEARCH_WEIGHT = float(os.getenv("SEMANTIC_SEARCH_WEIGHT", "0.7"))
KEYWORD_SEARCH_WEIGHT = float(os.getenv("KEYWORD_SEARCH_WEIGHT", "0.3"))
# LLM settings
LLM_MODEL_NAME = os.getenv("LLM_MODEL", "gpt-3.5-turbo")
LLM_API_KEY = os.getenv("OPENAI_API_KEY")
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.2"))
LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "512"))
# Local LLM settings (optional)
LOCAL_LLM_MODEL_NAME = os.getenv("LOCAL_LLM_MODEL", "google/flan-t5-base")
USE_LOCAL_LLM = os.getenv("USE_LOCAL_LLM", "False").lower() in ("true", "1", "t")
# API settings
API_HOST = os.getenv("API_HOST", "0.0.0.0")
API_PORT = int(os.getenv("API_PORT", "8000"))
# Logging settings
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# Default prompt template
DEFAULT_PROMPT_TEMPLATE = """
Answer the following question based ONLY on the provided context.
If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
Context:
{context}
Question: {query}
Answer:
"""
def get_logging_config() -> Dict[str, Any]:
"""Get logging configuration dictionary."""
return {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {
"format": LOG_FORMAT
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": LOG_LEVEL,
"formatter": "standard",
"stream": "ext://sys.stdout"
},
},
"loggers": {
"": {
"handlers": ["console"],
"level": LOG_LEVEL,
"propagate": True
}
}
}
def get_model_config(model_name: Optional[str] = None) -> Dict[str, Any]:
"""Get model-specific configuration."""
# Default to the configured model if none specified
if model_name is None:
model_name = EMBEDDING_MODEL_NAME
# Common configurations for popular models
config_map = {
"sentence-transformers/all-MiniLM-L6-v2": {
"dimension": 384,
"max_length": 512,
"normalize": True,
},
"sentence-transformers/all-mpnet-base-v2": {
"dimension": 768,
"max_length": 512,
"normalize": True,
},
# Add more models as needed
}
# Return specific config if available, otherwise return default values
return config_map.get(model_name, {
"dimension": EMBEDDING_DIMENSION,
"max_length": MAX_LENGTH,
"normalize": True,
})
|