mlbench123 commited on
Commit
da9369b
·
verified ·
1 Parent(s): 143a440

Update rag_treatment_app.py

Browse files
Files changed (1) hide show
  1. rag_treatment_app.py +191 -72
rag_treatment_app.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
 
4
  import os
5
  import pickle
 
6
  import time
7
  from dataclasses import dataclass
8
  from typing import Dict, List, Optional, Tuple
@@ -14,6 +15,7 @@ from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
 
16
  from llm_client import LocalLLMClient
 
17
 
18
 
19
  DEFAULT_EMBEDDING_MODEL = "sentence-transformers/static-similarity-mrl-multilingual-v1"
@@ -26,10 +28,6 @@ def _norm(x: str) -> str:
26
 
27
 
28
  def _norm_type_value(x: str) -> str:
29
- """
30
- Normalize DB type to {surgical, non-surgical, ""}.
31
- Handles many variants: Non surgical, non-surg, non-surgical, etc.
32
- """
33
  t = _norm(x).replace("_", "-").replace("–", "-").replace("—", "-")
34
  if ("non" in t and "surg" in t) or ("nonsurg" in t):
35
  return "non-surgical"
@@ -84,6 +82,28 @@ def _na_db(v: str) -> str:
84
  return v if v else "Not found in database."
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # ---------------------------- data model ----------------------------
88
 
89
  @dataclass
@@ -121,14 +141,10 @@ class RetrievedCandidate:
121
 
122
  class RAGTreatmentSearchApp:
123
  """
124
- HF-ready local structured RAG (DB-based details).
125
 
126
- DB: database.xlsx (NEW schema)
127
- - Uses sheet_name default: "Procedures"
128
- - Reads procedure details from DB columns (no web calls)
129
-
130
- API is kept compatible with your existing gradio_new_rag_app.py:
131
- RAGTreatmentSearchApp(excel_path=..., embeddings_cache_path=...)
132
  """
133
 
134
  def __init__(
@@ -138,6 +154,7 @@ class RAGTreatmentSearchApp:
138
  embeddings_cache_path: str = "treatment_embeddings.pkl",
139
  embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
140
  llm: Optional[LocalLLMClient] = None,
 
141
  ):
142
  try:
143
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
@@ -155,15 +172,22 @@ class RAGTreatmentSearchApp:
155
  self.embeddings, self.texts = self._load_or_build_embeddings()
156
 
157
  self.llm = llm or LocalLLMClient()
 
158
 
159
- # hard gate: avoid returning output when issue is empty
160
  self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
161
-
162
- # mismatch sensitivity (tuned)
163
  self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
164
  self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
165
  self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
166
 
 
 
 
 
 
 
 
 
167
  # ---------------- DB ----------------
168
 
169
  def _load_db(self) -> pd.DataFrame:
@@ -173,45 +197,32 @@ class RAGTreatmentSearchApp:
173
  return pd.read_excel(self.excel_path, sheet_name=self.sheet_name)
174
 
175
  def _normalize_columns(self) -> None:
176
- """
177
- Supports the NEW schema you described.
178
- We also create UI-friendly aliases: Region, Sub-Zone, Procedure, Type.
179
- """
180
- # Required minimal new schema keys (based on your DB update)
181
- required_any = [
182
- "procedure_title",
183
- "main_zone",
184
- "treatment_type",
185
- ]
186
  missing_any = [c for c in required_any if c not in self.df.columns]
187
  if missing_any:
188
  raise ValueError(f"Database missing required columns: {missing_any}")
189
 
190
- # Build unified Region/Sub-Zone fields
191
- # Region -> main_zone
192
  self.df["Region"] = self.df["main_zone"].fillna("").astype(str).str.strip()
193
 
194
- # Sub-Zone: prefer face_subzone else body_subzone else any existing fallback
195
  if "face_subzone" in self.df.columns or "body_subzone" in self.df.columns:
