ZedLow commited on
Commit
8872cf1
·
verified ·
1 Parent(s): ab45f59

Update rag/pipeline.py

Browse files
Files changed (1) hide show
  1. rag/pipeline.py +67 -67
rag/pipeline.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torch.nn.functional as F
3
  from PIL import Image
4
  from typing import List, Dict, Any, Callable, Tuple
5
-
6
  from rag.prompting import build_messages
7
  from rag.config import Settings
8
  from rag.logging_utils import get_logger
@@ -11,36 +10,54 @@ from qwen_vl_utils import process_vision_info
11
 
12
  logger = get_logger(__name__)
13
 
14
-
15
  def _route_companies(
16
  query: str,
17
  router_model,
18
  settings: Settings,
19
  ) -> Tuple[List[str], str | None]:
20
 
 
 
 
 
 
 
 
21
  labels = list(settings.router_labels)
22
  entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
23
-
24
- detected_companies: List[str] = []
 
 
25
  for e in entities:
26
- name = (e.get("text") or "").lower()
27
-
28
 
29
- if "microsoft" in name or "msft" in name:
30
- detected_companies.append("Microsoft")
31
- elif "apple" in name or "aapl" in name:
32
- detected_companies.append("Apple")
33
- else:
34
- # Hard reject anything outside the allowlist to keep retrieval constrained.
35
- return [], (
36
- f"⚠️ I detected a request for '{e.get('text')}', "
37
- "but I only have access to Microsoft and Apple data."
38
- )
 
39
 
 
 
 
 
 
40
 
41
- detected_companies = list(set(detected_companies))
42
- return detected_companies, None
 
 
 
43
 
 
44
 
