amiraghhh commited on
Commit
d332d62
·
verified ·
1 Parent(s): 557317b

Delete config.py

Browse files
Files changed (1) hide show
  1. config.py +0 -317
config.py DELETED
@@ -1,317 +0,0 @@
1
- """
2
- Configuration file for the Medical RAG System.
3
- Centralized settings for easy customization without modifying core files.
4
- """
5
-
6
- import os
7
- from pathlib import Path
8
-
9
- # ===========================
10
- # PATHS & DIRECTORIES
11
- # ===========================
12
-
13
- # Project root directory
14
- PROJECT_ROOT = Path(__file__).parent
15
-
16
- # Vector database location
17
- VECTOR_DB_PATH = os.getenv("VECTOR_DB_PATH", "./MedQuAD_db")
18
-
19
- # Model cache directory (HuggingFace cache)
20
- HF_CACHE_DIR = os.getenv("HF_HOME", "./models")
21
-
22
- # ===========================
23
- # MODEL CONFIGURATION
24
- # ===========================
25
-
26
- # Embedding Model (for query and document encoding)
27
- EMBEDDING_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
28
- EMBEDDING_MODEL_DEVICE = "cuda" # or "cpu"
29
-
30
- # Query Rewriter Model
31
- REWRITER_MODEL_ID = "google/flan-t5-small"
32
- REWRITER_MAX_LENGTH = 64
33
- REWRITER_TEMPERATURE = 0.3
34
- REWRITER_REPETITION_PENALTY = 1.3
35
-
36
- # Re-ranker Model (MonoT5)
37
- RERANKER_MODEL_ID = "castorini/monot5-base-msmarco"
38
- RERANKER_DEVICE = "cuda" # or "cpu"
39
-
40
- # Fine-tuned Generator Model
41
- FINETUNED_MODEL_ID = os.getenv(
42
- "FINETUNED_MODEL_ID",
43
- "amiraghhh/fine-tuned-flan-t5-small"
44
- )
45
-
46
- # Baseline FLAN-T5 (for prompt building)
47
- BASELINE_MODEL_ID = "google/flan-t5-small"
48
-
49
- # ===========================
50
- # RETRIEVAL CONFIGURATION
51
- # ===========================
52
-
53
- # Default number of context chunks to retrieve
54
- DEFAULT_TOP_K = 3
55
-
56
- # Maximum number of chunks to retrieve before ranking
57
- MAX_RETRIEVE = 10
58
-
59
- # Vector store collection name
60
- VECTOR_STORE_COLLECTION = "medical_rag"
61
-
62
- # Similarity threshold for filtering (0-1, lower is more strict)
63
- SIMILARITY_THRESHOLD = 0.1
64
-
65
- # ===========================
66
- # GENERATION CONFIGURATION
67
- # ===========================
68
-
69
- # Maximum tokens in generated answer
70
- GENERATION_MAX_TOKENS = 70
71
-
72
- # Number of beams for beam search
73
- GENERATION_NUM_BEAMS = 3
74
-
75
- # Repetition penalty (> 1.0 reduces repetition)
76
- GENERATION_REPETITION_PENALTY = 1.4
77
-
78
- # Do sampling (True) or greedy (False)
79
- GENERATION_DO_SAMPLE = False
80
-
81
- # Temperature for sampling (ignored if do_sample=False)
82
- GENERATION_TEMPERATURE = 0.7
83
-
84
- # ===========================
85
- # PROMPT CONFIGURATION
86
- # ===========================
87
-
88
- # Maximum tokens for the full prompt
89
- PROMPT_MAX_TOKENS = 512
90
-
91
- # Prompt template - can be customized
92
- PROMPT_INSTRUCTION = "Medical Context:\n"
93
- PROMPT_QUERY_FOOTER = "\nQ: {query}\nA:"
94
-
95
- # Emergency keywords that should trigger a warning
96
- EMERGENCY_KEYWORDS = [
97
- "emergency", "severe pain", "bleeding", "blind",
98
- "lose consciousness", "pass out", "call 911", "911",
99
- "critical", "life-threatening"
100
- ]
101
-
102
- EMERGENCY_RESPONSE = """I am an AI and cannot provide medical advice for emergencies.
103
- PLEASE CONTACT EMERGENCY SERVICES OR A MEDICAL PROFESSIONAL IMMEDIATELY."""
104
-
105
- # ===========================
106
- # CONFIDENCE SCORING
107
- # ===========================
108
-
109
- # Confidence thresholds
110
- CONFIDENCE_LOW_THRESHOLD = 40 # %
111
- CONFIDENCE_HIGH_THRESHOLD = 70 # %
112
-
113
- # How to calculate confidence (based on retrieval distances)
114
- # distance_range: 0 (identical) to 2 (very different)
115
- CONFIDENCE_FORMULA = "max(0, min(100, (1 - avg_distance) * 100))"
116
-
117
- # ===========================
118
- # WEB INTERFACE CONFIGURATION
119
- # ===========================
120
-
121
- # Gradio server settings
122
- GRADIO_SERVER_NAME = "0.0.0.0"
123
- GRADIO_SERVER_PORT = 7860
124
- GRADIO_SHARE = False
125
- GRADIO_DEBUG = False
126
- GRADIO_SHOW_ERROR = True
127
- GRADIO_SHOW_TIPS = True
128
-
129
- # Gradio theme
130
- GRADIO_THEME = "soft"
131
-
132
- # Page title
133
- PAGE_TITLE = "Medical Q&A System"
134
-
135
- # Example questions to display
136
- EXAMPLE_QUESTIONS = [
137
- "What are the symptoms of type 2 diabetes?",
138
- "How is hypertension treated?",
139
- "What causes migraines?",
140
- "What are the risk factors for heart disease?",
141
- "How do I manage chronic pain?",
142
- "What is asthma?",
143
- "When should I see a doctor for fever?",
144
- "What are the causes of back pain?"
145
- ]
146
-
147
- # ===========================
148
- # PERFORMANCE & OPTIMIZATION
149
- # ===========================
150
-
151
- # Batch size for embedding
152
- EMBEDDING_BATCH_SIZE = 64
153
-
154
- # Whether to normalize embeddings
155
- EMBEDDING_NORMALIZE = True
156
-
157
- # Convert embeddings to numpy (True) or keep as tensors (False)
158
- EMBEDDING_CONVERT_TO_NUMPY = True
159
-
160
- # Cache frequently used embeddings
161
- ENABLE_CACHE = True
162
- CACHE_SIZE = 1000 # number of queries to cache
163
-
164
- # ===========================
165
- # LOGGING CONFIGURATION
166
- # ===========================
167
-
168
- LOG_LEVEL = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
169
- LOG_FILE = "rag_system.log"
170
- LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
171
-
172
- # ===========================
173
- # RANDOM SEEDS (for reproducibility)
174
- # ===========================
175
-
176
- RANDOM_SEED = 1
177
- NUMPY_SEED = 1
178
- TORCH_SEED = 1
179
- CUDA_SEED = 1
180
-
181
- # ===========================
182
- # DATA PREPROCESSING
183
- # ===========================
184
-
185
- # Text cleaning options
186
- CLEAN_LOWERCASE = True
187
- CLEAN_REMOVE_URLS = True
188
- CLEAN_REMOVE_EMAILS = True
189
- CLEAN_REMOVE_PHONES = True
190
- CLEAN_REMOVE_ADDRESSES = True
191
-
192
- # Chunking options
193
- CHUNK_SIZE = 350 # tokens
194
- CHUNK_OVERLAP = 50 # tokens
195
-
196
- # ===========================
197
- # RATE LIMITING & SECURITY
198
- # ===========================
199
-
200
- # Enable rate limiting
201
- ENABLE_RATE_LIMIT = False
202
- MAX_REQUESTS_PER_MINUTE = 30
203
- MAX_REQUEST_LENGTH = 1000 # max question length in characters
204
-
205
- # ===========================
206
- # DEPLOYMENT SETTINGS
207
- # ===========================
208
-
209
- # Environment type
210
- ENVIRONMENT = os.getenv("ENVIRONMENT", "development") # development, staging, production
211
-
212
- # Enable analytics
213
- ENABLE_ANALYTICS = False
214
-
215
- # API key (for authentication if needed)
216
- API_KEY = os.getenv("API_KEY", None)
217
-
218
- # ===========================
219
- # VERSION & METADATA
220
- # ===========================
221
-
222
- APP_VERSION = "1.0.0"
223
- APP_NAME = "Medical RAG System"
224
- APP_DESCRIPTION = "Retrieval-Augmented Generation for medical Q&A"
225
- APP_AUTHOR = "Your Name"
226
- APP_LICENSE = "MIT"
227
-
228
- # ===========================
229
- # HELPER FUNCTIONS
230
- # ===========================
231
-
232
- def get_model_config(model_type):
233
- """Get configuration for a specific model type.
234
-
235
- Args:
236
- model_type (str): Type of model ('embedding', 'rewriter', 'reranker', 'generator')
237
-
238
- Returns:
239
- dict: Configuration dictionary
240
- """
241
- configs = {
242
- 'embedding': {
243
- 'model_id': EMBEDDING_MODEL_ID,
244
- 'device': EMBEDDING_MODEL_DEVICE,
245
- 'batch_size': EMBEDDING_BATCH_SIZE,
246
- 'normalize': EMBEDDING_NORMALIZE,
247
- },
248
- 'rewriter': {
249
- 'model_id': REWRITER_MODEL_ID,
250
- 'max_length': REWRITER_MAX_LENGTH,
251
- 'temperature': REWRITER_TEMPERATURE,
252
- 'repetition_penalty': REWRITER_REPETITION_PENALTY,
253
- },
254
- 'reranker': {
255
- 'model_id': RERANKER_MODEL_ID,
256
- 'device': RERANKER_DEVICE,
257
- },
258
- 'generator': {
259
- 'model_id': FINETUNED_MODEL_ID,
260
- 'max_tokens': GENERATION_MAX_TOKENS,
261
- 'num_beams': GENERATION_NUM_BEAMS,
262
- 'do_sample': GENERATION_DO_SAMPLE,
263
- 'temperature': GENERATION_TEMPERATURE,
264
- 'repetition_penalty': GENERATION_REPETITION_PENALTY,
265
- }
266
- }
267
-
268
- return configs.get(model_type, {})
269
-
270
-
271
- def is_production():
272
- """Check if running in production environment."""
273
- return ENVIRONMENT == "production"
274
-
275
-
276
- def is_emergency_query(query):
277
- """Check if query contains emergency keywords."""
278
- query_lower = query.lower()
279
- return any(keyword in query_lower for keyword in EMERGENCY_KEYWORDS)
280
-
281
-
282
- # ===========================
283
- # VALIDATE CONFIGURATION
284
- # ===========================
285
-
286
- def validate_config():
287
- """Validate critical configuration settings."""
288
- errors = []
289
-
290
- # Check paths
291
- if not os.path.exists(VECTOR_DB_PATH):
292
- errors.append(f"Vector database path not found: {VECTOR_DB_PATH}")
293
-
294
- # Check model IDs
295
- if not FINETUNED_MODEL_ID:
296
- errors.append("FINETUNED_MODEL_ID not set")
297
-
298
- # Check thresholds
299
- if not (0 <= CONFIDENCE_LOW_THRESHOLD <= 100):
300
- errors.append("CONFIDENCE_LOW_THRESHOLD must be between 0 and 100")
301
-
302
- if errors:
303
- print("Configuration validation errors:")
304
- for error in errors:
305
- print(f" - {error}")
306
- return False
307
-
308
- return True
309
-
310
-
311
- if __name__ == "__main__":
312
- # Test configuration
313
- print("Configuration loaded successfully!")
314
- print(f"Environment: {ENVIRONMENT}")
315
- print(f"Vector DB: {VECTOR_DB_PATH}")
316
- print(f"Fine-tuned Model: {FINETUNED_MODEL_ID}")
317
- print(f"Validation: {'PASSED' if validate_config() else 'FAILED'}")