196
- face = self.df["face_subzone"].fillna("").astype(str).str.strip() if "face_subzone" in self.df.columns else ""
197
- body = self.df["body_subzone"].fillna("").astype(str).str.strip() if "body_subzone" in self.df.columns else ""
198
- sub = face
199
- if isinstance(sub, str):
200
- # shouldn't happen, but keep safe
201
- sub = ""
202
- self.df["Sub-Zone"] = face
203
- mask_empty = self.df["Sub-Zone"].eq("") | self.df["Sub-Zone"].str.lower().eq("nan")
204
- if not isinstance(body, str):
205
- self.df.loc[mask_empty, "Sub-Zone"] = body.loc[mask_empty]
206
  else:
207
- # last fallback if DB already has something named Sub-Zone
208
  self.df["Sub-Zone"] = self.df.get("Sub-Zone", "").fillna("").astype(str).str.strip()
209
 
210
- # Procedure/Type
211
  self.df["Procedure"] = self.df["procedure_title"].fillna("").astype(str).str.strip()
212
  self.df["Type"] = self.df["treatment_type"].fillna("").astype(str).str.strip()
213
 
214
- # Normalize core columns
215
  for col in ["Type", "Region", "Sub-Zone", "Procedure"]:
216
  self.df[col] = self.df[col].astype(str).fillna("").str.strip()
217
 
@@ -233,22 +244,145 @@ class RAGTreatmentSearchApp:
233
  out.append(ss)
234
  return sorted(out)
235
 
236
- # ---------------- Embeddings ----------------
237
 
238
- def _row_to_text(self, row: pd.Series) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  """
240
- Build semantic text from DB fields (for embeddings).
241
- Keep it compact but informative so issue-only similarity works.
 
 
242
  """
