mlbench123 commited on
Commit
3afc71b
·
verified ·
1 Parent(s): 514c41a

Update rag_treatment_app.py

Browse files
Files changed (1) hide show
  1. rag_treatment_app.py +93 -14
rag_treatment_app.py CHANGED
@@ -103,9 +103,17 @@ class RAGTreatmentSearchApp:
103
  self.web_max_docs = int(os.getenv("WEB_MAX_DOCS", "6"))
104
  self.web_max_chars = int(os.getenv("WEB_MAX_CHARS", "1200"))
105
 
106
- # NEW: hard gate to prevent "empty issue" generic recommendations
107
  self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
108
 
 
 
 
 
 
 
 
 
109
  # ---------------- DB ----------------
110
  def _load_db(self) -> pd.DataFrame:
111
  xl = pd.ExcelFile(self.excel_path)
@@ -250,6 +258,42 @@ class RAGTreatmentSearchApp:
250
  break
251
  return out
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
254
  type_norm = _norm_type_choice(type_choice)
255
  query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
@@ -303,8 +347,12 @@ class RAGTreatmentSearchApp:
303
  return docs
304
 
305
  def _web_enrich_fallback(self, procedure: str) -> List[WebDoc]:
 
 
 
306
  if not self.web_enabled:
307
  return []
 
308
  queries = [
309
  f"{procedure} recovery swelling bruising downtime",
310
  f"{procedure} procedure time how long does it take",
@@ -385,12 +433,12 @@ Extract patient-facing procedure details for: {procedure}
385
  Use ONLY the evidence below. If not present, write "Not found in evidence."
386
 
387
  Return STRICT JSON with these keys:
388
- - invasiveness
389
- - duration
390
- - downtime
391
- - longevity
392
- - risks
393
- - best_suited_for
394
 
395
  Evidence:
396
  {evidence}
@@ -471,7 +519,7 @@ Evidence:
471
  retrieval_k: int = 12,
472
  final_k: int = 5,
473
  ) -> Dict[str, object]:
474
- # -------- NEW: Hard input gating to prevent empty issue searches --------
475
  region = (region or "").strip()
476
  sub_zone = (sub_zone or "").strip()
477
  issue_text = (issue_text or "").strip()
@@ -493,7 +541,7 @@ Evidence:
493
 
494
  candidates = self.semantic_search(region, sub_zone, type_choice, issue_text, top_k=int(retrieval_k))
495
 
496
- # If no candidates after filtering, return friendly message
497
  if not candidates:
498
  return {
499
  "answer_md": "No matching procedures found for your selected Region/Sub-Zone and issue. Please revise your inputs.",
@@ -501,13 +549,14 @@ Evidence:
501
  "_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
502
  }
503
 
504
- # Mismatch check using global signal (kept conservative)
505
  global_cands = self._global_semantic(issue_text, top_k=15)
506
  global_best = global_cands[0].similarity if global_cands else 0.0
507
  local_best = candidates[0].similarity if candidates else 0.0
508
 
509
  selected_region_norm = _norm(region)
510
  selected_sub_norm = _norm(sub_zone)
 
