amiraghhh commited on
Commit
27e8632
·
verified ·
1 Parent(s): d691cdc

Delete retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +0 -354
retriever.py DELETED
@@ -1,354 +0,0 @@
1
- """
2
- Retrieval functions for the RAG pipeline.
3
- Handles query rewriting, retrieval, and context re-ranking.
4
- """
5
-
6
- import torch
7
- from model import (
8
- load_embeddings_model,
9
- load_vector_store,
10
- load_rewriter_model,
11
- load_reranker_model
12
- )
13
-
14
-
15
- # ===========================
16
- # QUERY REWRITING
17
- # ===========================
18
-
19
- def rewrite_query(user_query):
20
- """Rewrite user query to be more specific and medical-focused.
21
-
22
- Args:
23
- user_query (str): Original user question
24
-
25
- Returns:
26
- str: Rewritten query
27
- """
28
- rewriter_llm = load_rewriter_model()
29
-
30
- # Few-shot prompting for medical question reformulation
31
- prompt = f"""Rewrite the input into a clear medical question following these patterns:
32
- Input: my head hurts
33
- Output: What causes headaches?
34
-
35
- Input: i keep vomiting but feel ok afterwards
36
- Output: What causes cyclic vomiting?
37
-
38
- Input: chest pain when breathing
39
- Output: What causes chest pain during breathing?
40
-
41
- Input: {user_query}
42
- Output:
43
- """
44
-
45
- llm_output = rewriter_llm(prompt)
46
- rewritten_query = llm_output[0]['generated_text']
47
- rewritten_query = rewritten_query.replace("Output:", "").strip()
48
-
49
- return rewritten_query.strip()
50
-
51
-
52
- # ===========================
53
- # RERANKING
54
- # ===========================
55
-
56
- def get_monot5_scores(rewritten_query, chunks):
57
- """Calculate relevance scores for chunks using MonoT5 reranker.
58
-
59
- Args:
60
- rewritten_query (str): The rewritten query
61
- chunks (list): List of retrieved text chunks
62
-
63
- Returns:
64
- list: Relevance scores for each chunk
65
- """
66
- rerank_tokenizer, rerank_model = load_reranker_model()
67
-
68
- # Format inputs: "Query: Q Document: D Relevant:"
69
- inputs = [f"Query: {rewritten_query} Document: {c} Relevant:" for c in chunks]
70
-
71
- # Tokenize
72
- tokenized_inputs = rerank_tokenizer(
73
- inputs,
74
- padding=True,
75
- truncation=True,
76
- return_tensors='pt'
77
- )
78
-
79
- # Generate predictions
80
- with torch.no_grad():
81
- outputs = rerank_model.generate(
82
- input_ids=tokenized_inputs['input_ids'],
83
- attention_mask=tokenized_inputs['attention_mask'],
84
- max_new_tokens=1,
85
- return_dict_in_generate=True,
86
- output_scores=True
87
- )
88
-
89
- # Extract "true" token scores
90
- true_token_id = rerank_tokenizer.encode("true")[0]
91
- batch_scores = outputs.scores[0][:, true_token_id]
92
-
93
- return batch_scores.tolist()
94
-
95
-
96
- # ===========================
97
- # RETRIEVAL FUNCTIONS
98
- # ===========================
99
-
100
- def retriever_simple(q, top_k=3, detail=False):
101
- """Simple retrieval without query rewriting or reranking.
102
-
103
- Args:
104
- q (str or list): Query or list of queries
105
- top_k (int): Number of results to return
106
- detail (bool): Include metadata in results
107
-
108
- Returns:
109
- list: Retrieved context chunks
110
- """
111
- embed_model = load_embeddings_model()
112
- vector_store = load_vector_store()
113
-
114
- is_single_query = isinstance(q, str)
115
- queries = [q] if is_single_query else q
116
-
117
- # Encode queries
118
- q_embeddings = embed_model.encode(queries).tolist()
119
-
120
- # Search vector store
121
- search_results = vector_store.query(
122
- query_embeddings=q_embeddings,
123
- n_results=10
124
- )
125
-
126
- all_contexts = []
127
-
128
- for i in range(len(queries)):
129
- contexts_for_query = []
130
- if search_results['documents'][i]:
131
- for r in range(len(search_results['ids'][i])):
132
- item = {
133
- 'chunk_id': search_results['ids'][i][r],
134
- 'chunk_answer': search_results['documents'][i][r],
135
- 'question': search_results['metadatas'][i][r]['question']
136
- }
137
- if detail:
138
- item['focus_area'] = search_results['metadatas'][i][r].get('focus_area', 'Unknown')
139
- item['source'] = search_results['metadatas'][i][r].get('source', 'Unknown')
140
- item['chroma_distance'] = round(search_results['distances'][i][r], 3)
141
- contexts_for_query.append(item)
142
- all_contexts.append(contexts_for_query)
143
-
144
- if is_single_query:
145
- return all_contexts[0]
146
- else:
147
- return all_contexts
148
-
149
-
150
- def retriever_with_rewriter(q, top_k=3, detail=False):
151
- """Retrieve with query rewriting but without reranking.
152
-
153
- Args:
154
- q (str or list): Query or list of queries
155
- top_k (int): Number of results to return
156
- detail (bool): Include metadata in results
157
-
158
- Returns:
159
- list: Retrieved context chunks
160
- """
161
- embed_model = load_embeddings_model()
162
- vector_store = load_vector_store()
163
-
164
- is_single_query = isinstance(q, str)
165
- queries = [q] if is_single_query else q
166
-
167
- # Rewrite queries
168
- rewritten_queries = [rewrite_query(query) for query in queries]
169
-
170
- # Encode rewritten queries
171
- q_embeddings = embed_model.encode(rewritten_queries).tolist()
172
-
173
- # Search vector store
174
- search_results = vector_store.query(
175
- query_embeddings=q_embeddings,
176
- n_results=10
177
- )
178
-
179
- all_contexts = []
180
-
181
- for i in range(len(queries)):
182
- contexts_for_query = []
183
- if search_results['documents'][i]:
184
- for r in range(min(top_k, len(search_results['ids'][i]))):
185
- item = {
186
- 'chunk_id': search_results['ids'][i][r],
187
- 'chunk_answer': search_results['documents'][i][r],
188
- 'question': search_results['metadatas'][i][r]['question']
189
- }
190
- if detail:
191
- item['focus_area'] = search_results['metadatas'][i][r].get('focus_area', 'Unknown')
192
- item['source'] = search_results['metadatas'][i][r].get('source', 'Unknown')
193
- item['chroma_distance'] = round(search_results['distances'][i][r], 3)
194
- contexts_for_query.append(item)
195
- all_contexts.append(contexts_for_query)
196
-
197
- if is_single_query:
198
- return all_contexts[0]
199
- else:
200
- return all_contexts
201
-
202
-
203
- def retriever_with_reranker(q, top_k=3, detail=False):
204
- """Retrieve with reranking but without query rewriting.
205
-
206
- Args:
207
- q (str or list): Query or list of queries
208
- top_k (int): Number of results to return
209
- detail (bool): Include metadata in results
210
-
211
- Returns:
212
- list: Retrieved context chunks
213
- """
214
- embed_model = load_embeddings_model()
215
- vector_store = load_vector_store()
216
-
217
- is_single_query = isinstance(q, str)
218
- queries = [q] if is_single_query else q
219
-
220
- q_embeddings = embed_model.encode(queries).tolist()
221
-
222
- search_results = vector_store.query(
223
- query_embeddings=q_embeddings,
224
- n_results=10
225
- )
226
-
227
- all_contexts = []
228
-
229
- for i in range(len(queries)):
230
- contexts_for_query = []
231
- if search_results['documents'][i]:
232
- retrieved_chunks = search_results['documents'][i]
233
- retrieved_metas = search_results['metadatas'][i]
234
- retrieved_ids = search_results['ids'][i]
235
- retrieved_distances = search_results.get('distances', [[]])[i]
236
-
237
- # Rerank
238
- rerank_scores = get_monot5_scores(queries[i], retrieved_chunks)
239
-
240
- if len(rerank_scores) == len(retrieved_chunks):
241
- candidates = []
242
- for r in range(len(retrieved_chunks)):
243
- candidates.append({
244
- 'chunk_id': retrieved_ids[r],
245
- 'chunk_answer': retrieved_chunks[r],
246
- 'question': retrieved_metas[r]['question'],
247
- 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
248
- 'chroma_distance': retrieved_distances[r],
249
- 'rerank_score': rerank_scores[r]
250
- })
251
-
252
- # Sort by rerank score
253
- candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
254
- sorted_results = candidates[:top_k]
255
-
256
- for item_dict in sorted_results:
257
- item = {
258
- 'chunk_id': item_dict['chunk_id'],
259
- 'chunk_answer': item_dict['chunk_answer'],
260
- 'question': item_dict['question']
261
- }
262
- if detail:
263
- item['focus_area'] = item_dict['focus_area']
264
- item['rerank_score'] = round(item_dict['rerank_score'], 3)
265
- item['chroma_distance'] = round(item_dict['chroma_distance'], 3)
266
- contexts_for_query.append(item)
267
-
268
- all_contexts.append(contexts_for_query)
269
-
270
- if is_single_query:
271
- return all_contexts[0]
272
- else:
273
- return all_contexts
274
-
275
-
276
- def retriever_full(q, top_k=3, detail=False):
277
- """Full retrieval with both query rewriting and reranking (recommended).
278
-
279
- Args:
280
- q (str or list): Query or list of queries
281
- top_k (int): Number of results to return
282
- detail (bool): Include metadata in results
283
-
284
- Returns:
285
- list: Retrieved context chunks
286
- """
287
- embed_model = load_embeddings_model()
288
- vector_store = load_vector_store()
289
-
290
- is_single_query = isinstance(q, str)
291
- queries = [q] if is_single_query else q
292
-
293
- all_contexts = []
294
-
295
- for query in queries:
296
- # Rewrite query
297
- q_rewritten = rewrite_query(query)
298
- q_embed = embed_model.encode([q_rewritten]).tolist()
299
-
300
- # Search
301
- search_result = vector_store.query(
302
- query_embeddings=q_embed,
303
- n_results=10
304
- )
305
-
306
- if not search_result['documents'][0]:
307
- all_contexts.append([])
308
- continue
309
-
310
- retrieved_chunks = search_result['documents'][0]
311
- retrieved_metas = search_result['metadatas'][0]
312
- retrieved_ids = search_result['ids'][0]
313
- retrieved_distances = search_result.get('distances', [[]])[0]
314
-
315
- # Rerank
316
- rerank_scores = get_monot5_scores(q_rewritten, retrieved_chunks)
317
-
318
- if len(rerank_scores) == len(retrieved_chunks):
319
- candidates = []
320
- for r in range(len(retrieved_chunks)):
321
- candidates.append({
322
- 'chunk_id': retrieved_ids[r],
323
- 'chunk_answer': retrieved_chunks[r],
324
- 'question': retrieved_metas[r]['question'],
325
- 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
326
- 'chroma_distance': retrieved_distances[r],
327
- 'rerank_score': rerank_scores[r]
328
- })
329
-
330
- # Sort by rerank score
331
- candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
332
- sorted_results = candidates[:top_k]
333
-
334
- contexts = []
335
- for i in sorted_results:
336
- item = {
337
- 'chunk_id': i['chunk_id'],
338
- 'chunk_answer': i['chunk_answer'],
339
- 'question': i['question']
340
- }
341
- if detail:
342
- item['focus_area'] = i['focus_area']
343
- item['rerank_score'] = round(i['rerank_score'], 3)
344
- item['chroma_distance'] = round(i['chroma_distance'], 3)
345
- contexts.append(item)
346
-
347
- all_contexts.append(contexts)
348
- else:
349
- all_contexts.append([])
350
-
351
- if is_single_query:
352
- return all_contexts[0]
353
- else:
354
- return all_contexts