Spaces:
Sleeping
Sleeping
Update rag_treatment_app.py
Browse files- rag_treatment_app.py +93 -14
rag_treatment_app.py
CHANGED
|
@@ -103,9 +103,17 @@ class RAGTreatmentSearchApp:
|
|
| 103 |
self.web_max_docs = int(os.getenv("WEB_MAX_DOCS", "6"))
|
| 104 |
self.web_max_chars = int(os.getenv("WEB_MAX_CHARS", "1200"))
|
| 105 |
|
| 106 |
-
# NEW: hard gate to prevent
|
| 107 |
self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
# ---------------- DB ----------------
|
| 110 |
def _load_db(self) -> pd.DataFrame:
|
| 111 |
xl = pd.ExcelFile(self.excel_path)
|
|
@@ -250,6 +258,42 @@ class RAGTreatmentSearchApp:
|
|
| 250 |
break
|
| 251 |
return out
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
|
| 254 |
type_norm = _norm_type_choice(type_choice)
|
| 255 |
query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
|
|
@@ -303,8 +347,12 @@ class RAGTreatmentSearchApp:
|
|
| 303 |
return docs
|
| 304 |
|
| 305 |
def _web_enrich_fallback(self, procedure: str) -> List[WebDoc]:
|
|
|
|
|
|
|
|
|
|
| 306 |
if not self.web_enabled:
|
| 307 |
return []
|
|
|
|
| 308 |
queries = [
|
| 309 |
f"{procedure} recovery swelling bruising downtime",
|
| 310 |
f"{procedure} procedure time how long does it take",
|
|
@@ -385,12 +433,12 @@ Extract patient-facing procedure details for: {procedure}
|
|
| 385 |
Use ONLY the evidence below. If not present, write "Not found in evidence."
|
| 386 |
|
| 387 |
Return STRICT JSON with these keys:
|
| 388 |
-
- invasiveness
|
| 389 |
-
- duration
|
| 390 |
-
- downtime
|
| 391 |
-
- longevity
|
| 392 |
-
- risks
|
| 393 |
-
- best_suited_for
|
| 394 |
|
| 395 |
Evidence:
|
| 396 |
{evidence}
|
|
@@ -471,7 +519,7 @@ Evidence:
|
|
| 471 |
retrieval_k: int = 12,
|
| 472 |
final_k: int = 5,
|
| 473 |
) -> Dict[str, object]:
|
| 474 |
-
# ---
|
| 475 |
region = (region or "").strip()
|
| 476 |
sub_zone = (sub_zone or "").strip()
|
| 477 |
issue_text = (issue_text or "").strip()
|
|
@@ -493,7 +541,7 @@ Evidence:
|
|
| 493 |
|
| 494 |
candidates = self.semantic_search(region, sub_zone, type_choice, issue_text, top_k=int(retrieval_k))
|
| 495 |
|
| 496 |
-
# If
|
| 497 |
if not candidates:
|
| 498 |
return {
|
| 499 |
"answer_md": "No matching procedures found for your selected Region/Sub-Zone and issue. Please revise your inputs.",
|
|
@@ -501,13 +549,14 @@ Evidence:
|
|
| 501 |
"_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
|
| 502 |
}
|
| 503 |
|
| 504 |
-
#
|
| 505 |
global_cands = self._global_semantic(issue_text, top_k=15)
|
| 506 |
global_best = global_cands[0].similarity if global_cands else 0.0
|
| 507 |
local_best = candidates[0].similarity if candidates else 0.0
|
| 508 |
|
| 509 |
selected_region_norm = _norm(region)
|
| 510 |
selected_sub_norm = _norm(sub_zone)
|
|
|
|
| 511 |
selected_in_global = any(
|
| 512 |
_norm(c.region) == selected_region_norm and (
|
| 513 |
selected_sub_norm in _norm(c.sub_zone) or _norm(c.sub_zone) in selected_sub_norm
|
|
@@ -515,7 +564,27 @@ Evidence:
|
|
| 515 |
for c in global_cands[:10]
|
| 516 |
)
|
| 517 |
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
suggestions = []
|
| 520 |
seen = set()
|
| 521 |
for c in global_cands:
|
|
@@ -531,13 +600,24 @@ Evidence:
|
|
| 531 |
answer_md = f"""## Notice
|
| 532 |
Sorry for inconvenience. Your selected body region/sub-zone is not appropriate as per your defined problem.
|
| 533 |
|
| 534 |
-
## Suggested Region/Sub-Zones
|
| 535 |
{sug_lines}
|
| 536 |
|
| 537 |
## Next step
|
| 538 |
Please select one of the suggested Region/Sub-Zones and run the search again.
|
| 539 |
""".strip()
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
|
| 543 |
if len(best) < int(final_k):
|
|
@@ -559,7 +639,6 @@ Please select one of the suggested Region/Sub-Zones and run the search again.
|
|
| 559 |
|
| 560 |
answer_md = self._format_final_answer(best, web_by_proc)
|
| 561 |
|
| 562 |
-
# de-dupe urls
|
| 563 |
seen_u = set()
|
| 564 |
dedup = []
|
| 565 |
for u in urls:
|
|
|
|
| 103 |
self.web_max_docs = int(os.getenv("WEB_MAX_DOCS", "6"))
|
| 104 |
self.web_max_chars = int(os.getenv("WEB_MAX_CHARS", "1200"))
|
| 105 |
|
| 106 |
+
# NEW: hard gate to prevent empty/generic recommendation runs
|
| 107 |
self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
|
| 108 |
|
| 109 |
+
# NEW: mismatch sensitivity knobs
|
| 110 |
+
# If issue-only similarity within selected region is below this => likely irrelevant
|
| 111 |
+
self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
|
| 112 |
+
# If global best similarity is above this => the issue clearly maps somewhere else
|
| 113 |
+
self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
|
| 114 |
+
# Additional delta guard (keeps your old behavior but improved)
|
| 115 |
+
self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
|
| 116 |
+
|
| 117 |
# ---------------- DB ----------------
|
| 118 |
def _load_db(self) -> pd.DataFrame:
|
| 119 |
xl = pd.ExcelFile(self.excel_path)
|
|
|
|
| 258 |
break
|
| 259 |
return out
|
| 260 |
|
| 261 |
+
def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
|
| 262 |
+
"""
|
| 263 |
+
NEW: Compute issue-only similarity *within selected region/subzone*.
|
| 264 |
+
This avoids the bias from including region/subzone words in the query.
|
| 265 |
+
"""
|
| 266 |
+
issue_text = (issue_text or "").strip()
|
| 267 |
+
if not issue_text:
|
| 268 |
+
return 0.0
|
| 269 |
+
|
| 270 |
+
# consider both types if user chose both
|
| 271 |
+
t = _norm_type_choice(type_choice)
|
| 272 |
+
if t == "both":
|
| 273 |
+
idx_s = self._candidate_indices(region, sub_zone, "surgical")
|
| 274 |
+
idx_n = self._candidate_indices(region, sub_zone, "non-surgical")
|
| 275 |
+
idxs = np.unique(np.concatenate([idx_s, idx_n])) if (idx_s.size or idx_n.size) else np.array([], dtype=int)
|
| 276 |
+
else:
|
| 277 |
+
idxs = self._candidate_indices(region, sub_zone, t)
|
| 278 |
+
|
| 279 |
+
if idxs.size == 0:
|
| 280 |
+
# fall back to region only, still issue-only
|
| 281 |
+
if t == "both":
|
| 282 |
+
idx_s = self._candidate_indices(region, "", "surgical")
|
| 283 |
+
idx_n = self._candidate_indices(region, "", "non-surgical")
|
| 284 |
+
idxs = np.unique(np.concatenate([idx_s, idx_n])) if (idx_s.size or idx_n.size) else np.array([], dtype=int)
|
| 285 |
+
else:
|
| 286 |
+
idxs = self._candidate_indices(region, "", t)
|
| 287 |
+
|
| 288 |
+
if idxs.size == 0:
|
| 289 |
+
return 0.0
|
| 290 |
+
|
| 291 |
+
q_emb = self.model.encode([issue_text], convert_to_numpy=True).astype(np.float32)
|
| 292 |
+
sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
|
| 293 |
+
if sims.size == 0:
|
| 294 |
+
return 0.0
|
| 295 |
+
return float(np.max(sims))
|
| 296 |
+
|
| 297 |
def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
|
| 298 |
type_norm = _norm_type_choice(type_choice)
|
| 299 |
query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
|
|
|
|
| 347 |
return docs
|
| 348 |
|
| 349 |
def _web_enrich_fallback(self, procedure: str) -> List[WebDoc]:
|
| 350 |
+
"""
|
| 351 |
+
Second-pass retrieval only if extraction is failing.
|
| 352 |
+
"""
|
| 353 |
if not self.web_enabled:
|
| 354 |
return []
|
| 355 |
+
|
| 356 |
queries = [
|
| 357 |
f"{procedure} recovery swelling bruising downtime",
|
| 358 |
f"{procedure} procedure time how long does it take",
|
|
|
|
| 433 |
Use ONLY the evidence below. If not present, write "Not found in evidence."
|
| 434 |
|
| 435 |
Return STRICT JSON with these keys:
|
| 436 |
+
- invasiveness (Non-invasive / Minimally invasive / Surgical, or evidence-based wording)
|
| 437 |
+
- duration (typical treatment/procedure time; include units if present)
|
| 438 |
+
- downtime (recovery/downtime; typical range)
|
| 439 |
+
- longevity (how long results last; typical range)
|
| 440 |
+
- risks (common risks/side effects; concise)
|
| 441 |
+
- best_suited_for (who it is for; concise)
|
| 442 |
|
| 443 |
Evidence:
|
| 444 |
{evidence}
|
|
|
|
| 519 |
retrieval_k: int = 12,
|
| 520 |
final_k: int = 5,
|
| 521 |
) -> Dict[str, object]:
|
| 522 |
+
# --- Hard gate (must have region, sub-zone, and meaningful issue text) ---
|
| 523 |
region = (region or "").strip()
|
| 524 |
sub_zone = (sub_zone or "").strip()
|
| 525 |
issue_text = (issue_text or "").strip()
|
|
|
|
| 541 |
|
| 542 |
candidates = self.semantic_search(region, sub_zone, type_choice, issue_text, top_k=int(retrieval_k))
|
| 543 |
|
| 544 |
+
# If nothing returned locally, show friendly message
|
| 545 |
if not candidates:
|
| 546 |
return {
|
| 547 |
"answer_md": "No matching procedures found for your selected Region/Sub-Zone and issue. Please revise your inputs.",
|
|
|
|
| 549 |
"_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
|
| 550 |
}
|
| 551 |
|
| 552 |
+
# ---------------- Improved mismatch detection ----------------
|
| 553 |
global_cands = self._global_semantic(issue_text, top_k=15)
|
| 554 |
global_best = global_cands[0].similarity if global_cands else 0.0
|
| 555 |
local_best = candidates[0].similarity if candidates else 0.0
|
| 556 |
|
| 557 |
selected_region_norm = _norm(region)
|
| 558 |
selected_sub_norm = _norm(sub_zone)
|
| 559 |
+
|
| 560 |
selected_in_global = any(
|
| 561 |
_norm(c.region) == selected_region_norm and (
|
| 562 |
selected_sub_norm in _norm(c.sub_zone) or _norm(c.sub_zone) in selected_sub_norm
|
|
|
|
| 564 |
for c in global_cands[:10]
|
| 565 |
)
|
| 566 |
|
| 567 |
+
# NEW: issue-only best similarity inside selected region/subzone
|
| 568 |
+
local_issue_best = self._local_issue_only_best_sim(region, sub_zone, type_choice, issue_text)
|
| 569 |
+
|
| 570 |
+
# Trigger mismatch if:
|
| 571 |
+
# - issue-only relevance to selected region is low
|
| 572 |
+
# - but global mapping is strong
|
| 573 |
+
# - and global top does not align with selected region/subzone
|
| 574 |
+
mismatch_strict = (
|
| 575 |
+
(local_issue_best > 0.0 and local_issue_best < self.local_issue_min_sim) and
|
| 576 |
+
(global_best >= self.global_issue_min_sim) and
|
| 577 |
+
(not selected_in_global)
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Keep your older delta-based signal (still useful)
|
| 581 |
+
mismatch_delta = (
|
| 582 |
+
(global_best >= self.global_issue_min_sim) and
|
| 583 |
+
((global_best - local_best) >= self.global_local_delta) and
|
| 584 |
+
(not selected_in_global)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
if mismatch_strict or mismatch_delta:
|
| 588 |
suggestions = []
|
| 589 |
seen = set()
|
| 590 |
for c in global_cands:
|
|
|
|
| 600 |
answer_md = f"""## Notice
|
| 601 |
Sorry for inconvenience. Your selected body region/sub-zone is not appropriate as per your defined problem.
|
| 602 |
|
| 603 |
+
## Suggested Region/Sub-Zones (based on your issue text)
|
| 604 |
{sug_lines}
|
| 605 |
|
| 606 |
## Next step
|
| 607 |
Please select one of the suggested Region/Sub-Zones and run the search again.
|
| 608 |
""".strip()
|
| 609 |
+
|
| 610 |
+
return {
|
| 611 |
+
"answer_md": answer_md,
|
| 612 |
+
"sources": [],
|
| 613 |
+
"_debug": {
|
| 614 |
+
"mismatch": True,
|
| 615 |
+
"global_best": round(global_best, 4),
|
| 616 |
+
"local_best": round(local_best, 4),
|
| 617 |
+
"local_issue_best": round(local_issue_best, 4),
|
| 618 |
+
},
|
| 619 |
+
}
|
| 620 |
+
# -------------------------------------------------------------
|
| 621 |
|
| 622 |
best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
|
| 623 |
if len(best) < int(final_k):
|
|
|
|
| 639 |
|
| 640 |
answer_md = self._format_final_answer(best, web_by_proc)
|
| 641 |
|
|
|
|
| 642 |
seen_u = set()
|
| 643 |
dedup = []
|
| 644 |
for u in urls:
|