243
- proc = _db_str(row.get("procedure_title", ""))
244
- reg = _db_str(row.get("main_zone", ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  sub = _db_str(row.get("Sub-Zone", ""))
246
- typ = _db_str(row.get("treatment_type", ""))
247
 
248
  short_desc = _first_present(row, ["short_description", "procedure_description", "description"])
249
  concerns = _first_present(row, ["concerns", "aesthetic_concerns", "Aesthetic Concerns"])
250
  techniques = _first_present(row, ["techniques_brands_variants", "Technique / Technology / Brand", "techniques"])
251
-
252
  expected = _first_present(row, ["expected_results", "expected_result"])
253
  sidefx = _first_present(row, ["potential_side_effects", "side_effects", "risks"])
254
 
@@ -321,6 +455,7 @@ class RAGTreatmentSearchApp:
321
  RetrievedCandidate(
322
  row_index=row_index,
323
  similarity=float(sims[pos]),
 
324
  procedure=_na_db(proc),
325
  region=_na_db(reg),
326
  sub_zone=_na_db(sub),
@@ -347,8 +482,10 @@ class RAGTreatmentSearchApp:
347
  average_cost_max_chf=_na_db(_first_present(row, ["average_cost_max_chf"])),
348
  )
349
  )
 
350
  if len(out) >= top_k:
351
  break
 
352
  return out
353
 
354
  def _global_semantic(self, issue_text: str, top_k: int = 15) -> List[RetrievedCandidate]:
@@ -361,7 +498,6 @@ class RAGTreatmentSearchApp:
361
  out: List[RetrievedCandidate] = []
362
  for idx in order[: max(top_k, 1) * 20]:
363
  row = self.df.iloc[int(idx)]
364
- # Build minimal candidate (details not required for mismatch suggestion list)
365
  proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
366
  reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
367
  sub = _db_str(row.get("Sub-Zone", "")) or _db_str(row.get("face_subzone", "")) or _db_str(row.get("body_subzone", ""))
@@ -371,6 +507,7 @@ class RAGTreatmentSearchApp:
371
  RetrievedCandidate(
372
  row_index=int(idx),
373
  similarity=float(sims[idx]),
 
374
  procedure=_na_db(proc),
375
  region=_na_db(reg),
376
  sub_zone=_na_db(sub),
@@ -399,12 +536,10 @@ class RAGTreatmentSearchApp:
399
  )
400
  if len(out) >= top_k:
401
  break
 
402
  return out
403
 
404
  def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
405
- """
406
- Compute issue-only similarity inside selected region/sub-zone to detect irrelevance.
407
- """
408
  issue_text = (issue_text or "").strip()
409
  if not issue_text:
410
  return 0.0
@@ -418,7 +553,6 @@ class RAGTreatmentSearchApp:
418
  idxs = self._candidate_indices(region, sub_zone, t)
419
 
420
  if idxs.size == 0:
421
- # region only
422
  if t == "both":
423
  idx_s = self._candidate_indices(region, "", "surgical")
424
  idx_n = self._candidate_indices(region, "", "non-surgical")
@@ -433,16 +567,8 @@ class RAGTreatmentSearchApp:
433
  sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
434
  return float(np.max(sims)) if sims.size else 0.0
435
 
436
- def semantic_search(
437
- self,
438
- region: str,
439
- sub_zone: str,
440
- type_choice: str,
441
- issue_text: str,
442
- top_k: int = 12,
443
- ) -> List[RetrievedCandidate]:
444
  type_norm = _norm_type_choice(type_choice)
445
-
446
  query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
447
 
448
  if type_norm == "both":
@@ -451,7 +577,6 @@ class RAGTreatmentSearchApp:
451
  per = max(3, top_k // 2)
452
  res = self._semantic_over(idx_s, query, per) + self._semantic_over(idx_n, query, per)
453
  res.sort(key=lambda x: x.similarity, reverse=True)
454
- # de-dupe by row index
455
  seen = set()
456
  out = []
457
  for c in res:
@@ -499,7 +624,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
499
  if len(out) >= top_k:
500
  break
501
 
502
- # fill remainder
503
  for c in candidates:
504
  if len(out) >= top_k:
505
  break
@@ -508,7 +632,7 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
508
 
509
  return out
510
 
511
- # ---------------- Formatting (DB details) ----------------
512
 
513
  def _format_cost(self, mn: str, mx: str, unit: str) -> str:
514
  if mn == "Not found in database." and mx == "Not found in database.":
@@ -572,7 +696,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
572
  sub_zone = (sub_zone or "").strip()
573
  issue_text = (issue_text or "").strip()
574
 
575
- # Hard gate: must provide issue text
576
  if not region or not sub_zone:
577
  return {
578
  "answer_md": "Please select **Region** and **Sub-Zone** before running the search.",
@@ -595,7 +718,7 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
595
  "_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
596
  }
597
 
598
- # ---------- mismatch detection ----------
599
  global_cands = self._global_semantic(issue_text, top_k=15)
600
  global_best = global_cands[0].similarity if global_cands else 0.0
601
  local_best = candidates[0].similarity if candidates else 0.0
@@ -625,7 +748,6 @@ Return ONLY a comma-separated list of procedure names (exactly as written).
625
  )
626
 
627
  if mismatch_strict or mismatch_delta:
628
- # suggest correct region/sub-zones based on issue text
629
  suggestions = []
630
  seen = set()
631
  for c in global_cands:
@@ -665,11 +787,8 @@ Please select one of the suggested **Region/Sub-Zones** and run the search again
665
  "candidate_count": len(candidates),
666
  },
667
  }
668
- # ---------------------------------------
669
 
670
  best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
671
-
672
- # Ensure exactly final_k if possible
673
  if len(best) < int(final_k):
674
  for c in candidates:
675
  if c not in best:
@@ -682,7 +801,7 @@ Please select one of the suggested **Region/Sub-Zones** and run the search again
682
 
683
  return {
684
  "answer_md": answer_md,
685
- "sources": [], # DB-only mode
686
  "_debug": {
687
  "mismatch": False,
688
  "candidate_count": len(candidates),
 
3
 
4
  import os
5
  import pickle
6
+ import re
7
  import time
8
  from dataclasses import dataclass
9
  from typing import Dict, List, Optional, Tuple
 
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
17
  from llm_client import LocalLLMClient
18
+ from web_retriever import WebRetriever, WebDoc
19
 
20
 
21
  DEFAULT_EMBEDDING_MODEL = "sentence-transformers/static-similarity-mrl-multilingual-v1"
 
28
 
29
 
30
  def _norm_type_value(x: str) -> str:
 
 
 
 
31
  t = _norm(x).replace("_", "-").replace("–", "-").replace("—", "-")
32
  if ("non" in t and "surg" in t) or ("nonsurg" in t):
33
  return "non-surgical"
 
82
  return v if v else "Not found in database."
83
 
84
 
85
+ def _split_concerns(text: str) -> List[str]:
86
+ """
87
+ Split a concerns cell into candidate concern phrases.
88
+ Handles ; , | newlines and bullet-ish formats.
89
+ """
90
+ t = (text or "").strip()
91
+ if not t:
92
+ return []
93
+ t = t.replace("•", "\n").replace("·", "\n")
94
+ parts = re.split(r"[;\n\|]+", t)
95
+ out = []
96
+ for p in parts:
97
+ p = p.strip(" -\t\r")
98
+ if not p:
99
+ continue
100
+ if len(p) > 120:
101
+ # keep short fragments only
102
+ continue
103
+ out.append(p)
104
+ return out
105
+
106
+
107
  # ---------------------------- data model ----------------------------
108
 
109
  @dataclass
 
141
 
142
  class RAGTreatmentSearchApp:
143
  """
144
+ DB-driven structured RAG + Common Concerns (internet -> fallback DB).
145
 
146
+ - Core recommendations: semantic retrieval + LLM rerank + formatting from DB columns.
147
+ - Common concerns: fetch short common issues for Region/Sub-Zone to help the user fill the issue box.
 
 
 
 
148
  """
149
 
150
  def __init__(
 
154
  embeddings_cache_path: str = "treatment_embeddings.pkl",
155
  embedding_model_name: str = DEFAULT_EMBEDDING_MODEL,
156
  llm: Optional[LocalLLMClient] = None,
157
+ web: Optional[WebRetriever] = None,
158
  ):
159
  try:
160
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "2")))
 
172
  self.embeddings, self.texts = self._load_or_build_embeddings()
173
 
174
  self.llm = llm or LocalLLMClient()
175
+ self.web = web or WebRetriever()
176
 
177
+ # gates + mismatch knobs
178
  self.min_issue_chars = int(os.getenv("MIN_ISSUE_CHARS", "5"))
 
 
179
  self.local_issue_min_sim = float(os.getenv("LOCAL_ISSUE_MIN_SIM", "0.42"))
180
  self.global_issue_min_sim = float(os.getenv("GLOBAL_ISSUE_MIN_SIM", "0.52"))
181
  self.global_local_delta = float(os.getenv("GLOBAL_LOCAL_DELTA", "0.10"))
182
 
183
+ # common concerns config
184
+ self.common_web_enabled = os.getenv("COMMON_CONCERNS_WEB_ENABLED", "1").strip() != "0"
185
+ self.common_max_docs = int(os.getenv("COMMON_CONCERNS_MAX_DOCS", "4"))
186
+ self.common_max_chars = int(os.getenv("COMMON_CONCERNS_MAX_CHARS", "900"))
187
+ self.common_top_n = int(os.getenv("COMMON_CONCERNS_TOP_N", "4"))
188
+
189
+ self._common_cache: Dict[Tuple[str, str], List[str]] = {}
190
+
191
  # ---------------- DB ----------------
192
 
193
  def _load_db(self) -> pd.DataFrame:
 
197
  return pd.read_excel(self.excel_path, sheet_name=self.sheet_name)
198
 
199
  def _normalize_columns(self) -> None:
200
+ required_any = ["procedure_title", "main_zone", "treatment_type"]
 
 
 
 
 
 
 
 
 
201
  missing_any = [c for c in required_any if c not in self.df.columns]
202
  if missing_any:
203
  raise ValueError(f"Database missing required columns: {missing_any}")
204
 
205
+ # Region
 
206
  self.df["Region"] = self.df["main_zone"].fillna("").astype(str).str.strip()
207
 
208
+ # Sub-Zone (prefer face_subzone, else body_subzone, else existing Sub-Zone)
209
  if "face_subzone" in self.df.columns or "body_subzone" in self.df.columns:
210
+ face = self.df["face_subzone"].fillna("").astype(str).str.strip() if "face_subzone" in self.df.columns else None
211
+ body = self.df["body_subzone"].fillna("").astype(str).str.strip() if "body_subzone" in self.df.columns else None
212
+ if face is None:
213
+ self.df["Sub-Zone"] = body
214
+ else:
215
+ self.df["Sub-Zone"] = face
216
+ mask_empty = self.df["Sub-Zone"].eq("") | self.df["Sub-Zone"].str.lower().eq("nan")
217
+ if body is not None:
218
+ self.df.loc[mask_empty, "Sub-Zone"] = body.loc[mask_empty]
 
219
  else:
 
220
  self.df["Sub-Zone"] = self.df.get("Sub-Zone", "").fillna("").astype(str).str.strip()
221
 
222
+ # Procedure / Type aliases
223
  self.df["Procedure"] = self.df["procedure_title"].fillna("").astype(str).str.strip()
224
  self.df["Type"] = self.df["treatment_type"].fillna("").astype(str).str.strip()
225
 
 
226
  for col in ["Type", "Region", "Sub-Zone", "Procedure"]:
227
  self.df[col] = self.df[col].astype(str).fillna("").str.strip()
228
 
 
244
  out.append(ss)
245
  return sorted(out)
246
 
247
+ # ---------------- Common concerns ----------------
248
 
249
+ def _db_common_concerns(self, region: str, sub_zone: str, n: int = 4) -> List[str]:
250
+ """
251
+ Fallback: extract most frequent short concerns from DB rows in selected Region/Sub-Zone.
252
+ """
253
+ r = _norm(region)
254
+ sz = _norm(sub_zone)
255
+
256
+ m = self.df["_region_norm"].eq(r)
257
+ if sz:
258
+ m = m & (self.df["_subzone_norm"].eq(sz) | self.df["_subzone_norm"].str.contains(sz, na=False))
259
+
260
+ df2 = self.df[m]
261
+ if df2.empty:
262
+ return []
263
+
264
+ counts: Dict[str, int] = {}
265
+ for _, row in df2.iterrows():
266
+ concerns = _first_present(row, ["concerns", "Aesthetic Concerns", "aesthetic_concerns"])
267
+ for c in _split_concerns(concerns):
268
+ key = c.strip()
269
+ if len(key) < 4:
270
+ continue
271
+ counts[key] = counts.get(key, 0) + 1
272
+
273
+ ranked = sorted(counts.items(), key=lambda x: (-x[1], x[0].lower()))
274
+ return [k for (k, _) in ranked[: max(1, n)]]
275
+
276
+ def _web_common_concerns(self, region: str, sub_zone: str, n: int = 4) -> List[str]:
277
+ """
278
+ Internet-based: get common concerns for Region/Sub-Zone; extract with LLM as short phrases.
279
+
280
+ If web is blocked/rate-limited on HF, this naturally falls back to DB list.
281
+ """
282
+ if not self.common_web_enabled:
283
+ return []
284
+
285
+ region = (region or "").strip()
286
+ sub_zone = (sub_zone or "").strip()
287
+ if not region or not sub_zone:
288
+ return []
289
+
290
+ queries = [
291
+ f"common aesthetic concerns {region} {sub_zone}",
292
+ f"most common problems {sub_zone} aesthetic treatment",
293
+ f"{sub_zone} cosmetic concerns dark circles wrinkles pigmentation",
294
+ ]
295
+
296
+ docs = self.web.search_and_fetch(
297
+ queries=queries,
298
+ max_results_per_query=2,
299
+ max_docs=self.common_max_docs,
300
+ max_chars_per_doc=self.common_max_chars,
301
+ )
302
+
303
+ if not docs:
304
+ return []
305
+
306
+ def compact(s: str, limit: int = 650) -> str:
307
+ s = re.sub(r"\s+", " ", (s or "").strip())
308
+ return (s[:limit] + "…") if len(s) > limit else s
309
+
310
+ ev = []
311
+ for i, d in enumerate(docs[:4], start=1):
312
+ ev.append(f"[Doc {i}] {d.title}\n{compact(d.snippet)}")
313
+ evidence = "\n\n".join(ev)
314
+
315
+ prompt = f"""
316
+ You are extracting ONLY common patient concerns (issues) for:
317
+ Region: {region}
318
+ Sub-Zone: {sub_zone}
319
+
320
+ From the evidence, output STRICT JSON:
321
+ {{"concerns": ["...","..."]}}
322
+
323
+ Rules:
324
+ - return 1 to {n} short concern phrases (3-8 words each)
325
+ - no treatment names, only issues/concerns
326
+ - deduplicate similar items
327
+ - if unclear, return fewer items
328
+
329
+ Evidence:
330
+ {evidence}
331
+ """.strip()
332
+
333
+ raw = (self.llm.generate(prompt, temperature=0.2, max_tokens=160) or "").strip()
334
+ data = self.llm.safe_json_loads(raw)
335
+ arr = data.get("concerns", [])
336
+
337
+ out: List[str] = []
338
+ if isinstance(arr, list):
339
+ for x in arr:
340
+ s = str(x).strip()
341
+ if not s:
342
+ continue
343
+ if len(s) > 80:
344
+ continue
345
+ if s.lower() in {z.lower() for z in out}:
346
+ continue
347
+ out.append(s)
348
+
349
+ return out[:n]
350
+
351
+ def get_common_concerns(self, region: str, sub_zone: str, n: Optional[int] = None) -> List[str]:
352
  """
353
+ Public API for UI:
354
+ - first try internet extraction
355
+ - if it fails, use DB-derived concerns
356
+ - cached per (region, sub_zone)
357
  """
