Slaiwala commited on
Commit
8afca39
·
verified ·
1 Parent(s): fc726e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -41
app.py CHANGED
@@ -412,7 +412,11 @@ _ANATOMY_OR_HISTORY = re.compile(
412
  re.I
413
  )
414
  _PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
415
- CITE_TRIGGER = re.compile(r"\b(cite|citations?|references?)\b", re.I)
 
 
 
 
416
 
417
  # ================== PUBMED & RETRIEVAL ==================
418
  def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
@@ -757,65 +761,110 @@ def direct_llm_fallback(question: str) -> str:
757
 
758
  # ================== PUBLIC API ==================
759
  def ask(question: str) -> str:
760
- q = question.strip()
 
 
761
  m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
762
  if m:
763
  pmid = m.group(1)
764
  chunks = fetch_pubmed_chunks(pmid, max_papers=1)
765
  return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
766
 
767
- if _PAPERS_INTENT.search(q) or CITE_TRIGGER.search(q):
768
- core_q = re.sub(CITE_TRIGGER, "", q, count=1, flags=re.I).strip().rstrip(".")
769
- core_q = re.sub(_PAPERS_INTENT, "", core_q, flags=re.I).strip()
770
- if not core_q:
771
- core_q = "CT/QCT structural rigidity femur hip finite element"
772
- compact = _compact_terms(core_q)
773
- pm_query = (
774
- f'(({compact}) AND (hip[TiAb] OR femur[TiAb] OR femoral[TiAb])) AND '
775
- '("Finite Element Analysis"[MeSH Terms] OR finite element[TiAb] OR QCT[TiAb] OR CT[TiAb] OR rigidity[TiAb]) '
776
- 'AND ("2000"[DP] : "2025"[DP])'
777
- )
778
- cits = fetch_pubmed_citations(pm_query, max_results=5)
779
- if not cits:
780
- lab = detect_lab(core_q)
781
- pm_query = build_lab_query(core_q, lab=lab)
782
- cits = fetch_pubmed_citations(pm_query, max_results=5)
783
- if not cits:
784
- cits = _fallback_cits_for("EA")
785
- # Provide a short explanation + citations
786
- explanation = _answer_from_chunks(retrieve_context(core_q, top_k=3), core_q) or direct_llm_fallback(core_q)
787
- explanation = _post_clean(explanation)
788
- if cits:
789
- explanation += "\n\nCitations:\n" + "\n".join(cits)
790
- return _ensure_min_answer(explanation)
791
-
792
- det = deterministic_definitions_text(q)
793
- if det:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  dlog("ASK", "Deterministic definition/workflow fired")
795
- return det
796
 
 
797
  if not (_MSK_MUST.search(q) or _is_fe_override(q)):
798
- chunks = retrieve_context(q, top_k=3)
799
  if chunks:
800
  answer = _answer_from_chunks(chunks, q)
801
- # tiny safety to release VRAM between turns
802
- try:
803
- torch.cuda.empty_cache()
804
- except Exception:
805
- pass
806
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
807
  return direct_llm_fallback(q)
808
 
809
- chunks = retrieve_context(q, top_k=3)
 
810
  if not chunks:
811
  return direct_llm_fallback(q)
 
812
  answer = _answer_from_chunks(chunks, q)
813
- try:
814
- torch.cuda.empty_cache()
815
- except Exception:
816
- pass
817
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
818
 
 
 
819
  def deterministic_definitions_text(core_q: str) -> Optional[str]:
820
  q_lower = core_q.lower()
821
  if "define axial rigidity" in q_lower or "what is axial rigidity" in q_lower:
 
412
  re.I
413
  )
414
  _PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
415
+ CITE_TRIGGER = re.compile(
416
+ r"\b(?:and\s+(?:cite|citations?|references?|studies?|papers?))\.?\s*$",
417
+ re.IGNORECASE,
418
+ )
419
+
420
 
421
  # ================== PUBMED & RETRIEVAL ==================
422
  def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
 
761
 
762
  # ================== PUBLIC API ==================
763
  def ask(question: str) -> str:
764
+ q = (question or "").strip()
765
+
766
+ # --- PMID short-circuit ---------------------------------------------------
767
  m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
768
  if m:
769
  pmid = m.group(1)
770
  chunks = fetch_pubmed_chunks(pmid, max_papers=1)
771
  return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
772
 
773
+ # --- Normalize query & detect 'and cite' / 'key papers' intent ------------
774
+ wants_cite = bool(CITE_TRIGGER.search(q) or _PAPERS_INTENT.search(q))
775
+ core_q = CITE_TRIGGER.sub("", q).strip().rstrip(".")
776
+ core_q = _PAPERS_INTENT.sub("", core_q).strip()
777
+
778
+ # --- Deterministic definitions/workflows (used in both paths) -------------
779
+ det_text = deterministic_definitions_text(core_q)
780
+
781
+ # ==========================================================================
782
+ # Citation path
783
+ # ==========================================================================
784
+ if wants_cite:
785
+ if det_text:
786
+ explanation = det_text
787
+ lq = core_q.lower()
788
+ used_term = None
789
+ if ("torsion" in lq) or ("gj" in lq):
790
+ used_term = "GJ"
791
+ pm_query = (
792
+ '(torsion[TiAb] OR "polar moment"[TiAb] OR GJ[TiAb]) AND '
793
+ '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
794
+ '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
795
+ '("2000"[DP] : "2025"[DP])'
796
+ )
797
+ elif ("bending" in lq) or ("ei" in lq):
798
+ used_term = "EI"
799
+ pm_query = (
800
+ '(bending[TiAb] OR "second moment"[TiAb] OR EI[TiAb]) AND '
801
+ '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
802
+ '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
803
+ '("2000"[DP] : "2025"[DP])'
804
+ )
805
+ else:
806
+ used_term = "EA"
807
+ pm_query = (
808
+ '("axial rigidity"[TiAb] OR EA[TiAb] OR "axial stiffness"[TiAb]) AND '
809
+ '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
810
+ '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
811
+ '("2000"[DP] : "2025"[DP])'
812
+ )
813
+ citations = fetch_pubmed_citations(pm_query, max_results=5)
814
+ if not citations and used_term:
815
+ citations = _fallback_cits_for(used_term)
816
+ else:
817
+ # Grounded explanation + progressively broader citation search
818
+ explanation = _answer_from_chunks(retrieve_context(core_q, top_k=5), core_q)
819
+
820
+ # Broad, de-biased search first
821
+ compact = _compact_terms(core_q)
822
+ pm_query = (
823
+ f'({compact}) AND ("Bone and Bones"[MeSH] OR Femur[TiAb] OR Hip[TiAb] '
824
+ f'OR Rigidity[TiAb] OR "Tomography, X-Ray Computed"[MeSH] OR "Finite Element Analysis"[MeSH]) '
825
+ f'NOT (heart[TiAb] OR cardiac[TiAb] OR brain[TiAb] OR skull[TiAb] OR EGFR[TiAb]) '
826
+ f'AND ("2000"[DP] : "2025"[DP])'
827
+ )
828
+ citations = fetch_pubmed_citations(pm_query, max_results=5)
829
+
830
+ # If empty, bias toward lab-authored queries that are likely relevant
831
+ if not citations:
832
+ lab = detect_lab(core_q)
833
+ pm_query = build_lab_query(core_q, lab=lab)
834
+ citations = fetch_pubmed_citations(pm_query, max_results=5)
835
+
836
+ resp = _post_clean(explanation)
837
+ if citations:
838
+ resp += "\n\nCitations:\n" + "\n".join(citations)
839
+ else:
840
+ resp += f'\n\nSorry, no relevant citations found for “{core_q}.”'
841
+ return _ensure_min_answer(resp)
842
+
843
+ # ==========================================================================
844
+ # Non-citation path
845
+ # ==========================================================================
846
+ if det_text:
847
  dlog("ASK", "Deterministic definition/workflow fired")
848
+ return det_text
849
 
850
+ # If the query doesn't clearly hit MSK/FE tokens, try retrieval then fallback
851
  if not (_MSK_MUST.search(q) or _is_fe_override(q)):
852
+ chunks = retrieve_context(q, top_k=5)
853
  if chunks:
854
  answer = _answer_from_chunks(chunks, q)
 
 
 
 
 
855
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
856
  return direct_llm_fallback(q)
857
 
858
+ # Clear MSK/FE intent → do the normal grounded path
859
+ chunks = retrieve_context(q, top_k=5)
860
  if not chunks:
861
  return direct_llm_fallback(q)
862
+
863
  answer = _answer_from_chunks(chunks, q)
 
 
 
 
864
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
865
 
866
+
867
+
868
  def deterministic_definitions_text(core_q: str) -> Optional[str]:
869
  q_lower = core_q.lower()
870
  if "define axial rigidity" in q_lower or "what is axial rigidity" in q_lower: