File size: 4,538 Bytes
0ef94af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
from pathlib import Path

class RAGConfig:
    """Configuration settings for the RAG Agent"""
    
    # Model settings
    MODEL_NAME = "microsoft/DialoGPT-medium"  # Default model, can be changed
    EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
    
    # Agent settings
    MAX_ITERATIONS = 5
    TEMPERATURE = 0.7
    MAX_TOKENS = 2048
    
    # Retrieval settings
    CHUNK_SIZE = 512
    CHUNK_OVERLAP = 50
    TOP_K_RETRIEVAL = 5
    SIMILARITY_THRESHOLD = 0.7
    
    # Paths
    BASE_DIR = Path(__file__).parent
    KNOWLEDGE_BASE_PATH = BASE_DIR / "knowledge_base"
    VECTOR_STORE_PATH = BASE_DIR / "vector_store"
    LOGS_PATH = BASE_DIR / "logs"
    
    # API Keys (set as environment variables)
    HF_TOKEN = os.getenv("HF_TOKEN")
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
    GROQ_API_KEY = os.getenv("GROQ_API_KEY")
    
    # Web search settings
    MAX_SEARCH_RESULTS = 5
    SEARCH_TIMEOUT = 10
    
    # Vector database settings
    VECTOR_DB_TYPE = "faiss"  # Options: faiss, chroma, pinecone
    PERSIST_DIRECTORY = str(VECTOR_STORE_PATH)
    
    # Gradio settings
    GRADIO_SHARE = True
    GRADIO_PORT = 7860
    GRADIO_HOST = "0.0.0.0"
    
    # Supported file types for knowledge base
    SUPPORTED_EXTENSIONS = ['.txt', '.md', '.pdf', '.docx', '.json', '.csv']
    
    # Advanced settings
    USE_RERANKING = True
    RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-2-v2"
    
    # Logging
    LOG_LEVEL = "INFO"
    LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    
    def __init__(self):
        """Initialize configuration and create necessary directories"""
        self._create_directories()
        self._validate_config()
    
    def _create_directories(self):
        """Create necessary directories if they don't exist"""
        directories = [
            self.KNOWLEDGE_BASE_PATH,
            self.VECTOR_STORE_PATH,
            self.LOGS_PATH
        ]
        
        for directory in directories:
            directory.mkdir(parents=True, exist_ok=True)
    
    def _validate_config(self):
        """Validate configuration settings"""
        # Check if HF token is available
        if not self.HF_TOKEN:
            print("⚠️ Warning: HF_TOKEN not set. Some features may not work.")
        
        # Validate paths
        if not self.BASE_DIR.exists():
            raise ValueError(f"Base directory does not exist: {self.BASE_DIR}")
    
    def get_model_config(self) -> dict:
        """Get model configuration dictionary"""
        return {
            "model_name": self.MODEL_NAME,
            "temperature": self.TEMPERATURE,
            "max_tokens": self.MAX_TOKENS,
            "token": self.HF_TOKEN
        }
    
    def get_retrieval_config(self) -> dict:
        """Get retrieval configuration dictionary"""
        return {
            "chunk_size": self.CHUNK_SIZE,
            "chunk_overlap": self.CHUNK_OVERLAP,
            "top_k": self.TOP_K_RETRIEVAL,
            "similarity_threshold": self.SIMILARITY_THRESHOLD,
            "embedding_model": self.EMBEDDING_MODEL,
            "reranker_model": self.RERANKER_MODEL if self.USE_RERANKING else None
        }
    
    def get_gradio_config(self) -> dict:
        """Get Gradio configuration dictionary"""
        return {
            "share": self.GRADIO_SHARE,
            "server_port": self.GRADIO_PORT,
            "server_name": self.GRADIO_HOST
        }
    
    @classmethod
    def from_env(cls):
        """Create configuration from environment variables"""
        instance = cls()
        
        # Override with environment variables if available
        env_mappings = {
            "RAG_MODEL_NAME": "MODEL_NAME",
            "RAG_MAX_ITERATIONS": "MAX_ITERATIONS",
            "RAG_TEMPERATURE": "TEMPERATURE",
            "RAG_CHUNK_SIZE": "CHUNK_SIZE",
            "RAG_TOP_K": "TOP_K_RETRIEVAL"
        }
        
        for env_var, attr_name in env_mappings.items():
            env_value = os.getenv(env_var)
            if env_value:
                # Convert to appropriate type
                if attr_name in ["MAX_ITERATIONS", "CHUNK_SIZE", "TOP_K_RETRIEVAL"]:
                    setattr(instance, attr_name, int(env_value))
                elif attr_name in ["TEMPERATURE"]:
                    setattr(instance, attr_name, float(env_value))
                else:
                    setattr(instance, attr_name, env_value)
        
        return instance

# Global config instance
config = RAGConfig.from_env()