358
+ n = int(n or self.common_top_n)
359
+ key = (_norm(region), _norm(sub_zone))
360
+ if key in self._common_cache:
361
+ return self._common_cache[key]
362
+
363
+ concerns: List[str] = []
364
+ try:
365
+ concerns = self._web_common_concerns(region, sub_zone, n=n)
366
+ except Exception:
367
+ concerns = []
368
+
369
+ if not concerns:
370
+ concerns = self._db_common_concerns(region, sub_zone, n=n)
371
+
372
+ self._common_cache[key] = concerns
373
+ return concerns
374
+
375
+ # ---------------- Embeddings ----------------
376
+
377
+ def _row_to_text(self, row: pd.Series) -> str:
378
+ proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
379
+ reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
380
  sub = _db_str(row.get("Sub-Zone", ""))
381
+ typ = _db_str(row.get("treatment_type", "")) or _db_str(row.get("Type", ""))
382
 
383
  short_desc = _first_present(row, ["short_description", "procedure_description", "description"])
384
  concerns = _first_present(row, ["concerns", "aesthetic_concerns", "Aesthetic Concerns"])
385
  techniques = _first_present(row, ["techniques_brands_variants", "Technique / Technology / Brand", "techniques"])
 
386
  expected = _first_present(row, ["expected_results", "expected_result"])
387
  sidefx = _first_present(row, ["potential_side_effects", "side_effects", "risks"])
388
 
 
455
  RetrievedCandidate(
456
  row_index=row_index,
457
  similarity=float(sims[pos]),
458
+
459
  procedure=_na_db(proc),
460
  region=_na_db(reg),
461
  sub_zone=_na_db(sub),
 
482
  average_cost_max_chf=_na_db(_first_present(row, ["average_cost_max_chf"])),
483
  )
484
  )
485
+
486
  if len(out) >= top_k:
487
  break
488
+
489
  return out
490
 
491
  def _global_semantic(self, issue_text: str, top_k: int = 15) -> List[RetrievedCandidate]:
 
498
  out: List[RetrievedCandidate] = []
499
  for idx in order[: max(top_k, 1) * 20]:
500
  row = self.df.iloc[int(idx)]
 
501
  proc = _db_str(row.get("procedure_title", "")) or _db_str(row.get("Procedure", ""))
502
  reg = _db_str(row.get("main_zone", "")) or _db_str(row.get("Region", ""))
503
  sub = _db_str(row.get("Sub-Zone", "")) or _db_str(row.get("face_subzone", "")) or _db_str(row.get("body_subzone", ""))
 
507
  RetrievedCandidate(
508
  row_index=int(idx),
509
  similarity=float(sims[idx]),
510
+
511
  procedure=_na_db(proc),
512
  region=_na_db(reg),
513
  sub_zone=_na_db(sub),
 
536
  )
537
  if len(out) >= top_k:
538
  break
539
+
540
  return out
541
 
542
  def _local_issue_only_best_sim(self, region: str, sub_zone: str, type_choice: str, issue_text: str) -> float:
 
 
 
543
  issue_text = (issue_text or "").strip()
544
  if not issue_text:
545
  return 0.0
 
553
  idxs = self._candidate_indices(region, sub_zone, t)
554
 
555
  if idxs.size == 0:
 
556
  if t == "both":
557
  idx_s = self._candidate_indices(region, "", "surgical")
558
  idx_n = self._candidate_indices(region, "", "non-surgical")
 
567
  sims = cosine_similarity(q_emb, self.embeddings[idxs])[0]
568
  return float(np.max(sims)) if sims.size else 0.0
569
 
570
+ def semantic_search(self, region: str, sub_zone: str, type_choice: str, issue_text: str, top_k: int = 12) -> List[RetrievedCandidate]:
 
 
 
 
 
 
 
571
  type_norm = _norm_type_choice(type_choice)
 
572
  query = f"Region: {region} | Sub-Zone: {sub_zone} | Preference: {type_choice} | Issue: {issue_text}"
573
 
574
  if type_norm == "both":
 
577
  per = max(3, top_k // 2)
578
  res = self._semantic_over(idx_s, query, per) + self._semantic_over(idx_n, query, per)
579
  res.sort(key=lambda x: x.similarity, reverse=True)
 
580
  seen = set()
581
  out = []
582
  for c in res:
 
624
  if len(out) >= top_k:
625
  break
626
 
 
627
  for c in candidates:
628
  if len(out) >= top_k:
629
  break
 
632
 
633
  return out
634
 
635
+ # ---------------- Formatting ----------------
636
 
637
  def _format_cost(self, mn: str, mx: str, unit: str) -> str:
638
  if mn == "Not found in database." and mx == "Not found in database.":
 
696
  sub_zone = (sub_zone or "").strip()
697
  issue_text = (issue_text or "").strip()
698
 
 
699
  if not region or not sub_zone:
700
  return {
701
  "answer_md": "Please select **Region** and **Sub-Zone** before running the search.",
 
718
  "_debug": {"mismatch": False, "candidate_count": 0, "final_count": 0},
719
  }
720
 
721
+ # mismatch detection
722
  global_cands = self._global_semantic(issue_text, top_k=15)
723
  global_best = global_cands[0].similarity if global_cands else 0.0
724
  local_best = candidates[0].similarity if candidates else 0.0
 
748
  )
749
 
750
  if mismatch_strict or mismatch_delta:
 
751
  suggestions = []
752
  seen = set()
753
  for c in global_cands:
 
787
  "candidate_count": len(candidates),
788
  },
789
  }
 
790
 
791
  best = self._llm_rerank(issue_text, candidates, top_k=int(final_k))
 
 
792
  if len(best) < int(final_k):
793
  for c in candidates:
794
  if c not in best:
 
801
 
802
  return {
803
  "answer_md": answer_md,
804
+ "sources": [],
805
  "_debug": {
806
  "mismatch": False,
807
  "candidate_count": len(candidates),