amiraghhh commited on
Commit
f5d4374
·
verified ·
1 Parent(s): d332d62

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -323
model.py DELETED
@@ -1,323 +0,0 @@
1
- """
2
- Model loading and RAG pipeline core functions.
3
- Handles vector store, embeddings, and answer generation.
4
- """
5
-
6
- import os
7
- import re
8
- import torch
9
- import traceback
10
- import random
11
- import numpy as np
12
- from pathlib import Path
13
-
14
- try:
15
- import chromadb
16
- except ImportError as e:
17
- print(f"Warning: chromadb import failed: {e}")
18
- chromadb = None
19
-
20
- try:
21
- from sentence_transformers import SentenceTransformer
22
- except ImportError as e:
23
- print(f"Error: sentence_transformers not available: {e}")
24
- raise
25
-
26
- try:
27
- from transformers import (
28
- AutoTokenizer,
29
- AutoModelForSeq2SeqLM,
30
- pipeline
31
- )
32
- except ImportError as e:
33
- print(f"Error: transformers not available: {e}")
34
- raise
35
-
36
-
37
- # ===========================
38
- # CONFIGURATION & SETUP
39
- # ===========================
40
-
41
- # Set random seeds for reproducibility
42
- random.seed(1)
43
- np.random.seed(1)
44
- torch.manual_seed(1)
45
- if torch.cuda.is_available():
46
- torch.cuda.manual_seed_all(1)
47
-
48
- # Paths (adjust these to match your HuggingFace Hub paths)
49
- VECTOR_DB_PATH = "./MedQuAD_db"
50
- FINE_TUNED_MODEL_ID = "amiraghhh/fine-tuned-flan-t5-small" # Update with your HF model path
51
-
52
- # Global objects (loaded once at startup)
53
- embed_model = None
54
- vector_store = None
55
- flant5_tokenizer = None
56
- flant5_model = None
57
- finetuned_llm = None
58
- rerank_tokenizer = None
59
- rerank_model = None
60
- rewriter_llm = None
61
-
62
-
63
- # ===========================
64
- # INITIALIZATION FUNCTIONS
65
- # ===========================
66
-
67
- def load_embeddings_model():
68
- """Load SentenceTransformer embedding model."""
69
- global embed_model
70
- if embed_model is None:
71
- print("Loading embedding model...")
72
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
73
- return embed_model
74
-
75
-
76
- def load_vector_store():
77
- """Load ChromaDB vector store from disk."""
78
- global vector_store
79
- if vector_store is None:
80
- print("Loading vector store...")
81
- db_client = chromadb.PersistentClient(path=VECTOR_DB_PATH)
82
- try:
83
- vector_store = db_client.get_collection("medical_rag")
84
- print(f"Vector store loaded with {vector_store.count()} documents")
85
- except Exception as e:
86
- print(f"Error loading vector store: {e}")
87
- raise
88
- return vector_store
89
-
90
-
91
- def load_flan_t5_models():
92
- """Load baseline FLAN-T5 models for prompt building."""
93
- global flant5_tokenizer, flant5_model
94
- if flant5_tokenizer is None:
95
- print("Loading FLAN-T5 models...")
96
- flant5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
97
- flant5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
98
- return flant5_tokenizer, flant5_model
99
-
100
-
101
- def load_rewriter_model():
102
- """Load query rewriter model (FLAN-T5 small)."""
103
- global rewriter_llm
104
- if rewriter_llm is None:
105
- print("Loading query rewriter...")
106
- rewriter_llm = pipeline(
107
- "text2text-generation",
108
- model="google/flan-t5-small",
109
- max_length=64,
110
- do_sample=False,
111
- temperature=0.3,
112
- repetition_penalty=1.3,
113
- no_repeat_ngram_size=2
114
- )
115
- return rewriter_llm
116
-
117
-
118
- def load_reranker_model():
119
- """Load MonoT5 reranker model."""
120
- global rerank_tokenizer, rerank_model
121
- if rerank_tokenizer is None:
122
- print("Loading reranker model...")
123
- rerank_tokenizer = AutoTokenizer.from_pretrained("castorini/monot5-base-msmarco")
124
- rerank_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/monot5-base-msmarco")
125
- rerank_model.eval()
126
- return rerank_tokenizer, rerank_model
127
-
128
-
129
- def load_finetuned_model():
130
- """Load fine-tuned FLAN-T5 model for answer generation."""
131
- global finetuned_llm
132
- if finetuned_llm is None:
133
- print("Loading fine-tuned model...")
134
- ft_tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_ID)
135
- ft_model = AutoModelForSeq2SeqLM.from_pretrained(FINE_TUNED_MODEL_ID)
136
-
137
- finetuned_llm = pipeline(
138
- "text2text-generation",
139
- model=ft_model,
140
- tokenizer=ft_tokenizer,
141
- decoder_start_token_id=ft_model.config.pad_token_id
142
- )
143
- return finetuned_llm
144
-
145
-
146
- def initialize_all():
147
- """Load all models and vector store at startup."""
148
- print("Initializing RAG pipeline...")
149
- load_embeddings_model()
150
- load_vector_store()
151
- load_flan_t5_models()
152
- load_rewriter_model()
153
- load_reranker_model()
154
- load_finetuned_model()
155
- print("RAG pipeline initialized successfully!")
156
-
157
-
158
- # ===========================
159
- # PROMPT BUILDING
160
- # ===========================
161
-
162
- def build_prompt(user_query, context, rewritten_query, max_tokens=512):
163
- """Build prompt with context and query within token limit.
164
-
165
- Args:
166
- user_query (str): Original user question
167
- context (list): Retrieved context chunks
168
- rewritten_query (str): Query after rewriting
169
- max_tokens (int): Maximum tokens for full prompt
170
-
171
- Returns:
172
- str: Formatted prompt for the model
173
- """
174
- tokenizer, _ = load_flan_t5_models()
175
-
176
- if not context:
177
- return f"""No relevant medical information found.
178
- Q: {rewritten_query}
179
- A: Information unavailable."""
180
-
181
- instruction_text = "Medical Context:\n"
182
- query_footer = f"\nQ: {rewritten_query}\nA:"
183
-
184
- # Calculate static overhead tokens
185
- inst_tokens = len(tokenizer.encode(instruction_text, add_special_tokens=False))
186
- query_tokens = len(tokenizer.encode(query_footer, add_special_tokens=False))
187
- total_static_cost = inst_tokens + query_tokens + 5
188
-
189
- remaining_tokens = max(0, max_tokens - total_static_cost)
190
-
191
- # Fill context budget
192
- valid_contexts = []
193
- current_context_tokens = 0
194
-
195
- for idx, c in enumerate(context, start=1):
196
- chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}"
197
- chunk_len = len(tokenizer.encode(chunk_text, add_special_tokens=False))
198
-
199
- if current_context_tokens + chunk_len > remaining_tokens:
200
- break
201
-
202
- valid_contexts.append(chunk_text)
203
- current_context_tokens += chunk_len
204
-
205
- the_context_block = "\n".join(valid_contexts)
206
- full_prompt = f"{instruction_text}{the_context_block}{query_footer}"
207
-
208
- return full_prompt
209
-
210
-
211
- # ===========================
212
- # RESPONSE REFINEMENT
213
- # ===========================
214
-
215
- def refine_response(answer):
216
- """Clean and format generated response text.
217
-
218
- Args:
219
- answer (str): Raw generated text
220
-
221
- Returns:
222
- str: Cleaned and formatted response
223
- """
224
- # Replace multiple periods
225
- answer = re.sub(r'\. {2,}', '.', answer)
226
- answer = re.sub(r'\.([^\s])', r'. \1', answer)
227
-
228
- # Handle truncation at last punctuation
229
- if not answer.strip().endswith(('.', '!', '?')):
230
- last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?'))
231
- if last_punc_pos != -1:
232
- answer = answer[:last_punc_pos + 1]
233
-
234
- # Capitalize sentences
235
- sentences = re.split(r'([.!?]\s*)', answer)
236
- refined_sentences = []
237
- for i in range(0, len(sentences), 2):
238
- sentence_part = sentences[i].strip()
239
- if sentence_part:
240
- refined_sentences.append(sentence_part.capitalize())
241
- if i + 1 < len(sentences):
242
- refined_sentences.append(sentences[i + 1])
243
-
244
- refined_sentences = ''.join(refined_sentences).strip()
245
- return refined_sentences
246
-
247
-
248
- # ===========================
249
- # RAG PIPELINE
250
- # ===========================
251
-
252
- def rag_pipeline(user_query, top_k=3, detail=False):
253
- """Main RAG pipeline: retrieve context and generate answer.
254
-
255
- Args:
256
- user_query (str): User's medical question
257
- top_k (int): Number of context chunks to retrieve
258
- detail (bool): Whether to show detailed context information
259
-
260
- Returns:
261
- str or dict: Answer with optional context details
262
- """
263
- try:
264
- # Import retriever here to avoid circular imports
265
- from retriever import retriever_simple
266
-
267
- # Check for emergencies
268
- emergency_keywords = [
269
- "emergency", "severe pain", "bleeding", "blind",
270
- "lose consciousness", "pass out", "call 911"
271
- ]
272
- if any(keyword in user_query.lower() for keyword in emergency_keywords):
273
- return """I am an AI and cannot provide medical advice for emergencies.
274
- PLEASE CONTACT EMERGENCY SERVICES OR A MEDICAL PROFESSIONAL IMMEDIATELY."""
275
-
276
- # 1. Retrieve context
277
- print(f"[RAG] Retrieving context for query: {user_query[:50]}...")
278
- contexts = retriever_simple(user_query, top_k=top_k, detail=detail)
279
-
280
- if not contexts:
281
- return "I couldn't find relevant medical information to answer your question. Please try rephrasing."
282
-
283
- # 2. Build prompt
284
- print(f"[RAG] Building prompt...")
285
- prompt = build_prompt(user_query, contexts, user_query)
286
-
287
- # 3. Generate response
288
- print(f"[RAG] Generating response...")
289
- llm = load_finetuned_model()
290
- result = llm(
291
- prompt,
292
- max_new_tokens=70,
293
- num_beams=3,
294
- early_stopping=True,
295
- do_sample=False,
296
- repetition_penalty=1.4,
297
- eos_token_id=llm.tokenizer.eos_token_id
298
- )
299
-
300
- answer = result[0]['generated_text'].strip()
301
- answer = refine_response(answer)
302
-
303
- # 4. Calculate confidence from retrieval distances
304
- if contexts and len(contexts) > 0:
305
- avg_distance = sum(
306
- c.get('chroma_distance', 1.0) for c in contexts
307
- ) / len(contexts)
308
- confidence_score = max(0, min(100, (1 - avg_distance) * 100))
309
-
310
- if confidence_score < 40:
311
- final_answer = f"⚠️ Low Confidence ({confidence_score:.1f}%)\n\n{answer}"
312
- else:
313
- final_answer = f"{answer}\n\n✓ Confidence: {confidence_score:.1f}%"
314
- else:
315
- final_answer = answer
316
-
317
- return final_answer
318
-
319
- except Exception as e:
320
- error_msg = f"Error in RAG pipeline: {str(e)}"
321
- print(error_msg)
322
- traceback.print_exc()
323
- return error_msg