511
  selected_in_global = any(
512
  _norm(c.region) == selected_region_norm and (
513
  selected_sub_norm in _norm(c.sub_zone) or _norm(c.sub_zone) in selected_sub_norm
@@ -515,7 +564,27 @@ Evidence:
515
  for c in global_cands[:10]
516
  )
517
 
518
- if (global_best >= 0.50 and (global_best - local_best) >= 0.12 and not selected_in_global):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  suggestions = []
520
  seen = set()
521
  for c in global_cands:
@@ -531,13 +600,24 @@ Evidence:
531
  answer_md = f"""## Notice
532
  Sorry for inconvenience. Your selected body region/sub-zone is not appropriate as per your defined problem.
533
 
534
- ## Suggested Region/Sub-Zones
535
  {sug_lines}
536
 
537
  ## Next step
538
  Please select one of the suggested Region/Sub-Zones and run the search again.
539
  """.strip()
540
- return {"answer_md": answer_md, "sources": [], "_debug": {"mismatch": True}}
 
 
 
 
 
 
 
 
 
 
 
541
 
542
  best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
543
  if len(best) < int(final_k):
@@ -559,7 +639,6 @@ Please select one of the suggested Region/Sub-Zones and run the search again.
559
 
560
  answer_md = self._format_final_answer(best, web_by_proc)
561
 
562
- # de-dupe urls
563
  seen_u = set()
564
  dedup = []
565
  for u in urls:
 
103
  self.web_max_docs = int(os.getenv("WEB_MAX_DOCS", "6"))
104
  self.web_max_chars = int(os.getenv("WEB_MAX_CHARS", "1200"))
105
 
106
+ # NEW: hard gate to prevent empty/generic recommendation runs
107
  self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
108
 
109
+ # NEW: mismatch sensitivity knobs
110
+ # If issue-only similarity within selected region is below this => likely irrelevant
111
+ self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
112
+ # If global best similarity is above this => the issue clearly maps somewhere else
113
+ self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
114
+ # Additional delta guard (keeps your old behavior but improved)
115
+ self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
116
+
117
  # ---------------- DB ----------------
118
  def _load_db(self) -> pd.DataFrame:
119
  xl = pd.ExcelFile(self.excel_path)
 
258
  break
259
  return out
260
 
261
+ def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
262
+ """
263
+ NEW: Compute issue-only similarity *within selected region/subzone*.
264
+ This avoids the bias from including region/subzone words in the query.
265
+ """
266
+ issue_text = (issue_text or "").strip()
267
+ if not issue_text:
268
+ return 0.0
269
+
270
+ # consider both types if user chose both
271
+ t = _norm_type_choice(type_choice)
272
+ if t == "both":
273
+ idx_s = self._candidate_indices(region, sub_zone, "surgical")
274
+ idx_n = self._candidate_indices(region, sub_zone, "non-surgical")
275
+ idxs = np.unique(np.concatenate([idx_s, idx_n])) if (idx_s.size or idx_n.size) else np.array([], dtype=int)
276
+ else:
277
+ idxs = self._candidate_indices(region, sub_zone, t)
278
+
279
+ if idxs.size == 0:
280
+ # fall back to region only, still issue-only
281
+ if t == "both":
282
+ idx_s = self._candidate_indices(region, "", "surgical")
283
+ idx_n = self._candidate_indices(region, "", "non-surgical")
284
+ idxs = np.unique(np.concatenate([idx_s, idx_n])) if (idx_s.size or idx_n.size) else np.array([], dtype=int)
285
+ else:
286
+ idxs = self._candidate_indices(region, "", t)
287
+
288
+ if idxs.size == 0:
289
+ return 0.0
290
+
291
+ q_emb = self.model.encode([issue_text], convert_to_numpy=True).astype(np.float32)
292
+ sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
293
+ if sims.size == 0:
294
+ return 0.0
295
+ return float(np.max(sims))
296
+
297
  def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
298
  type_norm = _norm_type_choice(type_choice)
299
  query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
 
347
  return docs
348
 
349
  def _web_enrich_fallback(self, procedure: str) -> List[WebDoc]:
350
+ """
351
+ Second-pass retrieval only if extraction is failing.
352
+ """
353
  if not self.web_enabled:
354
  return []
355
+
356
  queries = [
357
  f"{procedure} recovery swelling bruising downtime",
358
  f"{procedure} procedure time how long does it take",
 
433
  Use ONLY the evidence below. If not present, write "Not found in evidence."
434
 
435
  Return STRICT JSON with these keys:
436
+ - invasiveness (Non-invasive / Minimally invasive / Surgical, or evidence-based wording)
437
+ - duration (typical treatment/procedure time; include units if present)
438
+ - downtime (recovery/downtime; typical range)
439
+ - longevity (how long results last; typical range)
440
+ - risks (common risks/side effects; concise)
441
+ - best_suited_for (who it is for; concise)
442
 
443
  Evidence:
444
  {evidence}
 
519
  retrieval_k: int = 12,
520
  final_k: int = 5,
521
  ) -> Dict[str, object]:
522
+ # --- Hard gate (must have region, sub-zone, and meaningful issue text) ---
523
  region = (region or "").strip()
524
  sub_zone = (sub_zone or "").strip()
525
  issue_text = (issue_text or "").strip()
 
541
 
542
  candidates = self.semantic_search(region, sub_zone, type_choice, issue_text, top_k=int(retrieval_k))
543
 
544
+ # If nothing returned locally, show friendly message
545
  if not candidates:
546
  return {
547
  "answer_md": "No matching procedures found for your selected Region/Sub-Zone and issue. Please revise your inputs.",
 
549
  "_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
550
  }
551
 
552
+ # ---------------- Improved mismatch detection ----------------
553
  global_cands = self._global_semantic(issue_text, top_k=15)
554
  global_best = global_cands[0].similarity if global_cands else 0.0
555
  local_best = candidates[0].similarity if candidates else 0.0
556
 
557
  selected_region_norm = _norm(region)
558
  selected_sub_norm = _norm(sub_zone)
559
+
560
  selected_in_global = any(
561
  _norm(c.region) == selected_region_norm and (
562
  selected_sub_norm in _norm(c.sub_zone) or _norm(c.sub_zone) in selected_sub_norm
 
564
  for c in global_cands[:10]
565
  )
566
 
567
+ # NEW: issue-only best similarity inside selected region/subzone
568
+ local_issue_best = self._local_issue_only_best_sim(region, sub_zone, type_choice, issue_text)
569
+
570
+ # Trigger mismatch if:
571
+ # - issue-only relevance to selected region is low
572
+ # - but global mapping is strong
573
+ # - and global top does not align with selected region/subzone
574
+ mismatch_strict = (
575
+ (local_issue_best > 0.0 and local_issue_best < self.local_issue_min_sim) and
576
+ (global_best >= self.global_issue_min_sim) and
577
+ (not selected_in_global)
578
+ )
579
+
580
+ # Keep your older delta-based signal (still useful)
581
+ mismatch_delta = (
582
+ (global_best >= self.global_issue_min_sim) and
583
+ ((global_best - local_best) >= self.global_local_delta) and
584
+ (not selected_in_global)
585
+ )
586
+
587
+ if mismatch_strict or mismatch_delta:
588
  suggestions = []
589
  seen = set()
590
  for c in global_cands:
 
600
  answer_md = f"""## Notice
601
  Sorry for inconvenience. Your selected body region/sub-zone is not appropriate as per your defined problem.
602
 
603
+ ## Suggested Region/Sub-Zones (based on your issue text)
604
  {sug_lines}
605
 
606
  ## Next step
607
  Please select one of the suggested Region/Sub-Zones and run the search again.
608
  """.strip()
609
+
610
+ return {
611
+ "answer_md": answer_md,
612
+ "sources": [],
613
+ "_debug": {
614
+ "mismatch": True,
615
+ "global_best": round(global_best, 4),
616
+ "local_best": round(local_best, 4),
617
+ "local_issue_best": round(local_issue_best, 4),
618
+ },
619
+ }
620
+ # -------------------------------------------------------------
621
 
622
  best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
623
  if len(best) < int(final_k):
 
639
 
640
  answer_md = self._format_final_answer(best, web_by_proc)
641
 
 
642
  seen_u = set()
643
  dedup = []
644
  for u in urls: