Spaces:
Sleeping
Sleeping
Update rag_treatment_app.py
Browse files- rag_treatment_app.py +191 -72
rag_treatment_app.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pickle
|
|
|
|
| 6 |
import time
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Dict, List, Optional, Tuple
|
|
@@ -14,6 +15,7 @@ from sentence_transformers import SentenceTransformer
|
|
| 14 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 15 |
|
| 16 |
from llm_client import LocalLLMClient
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/static-similarity-mrl-multilingual-v1"
|
|
@@ -26,10 +28,6 @@ def _norm(x: str) -> str:
|
|
| 26 |
|
| 27 |
|
| 28 |
def _norm_type_value(x: str) -> str:
|
| 29 |
-
"""
|
| 30 |
-
Normalize DB type to {surgical, non-surgical, ""}.
|
| 31 |
-
Handles many variants: Non surgical, non-surg, non-surgical, etc.
|
| 32 |
-
"""
|
| 33 |
t = _norm(x).replace("_", "-").replace("–", "-").replace("—", "-")
|
| 34 |
if ("non" in t and "surg" in t) or ("nonsurg" in t):
|
| 35 |
return "non-surgical"
|
|
@@ -84,6 +82,28 @@ def _na_db(v: str) -> str:
|
|
| 84 |
return v if v else "Not found in database."
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# ---------------------------- data model ----------------------------
|
| 88 |
|
| 89 |
@dataclass
|
|
@@ -121,14 +141,10 @@ class RetrievedCandidate:
|
|
| 121 |
|
| 122 |
class RAGTreatmentSearchApp:
|
| 123 |
"""
|
| 124 |
-
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
- Reads procedure details from DB columns (no web calls)
|
| 129 |
-
|
| 130 |
-
API is kept compatible with your existing gradio_new_rag_app.py:
|
| 131 |
-
RAGTreatmentSearchApp(excel_path=..., embeddings_cache_path=...)
|
| 132 |
"""
|
| 133 |
|
| 134 |
def __init__(
|
|
@@ -138,6 +154,7 @@ class RAGTreatmentSearchApp:
|
|
| 138 |
embeddings_cache_path: str = "treatment_embeddings.pkl",
|
| 139 |
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
|
| 140 |
llm: Optional[LocalLLMClient] = None,
|
|
|
|
| 141 |
):
|
| 142 |
try:
|
| 143 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
|
|
@@ -155,15 +172,22 @@ class RAGTreatmentSearchApp:
|
|
| 155 |
self.embeddings, self.texts = self._load_or_build_embeddings()
|
| 156 |
|
| 157 |
self.llm = llm or LocalLLMClient()
|
|
|
|
| 158 |
|
| 159 |
-
#
|
| 160 |
self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
|
| 161 |
-
|
| 162 |
-
# mismatch sensitivity (tuned)
|
| 163 |
self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
|
| 164 |
self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
|
| 165 |
self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# ---------------- DB ----------------
|
| 168 |
|
| 169 |
def _load_db(self) -> pd.DataFrame:
|
|
@@ -173,45 +197,32 @@ class RAGTreatmentSearchApp:
|
|
| 173 |
return pd.read_excel(self.excel_path, sheet_name=self.sheet_name)
|
| 174 |
|
| 175 |
def _normalize_columns(self) -> None:
|
| 176 |
-
"""
|
| 177 |
-
Supports the NEW schema you described.
|
| 178 |
-
We also create UI-friendly aliases: Region, Sub-Zone, Procedure, Type.
|
| 179 |
-
"""
|
| 180 |
-
# Required minimal new schema keys (based on your DB update)
|
| 181 |
-
required_any = [
|
| 182 |
-
"procedure_title",
|
| 183 |
-
"main_zone",
|
| 184 |
-
"treatment_type",
|
| 185 |
-
]
|
| 186 |
missing_any = [c for c in required_any if c not in self.df.columns]
|
| 187 |
if missing_any:
|
| 188 |
raise ValueError(f"Database missing required columns: {missing_any}")
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
# Region -> main_zone
|
| 192 |
self.df["Region"] = self.df["main_zone"].fillna("").astype(str).str.strip()
|
| 193 |
|
| 194 |
-
# Sub-Zone
|
| 195 |
if "face_subzone" in self.df.columns or "body_subzone" in self.df.columns:
|
| 196 |
-
face = self.df["face_subzone"].fillna("").astype(str).str.strip() if "face_subzone" in self.df.columns else
|
| 197 |
-
body = self.df["body_subzone"].fillna("").astype(str).str.strip() if "body_subzone" in self.df.columns else
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
self.df.loc[mask_empty, "Sub-Zone"] = body.loc[mask_empty]
|
| 206 |
else:
|
| 207 |
-
# last fallback if DB already has something named Sub-Zone
|
| 208 |
self.df["Sub-Zone"] = self.df.get("Sub-Zone", "").fillna("").astype(str).str.strip()
|
| 209 |
|
| 210 |
-
# Procedure/Type
|
| 211 |
self.df["Procedure"] = self.df["procedure_title"].fillna("").astype(str).str.strip()
|
| 212 |
self.df["Type"] = self.df["treatment_type"].fillna("").astype(str).str.strip()
|
| 213 |
|
| 214 |
-
# Normalize core columns
|
| 215 |
for col in ["Type", "Region", "Sub-Zone", "Procedure"]:
|
| 216 |
self.df[col] = self.df[col].astype(str).fillna("").str.strip()
|
| 217 |
|
|
@@ -233,22 +244,145 @@ class RAGTreatmentSearchApp:
|
|
| 233 |
out.append(ss)
|
| 234 |
return sorted(out)
|
| 235 |
|
| 236 |
-
# ----------------
|
| 237 |
|
| 238 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
"""
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
| 242 |
"""
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
sub = _db_str(row.get("Sub-Zone", ""))
|
| 246 |
-
typ = _db_str(row.get("treatment_type", ""))
|
| 247 |
|
| 248 |
short_desc = _first_present(row, ["short_description", "procedure_description", "description"])
|
| 249 |
concerns = _first_present(row, ["concerns", "aesthetic_concerns", "Aesthetic Concerns"])
|
| 250 |
techniques = _first_present(row, ["techniques_brands_variants", "Technique / Technology / Brand", "techniques"])
|
| 251 |
-
|
| 252 |
expected = _first_present(row, ["expected_results", "expected_result"])
|
| 253 |
sidefx = _first_present(row, ["potential_side_effects", "side_effects", "risks"])
|
| 254 |
|
|
@@ -321,6 +455,7 @@ class RAGTreatmentSearchApp:
|
|
| 321 |
RetrievedCandidate(
|
| 322 |
row_index=row_index,
|
| 323 |
similarity=float(sims[pos]),
|
|
|
|
| 324 |
procedure=_na_db(proc),
|
| 325 |
region=_na_db(reg),
|
| 326 |
sub_zone=_na_db(sub),
|
|
@@ -347,8 +482,10 @@ class RAGTreatmentSearchApp:
|
|
| 347 |
average_cost_max_chf=_na_db(_first_present(row, ["average_cost_max_chf"])),
|
| 348 |
)
|
| 349 |
)
|
|
|
|
| 350 |
if len(out) >= top_k:
|
| 351 |
break
|
|
|
|
| 352 |
return out
|
| 353 |
|
| 354 |
def _global_semantic(self, issue_text: str, top_k: int = 15) -> List[RetrievedCandidate]:
|
|
@@ -361,7 +498,6 @@ class RAGTreatmentSearchApp:
|
|
| 361 |
out: List[RetrievedCandidate] = []
|
| 362 |
for idx in order[: max(top_k, 1) * 20]:
|
| 363 |
row = self.df.iloc[int(idx)]
|
| 364 |
-
# Build minimal candidate (details not required for mismatch suggestion list)
|
| 365 |
proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
|
| 366 |
reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
|
| 367 |
sub = _db_str(row.get("Sub-Zone", "")) or _db_str(row.get("face_subzone", "")) or _db_str(row.get("body_subzone", ""))
|
|
@@ -371,6 +507,7 @@ class RAGTreatmentSearchApp:
|
|
| 371 |
RetrievedCandidate(
|
| 372 |
row_index=int(idx),
|
| 373 |
similarity=float(sims[idx]),
|
|
|
|
| 374 |
procedure=_na_db(proc),
|
| 375 |
region=_na_db(reg),
|
| 376 |
sub_zone=_na_db(sub),
|
|
@@ -399,12 +536,10 @@ class RAGTreatmentSearchApp:
|
|
| 399 |
)
|
| 400 |
if len(out) >= top_k:
|
| 401 |
break
|
|
|
|
| 402 |
return out
|
| 403 |
|
| 404 |
def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
|
| 405 |
-
"""
|
| 406 |
-
Compute issue-only similarity inside selected region/sub-zone to detect irrelevance.
|
| 407 |
-
"""
|
| 408 |
issue_text = (issue_text or "").strip()
|
| 409 |
if not issue_text:
|
| 410 |
return 0.0
|
|
@@ -418,7 +553,6 @@ class RAGTreatmentSearchApp:
|
|
| 418 |
idxs = self._candidate_indices(region, sub_zone, t)
|
| 419 |
|
| 420 |
if idxs.size == 0:
|
| 421 |
-
# region only
|
| 422 |
if t == "both":
|
| 423 |
idx_s = self._candidate_indices(region, "", "surgical")
|
| 424 |
idx_n = self._candidate_indices(region, "", "non-surgical")
|
|
@@ -433,16 +567,8 @@ class RAGTreatmentSearchApp:
|
|
| 433 |
sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
|
| 434 |
return float(np.max(sims)) if sims.size else 0.0
|
| 435 |
|
| 436 |
-
def semantic_search(
|
| 437 |
-
self,
|
| 438 |
-
region: str,
|
| 439 |
-
sub_zone: str,
|
| 440 |
-
type_choice: str,
|
| 441 |
-
issue_text: str,
|
| 442 |
-
top_k: int = 12,
|
| 443 |
-
) -> List[RetrievedCandidate]:
|
| 444 |
type_norm = _norm_type_choice(type_choice)
|
| 445 |
-
|
| 446 |
query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
|
| 447 |
|
| 448 |
if type_norm == "both":
|
|
@@ -451,7 +577,6 @@ class RAGTreatmentSearchApp:
|
|
| 451 |
per = max(3, top_k // 2)
|
| 452 |
res = self._semantic_over(idx_s, query, per) + self._semantic_over(idx_n, query, per)
|
| 453 |
res.sort(key=lambda x: x.similarity, reverse=True)
|
| 454 |
-
# de-dupe by row index
|
| 455 |
seen = set()
|
| 456 |
out = []
|
| 457 |
for c in res:
|
|
@@ -499,7 +624,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
|
|
| 499 |
if len(out) >= top_k:
|
| 500 |
break
|
| 501 |
|
| 502 |
-
# fill remainder
|
| 503 |
for c in candidates:
|
| 504 |
if len(out) >= top_k:
|
| 505 |
break
|
|
@@ -508,7 +632,7 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
|
|
| 508 |
|
| 509 |
return out
|
| 510 |
|
| 511 |
-
# ---------------- Formatting
|
| 512 |
|
| 513 |
def _format_cost(self, mn: str, mx: str, unit: str) -> str:
|
| 514 |
if mn == "Not found in database." and mx == "Not found in database.":
|
|
@@ -572,7 +696,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
|
|
| 572 |
sub_zone = (sub_zone or "").strip()
|
| 573 |
issue_text = (issue_text or "").strip()
|
| 574 |
|
| 575 |
-
# Hard gate: must provide issue text
|
| 576 |
if not region or not sub_zone:
|
| 577 |
return {
|
| 578 |
"answer_md": "Please select **Region** and **Sub-Zone** before running the search.",
|
|
@@ -595,7 +718,7 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
|
|
| 595 |
"_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
|
| 596 |
}
|
| 597 |
|
| 598 |
-
#
|
| 599 |
global_cands = self._global_semantic(issue_text, top_k=15)
|
| 600 |
global_best = global_cands[0].similarity if global_cands else 0.0
|
| 601 |
local_best = candidates[0].similarity if candidates else 0.0
|
|
@@ -625,7 +748,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
|
|
| 625 |
)
|
| 626 |
|
| 627 |
if mismatch_strict or mismatch_delta:
|
| 628 |
-
# suggest correct region/sub-zones based on issue text
|
| 629 |
suggestions = []
|
| 630 |
seen = set()
|
| 631 |
for c in global_cands:
|
|
@@ -665,11 +787,8 @@ Please select one of the suggested **Region/Sub-Zones** and run the search again
|
|
| 665 |
"candidate_count": len(candidates),
|
| 666 |
},
|
| 667 |
}
|
| 668 |
-
# ---------------------------------------
|
| 669 |
|
| 670 |
best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
|
| 671 |
-
|
| 672 |
-
# Ensure exactly final_k if possible
|
| 673 |
if len(best) < int(final_k):
|
| 674 |
for c in candidates:
|
| 675 |
if c not in best:
|
|
@@ -682,7 +801,7 @@ Please select one of the suggested **Region/Sub-Zones** and run the search again
|
|
| 682 |
|
| 683 |
return {
|
| 684 |
"answer_md": answer_md,
|
| 685 |
-
"sources": [],
|
| 686 |
"_debug": {
|
| 687 |
"mismatch": False,
|
| 688 |
"candidate_count": len(candidates),
|
|
|
|
| 3 |
|
| 4 |
import os
|
| 5 |
import pickle
|
| 6 |
+
import re
|
| 7 |
import time
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import Dict, List, Optional, Tuple
|
|
|
|
| 15 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 16 |
|
| 17 |
from llm_client import LocalLLMClient
|
| 18 |
+
from web_retriever import WebRetriever, WebDoc
|
| 19 |
|
| 20 |
|
| 21 |
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/static-similarity-mrl-multilingual-v1"
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def _norm_type_value(x: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
t = _norm(x).replace("_", "-").replace("–", "-").replace("—", "-")
|
| 32 |
if ("non" in t and "surg" in t) or ("nonsurg" in t):
|
| 33 |
return "non-surgical"
|
|
|
|
| 82 |
return v if v else "Not found in database."
|
| 83 |
|
| 84 |
|
| 85 |
+
def _split_concerns(text: str) -> List[str]:
|
| 86 |
+
"""
|
| 87 |
+
Split a concerns cell into candidate concern phrases.
|
| 88 |
+
Handles ; , | newlines and bullet-ish formats.
|
| 89 |
+
"""
|
| 90 |
+
t = (text or "").strip()
|
| 91 |
+
if not t:
|
| 92 |
+
return []
|
| 93 |
+
t = t.replace("•", "\n").replace("·", "\n")
|
| 94 |
+
parts = re.split(r"[;\n\|]+", t)
|
| 95 |
+
out = []
|
| 96 |
+
for p in parts:
|
| 97 |
+
p = p.strip(" -\t\r")
|
| 98 |
+
if not p:
|
| 99 |
+
continue
|
| 100 |
+
if len(p) > 120:
|
| 101 |
+
# keep short fragments only
|
| 102 |
+
continue
|
| 103 |
+
out.append(p)
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
|
| 107 |
# ---------------------------- data model ----------------------------
|
| 108 |
|
| 109 |
@dataclass
|
|
|
|
| 141 |
|
| 142 |
class RAGTreatmentSearchApp:
|
| 143 |
"""
|
| 144 |
+
DB-driven structured RAG + Common Concerns (internet -> fallback DB).
|
| 145 |
|
| 146 |
+
- Core recommendations: semantic retrieval + LLM rerank + formatting from DB columns.
|
| 147 |
+
- Common concerns: fetch short common issues for Region/Sub-Zone to help the user fill the issue box.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
"""
|
| 149 |
|
| 150 |
def __init__(
|
|
|
|
| 154 |
embeddings_cache_path: str = "treatment_embeddings.pkl",
|
| 155 |
embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
|
| 156 |
llm: Optional[LocalLLMClient] = None,
|
| 157 |
+
web: Optional[WebRetriever] = None,
|
| 158 |
):
|
| 159 |
try:
|
| 160 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
|
|
|
|
| 172 |
self.embeddings, self.texts = self._load_or_build_embeddings()
|
| 173 |
|
| 174 |
self.llm = llm or LocalLLMClient()
|
| 175 |
+
self.web = web or WebRetriever()
|
| 176 |
|
| 177 |
+
# gates + mismatch knobs
|
| 178 |
self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
|
|
|
|
|
|
|
| 179 |
self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
|
| 180 |
self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
|
| 181 |
self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
|
| 182 |
|
| 183 |
+
# common concerns config
|
| 184 |
+
self.common_web_enabled = os.getenv("COMMON_CONCERNS_WEB_ENABLED", "1").strip() != "0"
|
| 185 |
+
self.common_max_docs = int(os.getenv("COMMON_CONCERNS_MAX_DOCS", "4"))
|
| 186 |
+
self.common_max_chars = int(os.getenv("COMMON_CONCERNS_MAX_CHARS", "900"))
|
| 187 |
+
self.common_top_n = int(os.getenv("COMMON_CONCERNS_TOP_N", "4"))
|
| 188 |
+
|
| 189 |
+
self._common_cache: Dict[Tuple[str, str], List[str]] = {}
|
| 190 |
+
|
| 191 |
# ---------------- DB ----------------
|
| 192 |
|
| 193 |
def _load_db(self) -> pd.DataFrame:
|
|
|
|
| 197 |
return pd.read_excel(self.excel_path, sheet_name=self.sheet_name)
|
| 198 |
|
| 199 |
def _normalize_columns(self) -> None:
|
| 200 |
+
required_any = ["procedure_title", "main_zone", "treatment_type"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
missing_any = [c for c in required_any if c not in self.df.columns]
|
| 202 |
if missing_any:
|
| 203 |
raise ValueError(f"Database missing required columns: {missing_any}")
|
| 204 |
|
| 205 |
+
# Region
|
|
|
|
| 206 |
self.df["Region"] = self.df["main_zone"].fillna("").astype(str).str.strip()
|
| 207 |
|
| 208 |
+
# Sub-Zone (prefer face_subzone, else body_subzone, else existing Sub-Zone)
|
| 209 |
if "face_subzone" in self.df.columns or "body_subzone" in self.df.columns:
|
| 210 |
+
face = self.df["face_subzone"].fillna("").astype(str).str.strip() if "face_subzone" in self.df.columns else None
|
| 211 |
+
body = self.df["body_subzone"].fillna("").astype(str).str.strip() if "body_subzone" in self.df.columns else None
|
| 212 |
+
if face is None:
|
| 213 |
+
self.df["Sub-Zone"] = body
|
| 214 |
+
else:
|
| 215 |
+
self.df["Sub-Zone"] = face
|
| 216 |
+
mask_empty = self.df["Sub-Zone"].eq("") | self.df["Sub-Zone"].str.lower().eq("nan")
|
| 217 |
+
if body is not None:
|
| 218 |
+
self.df.loc[mask_empty, "Sub-Zone"] = body.loc[mask_empty]
|
|
|
|
| 219 |
else:
|
|
|
|
| 220 |
self.df["Sub-Zone"] = self.df.get("Sub-Zone", "").fillna("").astype(str).str.strip()
|
| 221 |
|
| 222 |
+
# Procedure / Type aliases
|
| 223 |
self.df["Procedure"] = self.df["procedure_title"].fillna("").astype(str).str.strip()
|
| 224 |
self.df["Type"] = self.df["treatment_type"].fillna("").astype(str).str.strip()
|
| 225 |
|
|
|
|
| 226 |
for col in ["Type", "Region", "Sub-Zone", "Procedure"]:
|
| 227 |
self.df[col] = self.df[col].astype(str).fillna("").str.strip()
|
| 228 |
|
|
|
|
| 244 |
out.append(ss)
|
| 245 |
return sorted(out)
|
| 246 |
|
| 247 |
+
# ---------------- Common concerns ----------------
|
| 248 |
|
| 249 |
+
def _db_common_concerns(self, region: str, sub_zone: str, n: int = 4) -> List[str]:
|
| 250 |
+
"""
|
| 251 |
+
Fallback: extract most frequent short concerns from DB rows in selected Region/Sub-Zone.
|
| 252 |
+
"""
|
| 253 |
+
r = _norm(region)
|
| 254 |
+
sz = _norm(sub_zone)
|
| 255 |
+
|
| 256 |
+
m = self.df["_region_norm"].eq(r)
|
| 257 |
+
if sz:
|
| 258 |
+
m = m & (self.df["_subzone_norm"].eq(sz) | self.df["_subzone_norm"].str.contains(sz, na=False))
|
| 259 |
+
|
| 260 |
+
df2 = self.df[m]
|
| 261 |
+
if df2.empty:
|
| 262 |
+
return []
|
| 263 |
+
|
| 264 |
+
counts: Dict[str, int] = {}
|
| 265 |
+
for _, row in df2.iterrows():
|
| 266 |
+
concerns = _first_present(row, ["concerns", "Aesthetic Concerns", "aesthetic_concerns"])
|
| 267 |
+
for c in _split_concerns(concerns):
|
| 268 |
+
key = c.strip()
|
| 269 |
+
if len(key) < 4:
|
| 270 |
+
continue
|
| 271 |
+
counts[key] = counts.get(key, 0) + 1
|
| 272 |
+
|
| 273 |
+
ranked = sorted(counts.items(), key=lambda x: (-x[1], x[0].lower()))
|
| 274 |
+
return [k for (k, _) in ranked[: max(1, n)]]
|
| 275 |
+
|
| 276 |
+
def _web_common_concerns(self, region: str, sub_zone: str, n: int = 4) -> List[str]:
|
| 277 |
+
"""
|
| 278 |
+
Internet-based: get common concerns for Region/Sub-Zone; extract with LLM as short phrases.
|
| 279 |
+
|
| 280 |
+
If web is blocked/rate-limited on HF, this naturally falls back to DB list.
|
| 281 |
+
"""
|
| 282 |
+
if not self.common_web_enabled:
|
| 283 |
+
return []
|
| 284 |
+
|
| 285 |
+
region = (region or "").strip()
|
| 286 |
+
sub_zone = (sub_zone or "").strip()
|
| 287 |
+
if not region or not sub_zone:
|
| 288 |
+
return []
|
| 289 |
+
|
| 290 |
+
queries = [
|
| 291 |
+
f"common aesthetic concerns {region} {sub_zone}",
|
| 292 |
+
f"most common problems {sub_zone} aesthetic treatment",
|
| 293 |
+
f"{sub_zone} cosmetic concerns dark circles wrinkles pigmentation",
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
docs = self.web.search_and_fetch(
|
| 297 |
+
queries=queries,
|
| 298 |
+
max_results_per_query=2,
|
| 299 |
+
max_docs=self.common_max_docs,
|
| 300 |
+
max_chars_per_doc=self.common_max_chars,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if not docs:
|
| 304 |
+
return []
|
| 305 |
+
|
| 306 |
+
def compact(s: str, limit: int = 650) -> str:
|
| 307 |
+
s = re.sub(r"\s+", " ", (s or "").strip())
|
| 308 |
+
return (s[:limit] + "…") if len(s) > limit else s
|
| 309 |
+
|
| 310 |
+
ev = []
|
| 311 |
+
for i, d in enumerate(docs[:4], start=1):
|
| 312 |
+
ev.append(f"[Doc {i}] {d.title}\n{compact(d.snippet)}")
|
| 313 |
+
evidence = "\n\n".join(ev)
|
| 314 |
+
|
| 315 |
+
prompt = f"""
|
| 316 |
+
You are extracting ONLY common patient concerns (issues) for:
|
| 317 |
+
Region: {region}
|
| 318 |
+
Sub-Zone: {sub_zone}
|
| 319 |
+
|
| 320 |
+
From the evidence, output STRICT JSON:
|
| 321 |
+
{{"concerns": ["...","..."]}}
|
| 322 |
+
|
| 323 |
+
Rules:
|
| 324 |
+
- return 1 to {n} short concern phrases (3-8 words each)
|
| 325 |
+
- no treatment names, only issues/concerns
|
| 326 |
+
- deduplicate similar items
|
| 327 |
+
- if unclear, return fewer items
|
| 328 |
+
|
| 329 |
+
Evidence:
|
| 330 |
+
{evidence}
|
| 331 |
+
""".strip()
|
| 332 |
+
|
| 333 |
+
raw = (self.llm.generate(prompt, temperature=0.2, max_tokens=160) or "").strip()
|
| 334 |
+
data = self.llm.safe_json_loads(raw)
|
| 335 |
+
arr = data.get("concerns", [])
|
| 336 |
+
|
| 337 |
+
out: List[str] = []
|
| 338 |
+
if isinstance(arr, list):
|
| 339 |
+
for x in arr:
|
| 340 |
+
s = str(x).strip()
|
| 341 |
+
if not s:
|
| 342 |
+
continue
|
| 343 |
+
if len(s) > 80:
|
| 344 |
+
continue
|
| 345 |
+
if s.lower() in {z.lower() for z in out}:
|
| 346 |
+
continue
|
| 347 |
+
out.append(s)
|
| 348 |
+
|
| 349 |
+
return out[:n]
|
| 350 |
+
|
| 351 |
+
def get_common_concerns(self, region: str, sub_zone: str, n: Optional[int] = None) -> List[str]:
|
| 352 |
"""
|
| 353 |
+
Public API for UI:
|
| 354 |
+
- first try internet extraction
|
| 355 |
+
- if it fails, use DB-derived concerns
|
| 356 |
+
- cached per (region, sub_zone)
|
| 357 |
"""
|
| 358 |
+
n = int(n or self.common_top_n)
|
| 359 |
+
key = (_norm(region), _norm(sub_zone))
|
| 360 |
+
if key in self._common_cache:
|
| 361 |
+
return self._common_cache[key]
|
| 362 |
+
|
| 363 |
+
concerns: List[str] = []
|
| 364 |
+
try:
|
| 365 |
+
concerns = self._web_common_concerns(region, sub_zone, n=n)
|
| 366 |
+
except Exception:
|
| 367 |
+
concerns = []
|
| 368 |
+
|
| 369 |
+
if not concerns:
|
| 370 |
+
concerns = self._db_common_concerns(region, sub_zone, n=n)
|
| 371 |
+
|
| 372 |
+
self._common_cache[key] = concerns
|
| 373 |
+
return concerns
|
| 374 |
+
|
| 375 |
+
# ---------------- Embeddings ----------------
|
| 376 |
+
|
| 377 |
+
def _row_to_text(self, row: pd.Series) -> str:
|
| 378 |
+
proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
|
| 379 |
+
reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
|
| 380 |
sub = _db_str(row.get("Sub-Zone", ""))
|
| 381 |
+
typ = _db_str(row.get("treatment_type", "")) or _db_str(row.get("Type", ""))
|
| 382 |
|
| 383 |
short_desc = _first_present(row, ["short_description", "procedure_description", "description"])
|
| 384 |
concerns = _first_present(row, ["concerns", "aesthetic_concerns", "Aesthetic Concerns"])
|
| 385 |
techniques = _first_present(row, ["techniques_brands_variants", "Technique / Technology / Brand", "techniques"])
|
|
|
|
| 386 |
expected = _first_present(row, ["expected_results", "expected_result"])
|
| 387 |
sidefx = _first_present(row, ["potential_side_effects", "side_effects", "risks"])
|
| 388 |
|
|
|
|
| 455 |
RetrievedCandidate(
|
| 456 |
row_index=row_index,
|
| 457 |
similarity=float(sims[pos]),
|
| 458 |
+
|
| 459 |
procedure=_na_db(proc),
|
| 460 |
region=_na_db(reg),
|
| 461 |
sub_zone=_na_db(sub),
|
|
|
|
| 482 |
average_cost_max_chf=_na_db(_first_present(row, ["average_cost_max_chf"])),
|
| 483 |
)
|
| 484 |
)
|
| 485 |
+
|
| 486 |
if len(out) >= top_k:
|
| 487 |
break
|
| 488 |
+
|
| 489 |
return out
|
| 490 |
|
| 491 |
def _global_semantic(self, issue_text: str, top_k: int = 15) -> List[RetrievedCandidate]:
|
|
|
|
| 498 |
out: List[RetrievedCandidate] = []
|
| 499 |
for idx in order[: max(top_k, 1) * 20]:
|
| 500 |
row = self.df.iloc[int(idx)]
|
|
|
|
| 501 |
proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
|
| 502 |
reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
|
| 503 |
sub = _db_str(row.get("Sub-Zone", "")) or _db_str(row.get("face_subzone", "")) or _db_str(row.get("body_subzone", ""))
|
|
|
|
| 507 |
RetrievedCandidate(
|
| 508 |
row_index=int(idx),
|
| 509 |
similarity=float(sims[idx]),
|
| 510 |
+
|
| 511 |
procedure=_na_db(proc),
|
| 512 |
region=_na_db(reg),
|
| 513 |
sub_zone=_na_db(sub),
|
|
|
|
| 536 |
)
|
| 537 |
if len(out) >= top_k:
|
| 538 |
break
|
| 539 |
+
|
| 540 |
return out
|
| 541 |
|
| 542 |
def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
|
|
|
|
|
|
|
|
|
|
| 543 |
issue_text = (issue_text or "").strip()
|
| 544 |
if not issue_text:
|
| 545 |
return 0.0
|
|
|
|
| 553 |
idxs = self._candidate_indices(region, sub_zone, t)
|
| 554 |
|
| 555 |
if idxs.size == 0:
|
|
|
|
| 556 |
if t == "both":
|
| 557 |
idx_s = self._candidate_indices(region, "", "surgical")
|
| 558 |
idx_n = self._candidate_indices(region, "", "non-surgical")
|
|
|
|
| 567 |
sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
|
| 568 |
return float(np.max(sims)) if sims.size else 0.0
|
| 569 |
|
| 570 |
+
def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
type_norm = _norm_type_choice(type_choice)
|
|
|
|
| 572 |
query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
|
| 573 |
|
| 574 |
if type_norm == "both":
|
|
|
|
| 577 |
per = max(3, top_k // 2)
|
| 578 |
res = self._semantic_over(idx_s, query, per) + self._semantic_over(idx_n, query, per)
|
| 579 |
res.sort(key=lambda x: x.similarity, reverse=True)
|
|
|
|
| 580 |
seen = set()
|
| 581 |
out = []
|
| 582 |
for c in res:
|
|
|
|
| 624 |
if len(out) >= top_k:
|
| 625 |
break
|
| 626 |
|
|
|
|
| 627 |
for c in candidates:
|
| 628 |
if len(out) >= top_k:
|
| 629 |
break
|
|
|
|
| 632 |
|
| 633 |
return out
|
| 634 |
|
| 635 |
+
# ---------------- Formatting ----------------
|
| 636 |
|
| 637 |
def _format_cost(self, mn: str, mx: str, unit: str) -> str:
|
| 638 |
if mn == "Not found in database." and mx == "Not found in database.":
|
|
|
|
| 696 |
sub_zone = (sub_zone or "").strip()
|
| 697 |
issue_text = (issue_text or "").strip()
|
| 698 |
|
|
|
|
| 699 |
if not region or not sub_zone:
|
| 700 |
return {
|
| 701 |
"answer_md": "Please select **Region** and **Sub-Zone** before running the search.",
|
|
|
|
| 718 |
"_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
|
| 719 |
}
|
| 720 |
|
| 721 |
+
# mismatch detection
|
| 722 |
global_cands = self._global_semantic(issue_text, top_k=15)
|
| 723 |
global_best = global_cands[0].similarity if global_cands else 0.0
|
| 724 |
local_best = candidates[0].similarity if candidates else 0.0
|
|
|
|
| 748 |
)
|
| 749 |
|
| 750 |
if mismatch_strict or mismatch_delta:
|
|
|
|
| 751 |
suggestions = []
|
| 752 |
seen = set()
|
| 753 |
for c in global_cands:
|
|
|
|
| 787 |
"candidate_count": len(candidates),
|
| 788 |
},
|
| 789 |
}
|
|
|
|
| 790 |
|
| 791 |
best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
|
|
|
|
|
|
|
| 792 |
if len(best) < int(final_k):
|
| 793 |
for c in candidates:
|
| 794 |
if c not in best:
|
|
|
|
| 801 |
|
| 802 |
return {
|
| 803 |
"answer_md": answer_md,
|
| 804 |
+
"sources": [],
|
| 805 |
"_debug": {
|
| 806 |
"mismatch": False,
|
| 807 |
"candidate_count": len(candidates),
|