45
  def _filter_docs(
46
  dataset: List[Dict[str, Any]],
@@ -50,19 +67,17 @@ def _filter_docs(
50
  valid_docs = []
51
  for i, doc in enumerate(dataset):
52
  doc_name = doc.get("doc_name", "Doc")
53
-
54
 
55
  if detected_companies:
56
  if not any(company in doc_name for company in detected_companies):
57
  continue
58
-
59
  text = (doc.get("text") or "").strip()
60
  if text:
61
  valid_docs.append({"text": text, "original_index": i, "doc_name": doc_name})
62
-
63
  return valid_docs
64
 
65
-
66
  def _prepare_images(
67
  dataset: List[Dict[str, Any]],
68
  valid_docs: List[Dict[str, Any]],
@@ -70,39 +85,34 @@ def _prepare_images(
70
  r_scores,
71
  top_k_indices_local: List[int],
72
  ):
73
-
74
  images_content = []
75
  gallery_preview = []
76
  meta_info = ""
77
-
78
  for idx_local in top_k_indices_local:
79
- # idx_local is in reranker score space (over retrieved candidates).
80
  idx_in_valid = top_k_indices[idx_local]
81
  final_doc_idx = valid_docs[idx_in_valid]["original_index"]
82
-
83
  doc = dataset[final_doc_idx]
84
  image_path = doc["image_path"]
85
  score = r_scores[idx_local].item()
86
  doc_name = doc.get("doc_name", "Unknown")
87
-
88
  try:
89
  img = Image.open(image_path)
90
-
91
-
92
  header_text = f"SOURCE DOCUMENT: {doc_name} (Confidence: {score:.2f})\n"
 
93
  images_content.append({"type": "text", "text": header_text})
94
  images_content.append({"type": "image", "image": img})
95
-
96
  gallery_preview.append((img, doc_name))
97
  meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
98
-
99
  except Exception as e:
100
  logger.warning("Failed to open image %s: %s", image_path, e)
101
  continue
102
-
103
  return images_content, gallery_preview, meta_info
104
 
105
-
106
  def make_retrieve_and_answer(
107
  dataset: List[Dict[str, Any]],
108
  models,
@@ -112,35 +122,32 @@ def make_retrieve_and_answer(
112
  if settings is None:
113
  settings = models.settings if hasattr(models, "settings") else Settings()
114
 
115
- # Hugging Face Spaces: ensure the handler runs on GPU when deployed.
116
  import spaces
117
 
118
  @spaces.GPU
119
  def retrieve_and_answer(query: str):
120
  logger.info("User question: %s", query)
121
-
122
  if not dataset:
123
  return [], "Empty corpus", "No documents loaded."
124
 
125
- # Step 1 — Entity routing (CPU): determine allowed company scope (or reject).
126
  detected_companies, blocked_msg = _route_companies(query, models.router_model, settings)
 
127
  if blocked_msg is not None:
128
  return [], "", blocked_msg
129
-
130
  logger.info("Router detected companies: %s", detected_companies)
131
 
132
- # Step 2 — Filter corpus by company scope (document-level access control).
133
  valid_docs = _filter_docs(dataset, detected_companies)
 
134
  if not valid_docs:
135
- return [], "", "No relevant documents found based on routed scope."
136
 
137
- # Step 3 — Dense retrieval: compute query embedding and score against doc embeddings.
138
- # Note: embeddings are computed on-the-fly for a small demo corpus (no persistent index).
139
  query_text = (
140
  "Instruct: Given a user query, retrieve relevant passages that answer the query.\n"
141
  f"Query: {query}"
142
  )
143
-
144
  with torch.no_grad():
145
  q_inputs = models.embed_tokenizer(
146
  [query_text],
@@ -149,14 +156,14 @@ def make_retrieve_and_answer(
149
  truncation=True,
150
  return_tensors="pt",
151
  ).to(models.embed_model.device)
152
-
153
  q_outputs = models.embed_model(**q_inputs)
154
  q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs["attention_mask"])
155
  q_emb = F.normalize(q_emb, p=2, dim=1)
156
-
157
  d_embeddings_list = []
158
  doc_texts = [d["text"] for d in valid_docs]
159
-
160
  for i in range(0, len(doc_texts), 1):
161
  d_inputs = models.embed_tokenizer(
162
  doc_texts[i:i + 1],
@@ -165,21 +172,20 @@ def make_retrieve_and_answer(
165
  truncation=True,
166
  return_tensors="pt",
167
  ).to(models.embed_model.device)
168
-
169
  d_outputs = models.embed_model(**d_inputs)
170
  batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs["attention_mask"])
171
  batch_emb = F.normalize(batch_emb, p=2, dim=1)
172
  d_embeddings_list.append(batch_emb)
173
-
174
  d_emb_final = torch.cat(d_embeddings_list, dim=0)
175
  scores = (q_emb @ d_emb_final.T).squeeze(0)
176
-
177
  k_val = min(settings.embed_top_k, len(scores))
178
  top_k_indices = torch.topk(scores, k=k_val).indices.tolist()
179
 
180
- # Step 4 — Cross-encoder reranking over retrieved candidates.
181
  pairs = [[query, valid_docs[idx]["text"]] for idx in top_k_indices]
182
-
183
  with torch.no_grad():
184
  r_inputs = models.rerank_tokenizer(
185
  pairs,
@@ -188,48 +194,42 @@ def make_retrieve_and_answer(
188
  return_tensors="pt",
189
  max_length=settings.rerank_max_length,
190
  ).to(models.rerank_model.device)
191
-
192
  r_scores = models.rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
193
-
194
  k_rerank = min(settings.rerank_top_k, len(r_scores))
195
  top_k_indices_local = torch.topk(r_scores, k=k_rerank).indices.tolist()
196
 
197
- # Step 5 Build multimodal context + provenance metadata for inspection.
198
- meta_info = ""
199
- if detected_companies:
200
- meta_info += f"**AI Router Focus:** {', '.join(detected_companies)}\n\n"
201
- else:
202
- meta_info += "**AI Router Mode:** Broad Search (No specific company detected)\n\n"
203
-
204
  images_content, gallery_preview, meta_sources = _prepare_images(
205
  dataset, valid_docs, top_k_indices, r_scores, top_k_indices_local
206
  )
207
  meta_info += meta_sources
208
-
209
  if not images_content:
210
  return [], "", "No images found for the retrieved passages."
211
 
212
- # Step 6 — Vision-native generation: answer only from provided visual evidence.
213
  messages = build_messages(query, images_content)
214
-
215
  text_input = models.gen_processor.apply_chat_template(
216
  messages, tokenize=False, add_generation_prompt=True
217
  )
218
  image_inputs, _video_inputs = process_vision_info(messages)
219
-
220
  inputs = models.gen_processor(
221
  text=[text_input],
222
  images=image_inputs,
223
  padding=True,
224
  return_tensors="pt",
225
  ).to(models.gen_model.device)
226
-
227
  generated_ids = models.gen_model.generate(**inputs, max_new_tokens=settings.max_new_tokens)
228
-
229
- # Remove the prompt tokens to keep only the generated answer.
230
  generated_ids_trimmed = [
231
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
232
  ]
 
233
  response = models.gen_processor.batch_decode(
234
  generated_ids_trimmed,
235
  skip_special_tokens=True,
@@ -238,4 +238,4 @@ def make_retrieve_and_answer(
238
 
239
  return gallery_preview, meta_info, response
240
 
241
- return retrieve_and_answer
 
2
  import torch.nn.functional as F
3
  from PIL import Image
4
  from typing import List, Dict, Any, Callable, Tuple
 
5
  from rag.prompting import build_messages
6
  from rag.config import Settings
7
  from rag.logging_utils import get_logger
 
10
 
11
  logger = get_logger(__name__)
12
 
 
13
  def _route_companies(
14
  query: str,
15
  router_model,
16
  settings: Settings,
17
  ) -> Tuple[List[str], str | None]:
18
 
19
+ allowed_companies = {
20
+ "apple": "Apple",
21
+ "aapl": "Apple",
22
+ "microsoft": "Microsoft",
23
+ "msft": "Microsoft"
24
+ }
25
+
26
  labels = list(settings.router_labels)
27
  entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
28
+
29
+ detected_targets = []
30
+ unsupported_targets = []
31
+
32
  for e in entities:
33
+ name_clean = (e.get("text") or "").lower().strip()
34
+ found_match = False
35
 
36
+ for key, canonical_name in allowed_companies.items():
37
+ if key in name_clean:
38
+ detected_targets.append(canonical_name)
39
+ found_match = True
40
+ break
41
+
42
+ if not found_match:
43
+ unsupported_targets.append(e.get("text"))
44
+
45
+ detected_targets = list(set(detected_targets))
46
+ unsupported_targets = list(set(unsupported_targets))
47
 
48
+ if unsupported_targets:
49
+ return [], (
50
+ f"⛔ **Out of Scope:** I detected a request for **{', '.join(unsupported_targets)}**. "
51
+ "This system only has access to **Microsoft** and **Apple** data."
52
+ )
53
 
54
+ if not detected_targets:
55
+ return [], (
56
+ "❓ **Ambiguous Query:** I could not identify a specific company (Apple or Microsoft). "
57
+ "Please name the company you want to analyze."
58
+ )
59
 
60
+ return detected_targets, None
61
 
62
  def _filter_docs(
63
  dataset: List[Dict[str, Any]],
 
67
  valid_docs = []
68
  for i, doc in enumerate(dataset):
69
  doc_name = doc.get("doc_name", "Doc")
 
70
 
71
  if detected_companies:
72
  if not any(company in doc_name for company in detected_companies):
73
  continue
74
+
75
  text = (doc.get("text") or "").strip()
76
  if text:
77
  valid_docs.append({"text": text, "original_index": i, "doc_name": doc_name})
78
+
79
  return valid_docs
80
 
 
81
  def _prepare_images(
82
  dataset: List[Dict[str, Any]],
83
  valid_docs: List[Dict[str, Any]],
 
85
  r_scores,
86
  top_k_indices_local: List[int],
87
  ):
 
88
  images_content = []
89
  gallery_preview = []
90
  meta_info = ""
91
+
92
  for idx_local in top_k_indices_local:
 
93
  idx_in_valid = top_k_indices[idx_local]
94
  final_doc_idx = valid_docs[idx_in_valid]["original_index"]
95
+
96
  doc = dataset[final_doc_idx]
97
  image_path = doc["image_path"]
98
  score = r_scores[idx_local].item()
99
  doc_name = doc.get("doc_name", "Unknown")
100
+
101
  try:
102
  img = Image.open(image_path)
 
 
103
  header_text = f"SOURCE DOCUMENT: {doc_name} (Confidence: {score:.2f})\n"
104
+
105
  images_content.append({"type": "text", "text": header_text})
106
  images_content.append({"type": "image", "image": img})
107
+
108
  gallery_preview.append((img, doc_name))
109
  meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
 
110
  except Exception as e:
111
  logger.warning("Failed to open image %s: %s", image_path, e)
112
  continue
113
+
114
  return images_content, gallery_preview, meta_info
115
 
 
116
  def make_retrieve_and_answer(
117
  dataset: List[Dict[str, Any]],
118
  models,
 
122
  if settings is None:
123
  settings = models.settings if hasattr(models, "settings") else Settings()
124
 
 
125
  import spaces
126
 
127
  @spaces.GPU
128
  def retrieve_and_answer(query: str):
129
  logger.info("User question: %s", query)
130
+
131
  if not dataset:
132
  return [], "Empty corpus", "No documents loaded."
133
 
 
134
  detected_companies, blocked_msg = _route_companies(query, models.router_model, settings)
135
+
136
  if blocked_msg is not None:
137
  return [], "", blocked_msg
138
+
139
  logger.info("Router detected companies: %s", detected_companies)
140
 
 
141
  valid_docs = _filter_docs(dataset, detected_companies)
142
+
143
  if not valid_docs:
144
+ return [], "", "System Error: Valid targets detected but no matching documents found."
145
 
 
 
146
  query_text = (
147
  "Instruct: Given a user query, retrieve relevant passages that answer the query.\n"
148
  f"Query: {query}"
149
  )
150
+
151
  with torch.no_grad():
152
  q_inputs = models.embed_tokenizer(
153
  [query_text],
 
156
  truncation=True,
157
  return_tensors="pt",
158
  ).to(models.embed_model.device)
159
+
160
  q_outputs = models.embed_model(**q_inputs)
161
  q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs["attention_mask"])
162
  q_emb = F.normalize(q_emb, p=2, dim=1)
163
+
164
  d_embeddings_list = []
165
  doc_texts = [d["text"] for d in valid_docs]
166
+
167
  for i in range(0, len(doc_texts), 1):
168
  d_inputs = models.embed_tokenizer(
169
  doc_texts[i:i + 1],
 
172
  truncation=True,
173
  return_tensors="pt",
174
  ).to(models.embed_model.device)
175
+
176
  d_outputs = models.embed_model(**d_inputs)
177
  batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs["attention_mask"])
178
  batch_emb = F.normalize(batch_emb, p=2, dim=1)
179
  d_embeddings_list.append(batch_emb)
180
+
181
  d_emb_final = torch.cat(d_embeddings_list, dim=0)
182
  scores = (q_emb @ d_emb_final.T).squeeze(0)
183
+
184
  k_val = min(settings.embed_top_k, len(scores))
185
  top_k_indices = torch.topk(scores, k=k_val).indices.tolist()
186
 
 
187
  pairs = [[query, valid_docs[idx]["text"]] for idx in top_k_indices]
188
+
189
  with torch.no_grad():
190
  r_inputs = models.rerank_tokenizer(
191
  pairs,
 
194
  return_tensors="pt",
195
  max_length=settings.rerank_max_length,
196
  ).to(models.rerank_model.device)
197
+
198
  r_scores = models.rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
199
+
200
  k_rerank = min(settings.rerank_top_k, len(r_scores))
201
  top_k_indices_local = torch.topk(r_scores, k=k_rerank).indices.tolist()
202
 
203
+ meta_info = f"**AI Router Focus:** {', '.join(detected_companies)}\n\n"
204
+
 
 
 
 
 
205
  images_content, gallery_preview, meta_sources = _prepare_images(
206
  dataset, valid_docs, top_k_indices, r_scores, top_k_indices_local
207
  )
208
  meta_info += meta_sources
209
+
210
  if not images_content:
211
  return [], "", "No images found for the retrieved passages."
212
 
 
213
  messages = build_messages(query, images_content)
214
+
215
  text_input = models.gen_processor.apply_chat_template(
216
  messages, tokenize=False, add_generation_prompt=True
217
  )
218
  image_inputs, _video_inputs = process_vision_info(messages)
219
+
220
  inputs = models.gen_processor(
221
  text=[text_input],
222
  images=image_inputs,
223
  padding=True,
224
  return_tensors="pt",
225
  ).to(models.gen_model.device)
226
+
227
  generated_ids = models.gen_model.generate(**inputs, max_new_tokens=settings.max_new_tokens)
228
+
 
229
  generated_ids_trimmed = [
230
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
231
  ]
232
+
233
  response = models.gen_processor.batch_decode(
234
  generated_ids_trimmed,
235
  skip_special_tokens=True,
 
238
 
239
  return gallery_preview, meta_info, response
240
 
241
+ return retrieve_and_answer