amiraghhh commited on
Commit
f733dad
·
verified ·
1 Parent(s): dc394f7

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +185 -185
retriever.py CHANGED
@@ -1,185 +1,185 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
- from utils import get_embed_model
4
-
5
-
6
- # ============================================================================
7
- # LAZY-LOAD RERANKER AND QUERY REWRITER (load on first use)
8
- # ============================================================================
9
-
10
- _rerank_tokenizer = None
11
- _rerank_model = None
12
- _rewritter_llm = None
13
-
14
- def get_reranker():
15
- """Load reranker model (lazy-loaded on first use)"""
16
- global _rerank_tokenizer, _rerank_model
17
- if _rerank_tokenizer is None:
18
- _rerank_tokenizer = AutoTokenizer.from_pretrained("castorini/monot5-base-msmarco")
19
- _rerank_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/monot5-base-msmarco")
20
- _rerank_model.eval()
21
- return _rerank_tokenizer, _rerank_model
22
-
23
- def get_query_rewriter():
24
- """Load query rewriter (lazy-loaded on first use)"""
25
- global _rewritter_llm
26
- if _rewritter_llm is None:
27
- _rewritter_llm = pipeline(
28
- "text2text-generation",
29
- model="google/flan-t5-small",
30
- max_length=64,
31
- do_sample=False,
32
- temperature=0.3,
33
- repetition_penalty=1.3,
34
- no_repeat_ngram_size=2
35
- )
36
- return _rewritter_llm
37
-
38
-
39
- def rewrite_query(user_query):
40
- """Rewrite user query to be more specific and medical-focused.
41
- Returns: str(rewritten_query)"""
42
-
43
- rewritter_llm = get_query_rewriter()
44
-
45
- prompt = f"""Rewrite the input into a clear medical question following these patterns
46
- Input: my head hurts
47
- Output: What causes headaches?
48
-
49
- Input: i keep vomiting but feel ok afterwards
50
- Output: What causes cyclic vomiting?
51
-
52
- Input: chest pain when breathing
53
- Output: What causes chest pain during breathing?
54
-
55
- Input: {user_query}
56
- Output:
57
- """
58
-
59
- llm_output = rewritter_llm(prompt)
60
- rewritten_query = llm_output[0]['generated_text']
61
- rewritten_query = rewritten_query.replace("Output:", "").strip()
62
-
63
- return rewritten_query.strip()
64
-
65
-
66
- def get_monot5_scores(rewritten_query, chunks):
67
- """Calculate relevance scores for chunks against query using MonoT5 reranker.
68
- Returns: list(scores)"""
69
-
70
- rerank_tokenizer, rerank_model = get_reranker()
71
-
72
- # Format input as "Query: ABC? Document: XYZ... Relevant:"
73
- inputs = [f"Query: {rewritten_query} Document: {c} Relevant:" for c in chunks]
74
-
75
- # Tokenize
76
- tokenized_inputs = rerank_tokenizer(
77
- inputs,
78
- padding=True,
79
- truncation=True,
80
- return_tensors='pt'
81
- )
82
-
83
- # Generate predictions
84
- with torch.no_grad():
85
- outputs = rerank_model.generate(
86
- input_ids=tokenized_inputs['input_ids'],
87
- attention_mask=tokenized_inputs['attention_mask'],
88
- max_new_tokens=1,
89
- return_dict_in_generate=True,
90
- output_scores=True
91
- )
92
-
93
- # Extract "true" token probability scores
94
- true_token_id = rerank_tokenizer.encode("true")[0]
95
- batch_scores = outputs.scores[0][:, true_token_id]
96
-
97
- return batch_scores.tolist()
98
-
99
-
100
- def retrieve(vector_store, query, top_k=3, use_reranking=True, detail=False):
101
- """Retrieve relevant context chunks for a query with optional reranking.
102
- Returns: list(contexts)"""
103
-
104
- embed_model = get_embed_model()
105
-
106
- is_single_query = isinstance(query, str)
107
- queries = [query] if is_single_query else query
108
-
109
- # Rewrite queries for better matching
110
- rewritten_queries = [rewrite_query(q) for q in queries]
111
-
112
- # Embed rewritten queries
113
- q_embeddings = embed_model.encode(rewritten_queries).tolist()
114
-
115
- # Vector search
116
- search_results = vector_store.query(
117
- query_embeddings=q_embeddings,
118
- n_results=10
119
- )
120
-
121
- all_contexts = []
122
-
123
- for i in range(len(queries)):
124
- contexts_for_query = []
125
-
126
- if not search_results['documents'][i]:
127
- all_contexts.append([])
128
- continue
129
-
130
- retrieved_chunks = search_results['documents'][i]
131
- retrieved_metas = search_results['metadatas'][i]
132
- retrieved_ids = search_results['ids'][i]
133
- retrieved_distances = search_results.get('distances', [[]])[i]
134
-
135
- # Rerank if enabled
136
- if use_reranking:
137
- rerank_scores = get_monot5_scores(rewritten_queries[i], retrieved_chunks)
138
-
139
- candidates = []
140
- for r in range(len(retrieved_chunks)):
141
- candidates.append({
142
- 'chunk_id': retrieved_ids[r],
143
- 'chunk_answer': retrieved_chunks[r],
144
- 'question': retrieved_metas[r]['question'],
145
- 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
146
- 'chroma_dist': retrieved_distances[r],
147
- 'rerank_score': rerank_scores[r]
148
- })
149
-
150
- # Sort by rerank score
151
- candidates.sort(key=lambda s: s['rerank_score'], reverse=True)
152
- sorted_results = candidates[:top_k]
153
- else:
154
- # Just use top-k from vector search
155
- sorted_results = [
156
- {
157
- 'chunk_id': retrieved_ids[r],
158
- 'chunk_answer': retrieved_chunks[r],
159
- 'question': retrieved_metas[r]['question'],
160
- 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
161
- 'chroma_dist': retrieved_distances[r]
162
- }
163
- for r in range(min(top_k, len(retrieved_chunks)))
164
- ]
165
-
166
- # Format output
167
- for item_dict in sorted_results:
168
- item = {
169
- 'chunk_id': item_dict['chunk_id'],
170
- 'chunk_answer': item_dict['chunk_answer'],
171
- 'question': item_dict['question']
172
- }
173
- if detail:
174
- item['focus_area'] = item_dict['focus_area']
175
- item['chroma_distance'] = round(item_dict['chroma_dist'], 3)
176
- if 'rerank_score' in item_dict:
177
- item['rerank_score'] = round(item_dict['rerank_score'], 3)
178
- contexts_for_query.append(item)
179
-
180
- all_contexts.append(contexts_for_query)
181
-
182
- if is_single_query:
183
- return all_contexts[0]
184
- else:
185
- return all_contexts
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from utils import get_embed_model
4
+
5
+
6
+ # ============================================================================
7
+ # LAZY-LOAD RERANKER AND QUERY REWRITER (load on first use)
8
+ # ============================================================================
9
+
10
+ _rerank_tokenizer = None
11
+ _rerank_model = None
12
+ _rewritter_llm = None
13
+
14
+ def get_reranker():
15
+ """Load reranker model (lazy-loaded on first use)"""
16
+ global _rerank_tokenizer, _rerank_model
17
+ if _rerank_tokenizer is None:
18
+ _rerank_tokenizer = AutoTokenizer.from_pretrained("castorini/monot5-base-msmarco")
19
+ _rerank_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/monot5-base-msmarco")
20
+ _rerank_model.eval()
21
+ return _rerank_tokenizer, _rerank_model
22
+
23
+ def get_query_rewriter():
24
+ """Load query rewriter (lazy-loaded on first use)"""
25
+ global _rewritter_llm
26
+ if _rewritter_llm is None:
27
+ _rewritter_llm = pipeline(
28
+ "text2text-generation",
29
+ model="google/flan-t5-small",
30
+ max_length=64,
31
+ do_sample=False,
32
+ temperature=0.3,
33
+ repetition_penalty=1.3,
34
+ no_repeat_ngram_size=2
35
+ )
36
+ return _rewritter_llm
37
+
38
+
39
+ def rewrite_query(user_query):
40
+ """Rewrite user query to be more specific and medical-focused.
41
+ Returns: str(rewritten_query)"""
42
+
43
+ rewritter_llm = get_query_rewriter()
44
+
45
+ prompt = f"""Rewrite the input into a clear medical question following these patterns
46
+ Input: my head hurts
47
+ Output: What causes headaches?
48
+
49
+ Input: i keep vomiting but feel ok afterwards
50
+ Output: What causes cyclic vomiting?
51
+
52
+ Input: chest pain when breathing
53
+ Output: What causes chest pain during breathing?
54
+
55
+ Input: {user_query}
56
+ Output:
57
+ """
58
+
59
+ llm_output = rewritter_llm(prompt)
60
+ rewritten_query = llm_output[0]['generated_text']
61
+ rewritten_query = rewritten_query.replace("Output:", "").strip()
62
+
63
+ return rewritten_query.strip()
64
+
65
+
66
+ def get_monot5_scores(rewritten_query, chunks):
67
+ """Calculate relevance scores for chunks against query using MonoT5 reranker.
68
+ Returns: list(scores)"""
69
+
70
+ rerank_tokenizer, rerank_model = get_reranker()
71
+
72
+ # Format input as "Query: ABC? Document: XYZ... Relevant:"
73
+ inputs = [f"Query: {rewritten_query} Document: {c} Relevant:" for c in chunks]
74
+
75
+ # Tokenize
76
+ tokenized_inputs = rerank_tokenizer(
77
+ inputs,
78
+ padding=True,
79
+ truncation=True,
80
+ return_tensors='pt'
81
+ )
82
+
83
+ # Generate predictions
84
+ with torch.no_grad():
85
+ outputs = rerank_model.generate(
86
+ input_ids=tokenized_inputs['input_ids'],
87
+ attention_mask=tokenized_inputs['attention_mask'],
88
+ max_new_tokens=1,
89
+ return_dict_in_generate=True,
90
+ output_scores=True
91
+ )
92
+
93
+ # Extract "true" token probability scores
94
+ true_token_id = rerank_tokenizer.encode("true")[0]
95
+ batch_scores = outputs.scores[0][:, true_token_id]
96
+
97
+ return batch_scores.tolist()
98
+
99
+
100
+ def retrieve(vector_store, query, top_k=3, use_reranking=True, detail=False):
101
+ """Retrieve relevant context chunks for a query with optional reranking.
102
+ Returns: list(contexts)"""
103
+
104
+ embed_model = get_embed_model()
105
+
106
+ is_single_query = isinstance(query, str)
107
+ queries = [query] if is_single_query else query
108
+
109
+ # Rewrite queries for better matching
110
+ rewritten_queries = [rewrite_query(q) for q in queries]
111
+
112
+ # Embed rewritten queries
113
+ q_embeddings = embed_model.encode(rewritten_queries).tolist()
114
+
115
+ # Vector search
116
+ search_results = vector_store.query(
117
+ query_embeddings=q_embeddings,
118
+ n_results=10
119
+ )
120
+
121
+ all_contexts = []
122
+
123
+ for i in range(len(queries)):
124
+ contexts_for_query = []
125
+
126
+ if not search_results['documents'][i]:
127
+ all_contexts.append([])
128
+ continue
129
+
130
+ retrieved_chunks = search_results['documents'][i]
131
+ retrieved_metas = search_results['metadatas'][i]
132
+ retrieved_ids = search_results['ids'][i]
133
+ retrieved_distances = search_results.get('distances', [[]])[i]
134
+
135
+ # Rerank if enabled
136
+ if use_reranking:
137
+ rerank_scores = get_monot5_scores(rewritten_queries[i], retrieved_chunks)
138
+
139
+ candidates = []
140
+ for r in range(len(retrieved_chunks)):
141
+ candidates.append({
142
+ 'chunk_id': retrieved_ids[r],
143
+ 'chunk_answer': retrieved_chunks[r],
144
+ 'question': retrieved_metas[r]['question'],
145
+ 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
146
+ 'chroma_dist': retrieved_distances[r],
147
+ 'rerank_score': rerank_scores[r]
148
+ })
149
+
150
+ # Sort by rerank score
151
+ candidates.sort(key=lambda s: s['rerank_score'], reverse=True)
152
+ sorted_results = candidates[:top_k]
153
+ else:
154
+ # Just use top-k from vector search
155
+ sorted_results = [
156
+ {
157
+ 'chunk_id': retrieved_ids[r],
158
+ 'chunk_answer': retrieved_chunks[r],
159
+ 'question': retrieved_metas[r]['question'],
160
+ 'focus_area': retrieved_metas[r].get('focus_area', 'Unknown'),
161
+ 'chroma_dist': retrieved_distances[r]
162
+ }
163
+ for r in range(min(top_k, len(retrieved_chunks)))
164
+ ]
165
+
166
+ # Format output
167
+ for item_dict in sorted_results:
168
+ item = {
169
+ 'chunk_id': item_dict['chunk_id'],
170
+ 'chunk_answer': item_dict['chunk_answer'],
171
+ 'question': item_dict['question']
172
+ }
173
+ if detail:
174
+ item['focus_area'] = item_dict['focus_area']
175
+ item['chroma_distance'] = round(item_dict['chroma_dist'], 3)
176
+ if 'rerank_score' in item_dict:
177
+ item['rerank_score'] = round(item_dict['rerank_score'], 3)
178
+ contexts_for_query.append(item)
179
+
180
+ all_contexts.append(contexts_for_query)
181
+
182
+ if is_single_query:
183
+ return all_contexts[0]
184
+ else:
185
+ return all_contexts