Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,29 +1,19 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
| 4 |
import os, re, json, time, sys, csv, uuid, datetime
|
| 5 |
from typing import List, Dict, Any, Optional
|
| 6 |
from functools import lru_cache
|
| 7 |
from xml.etree import ElementTree as ET
|
| 8 |
-
|
|
|
|
| 9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
try:
|
| 11 |
-
from transformers import BitsAndBytesConfig #
|
| 12 |
except Exception:
|
| 13 |
BitsAndBytesConfig = None
|
| 14 |
|
| 15 |
-
# Normalize QUANTIZE env
|
| 16 |
-
QUANTIZE = os.environ.get("QUANTIZE", "none").strip().lower()
|
| 17 |
-
|
| 18 |
-
# Detect bitsandbytes presence
|
| 19 |
-
try:
|
| 20 |
-
import bitsandbytes as _bnb # noqa: F401
|
| 21 |
-
_BNB_AVAILABLE = True
|
| 22 |
-
except Exception:
|
| 23 |
-
_BNB_AVAILABLE = False
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
import numpy as np
|
| 28 |
import requests
|
| 29 |
import gradio as gr
|
|
@@ -33,11 +23,13 @@ ASSETS_DIR = os.environ.get("ASSETS_DIR", "assets")
|
|
| 33 |
FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
|
| 34 |
META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
|
| 35 |
REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
|
| 36 |
-
QUANTIZE = os.environ.get("QUANTIZE", "4bit") # "none" | "8bit" | "4bit"
|
| 37 |
-
# --- Turn logging ---
|
| 38 |
-
TRANSCRIPT_PATH = os.environ.get("TRANSCRIPT_PATH", "transcripts.jsonl")
|
| 39 |
-
PUSH_TRANSCRIPTS = os.environ.get("PUSH_TRANSCRIPTS", "1") == "1" # set to "0" to disable
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# Models
|
| 43 |
BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
|
@@ -53,12 +45,11 @@ NCBI_TOOL = os.environ.get("NCBI_TOOL", "askstein")
|
|
| 53 |
NCBI_APIKEY = os.environ.get("NCBI_APIKEY", "")
|
| 54 |
|
| 55 |
# Feedback logging
|
| 56 |
-
FEEDBACK_PATH
|
| 57 |
-
PUSH_FEEDBACK
|
| 58 |
-
HF_READ_TOKEN
|
| 59 |
-
HF_WRITE_TOKEN
|
| 60 |
-
SPACE_REPO_ID
|
| 61 |
-
|
| 62 |
|
| 63 |
# Generation / toggles
|
| 64 |
ALLOW_WIKIPEDIA = False
|
|
@@ -72,7 +63,6 @@ AUTO_CONTINUE = True
|
|
| 72 |
AUTO_CONT_MAX_STEPS = 2 # continue up to 2 extra chunks
|
| 73 |
AUTO_CONT_NEW_TOKENS = 256 # tokens per continuation step
|
| 74 |
|
| 75 |
-
|
| 76 |
def dlog(tag, msg):
|
| 77 |
if DEBUG: print(f"[{tag}] {msg}")
|
| 78 |
|
|
@@ -80,12 +70,14 @@ def dlog(tag, msg):
|
|
| 80 |
import faiss
|
| 81 |
from sentence_transformers import SentenceTransformer
|
| 82 |
import torch
|
| 83 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 84 |
from peft import PeftModel
|
| 85 |
import wikipedia
|
| 86 |
from wikipedia.exceptions import DisambiguationError, PageError
|
| 87 |
from huggingface_hub import login, snapshot_download, HfApi
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
# ================== GPU CHECK ==================
|
| 90 |
if not torch.cuda.is_available():
|
| 91 |
with gr.Blocks() as demo:
|
|
@@ -96,6 +88,8 @@ if not torch.cuda.is_available():
|
|
| 96 |
device = "cuda"
|
| 97 |
dtype = torch.float16
|
| 98 |
torch.manual_seed(42)
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# ================== RELEVANCE CONFIG ==================
|
| 101 |
DEFAULT_REL_CONFIG = {
|
|
@@ -206,7 +200,7 @@ except Exception as e:
|
|
| 206 |
|
| 207 |
_IS_IP = isinstance(index, faiss.IndexFlatIP) or "IndexFlatIP" in type(index).__name__
|
| 208 |
|
| 209 |
-
# ==================
|
| 210 |
if HF_READ_TOKEN:
|
| 211 |
try:
|
| 212 |
login(token=HF_READ_TOKEN)
|
|
@@ -214,17 +208,24 @@ if HF_READ_TOKEN:
|
|
| 214 |
except Exception as e:
|
| 215 |
dlog("HF", f"Login issue: {e}")
|
| 216 |
|
| 217 |
-
|
| 218 |
if ADAPTER_REPO:
|
| 219 |
ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
|
| 220 |
|
| 221 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
dlog("LLM", f"Loading base model: {BASE_MODEL}")
|
| 223 |
tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 224 |
|
| 225 |
use_bnb = QUANTIZE in {"8bit", "4bit"} and BitsAndBytesConfig is not None and _BNB_AVAILABLE
|
| 226 |
|
| 227 |
if use_bnb:
|
|
|
|
| 228 |
bnb_config = BitsAndBytesConfig(
|
| 229 |
load_in_8bit=(QUANTIZE == "8bit"),
|
| 230 |
load_in_4bit=(QUANTIZE == "4bit"),
|
|
@@ -238,23 +239,28 @@ if use_bnb:
|
|
| 238 |
quantization_config=bnb_config,
|
| 239 |
)
|
| 240 |
else:
|
| 241 |
-
#
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
|
| 252 |
model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
|
| 253 |
model_lm.eval()
|
| 254 |
|
| 255 |
-
|
| 256 |
GEN_ARGS_GROUNDED = dict(
|
| 257 |
max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
|
|
|
|
| 258 |
do_sample=False,
|
| 259 |
num_beams=1,
|
| 260 |
no_repeat_ngram_size=3,
|
|
@@ -263,6 +269,7 @@ GEN_ARGS_GROUNDED = dict(
|
|
| 263 |
)
|
| 264 |
GEN_ARGS_FALLBACK = dict(
|
| 265 |
max_new_tokens=MAX_NEW_TOKENS_FALLBACK,
|
|
|
|
| 266 |
do_sample=False,
|
| 267 |
num_beams=1,
|
| 268 |
no_repeat_ngram_size=3,
|
|
@@ -270,38 +277,6 @@ GEN_ARGS_FALLBACK = dict(
|
|
| 270 |
eos_token_id=tokenizer_lm.eos_token_id,
|
| 271 |
)
|
| 272 |
|
| 273 |
-
def _generate(inputs, grounded: bool):
|
| 274 |
-
args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
|
| 275 |
-
in_len = inputs["input_ids"].shape[-1]
|
| 276 |
-
with torch.inference_mode():
|
| 277 |
-
out = model_lm.generate(**inputs, **args)
|
| 278 |
-
|
| 279 |
-
if not AUTO_CONTINUE:
|
| 280 |
-
return out
|
| 281 |
-
|
| 282 |
-
steps = 0
|
| 283 |
-
while steps < AUTO_CONT_MAX_STEPS:
|
| 284 |
-
seq = out[0]
|
| 285 |
-
ended_with_eos = (seq[-1].item() == tokenizer_lm.eos_token_id)
|
| 286 |
-
hit_cap = (seq.shape[0] - in_len) >= args["max_new_tokens"]
|
| 287 |
-
if ended_with_eos or not hit_cap:
|
| 288 |
-
break
|
| 289 |
-
|
| 290 |
-
# continue generation from the current sequence
|
| 291 |
-
cont_inputs = {
|
| 292 |
-
"input_ids": seq.unsqueeze(0),
|
| 293 |
-
"attention_mask": torch.ones_like(seq).unsqueeze(0),
|
| 294 |
-
}
|
| 295 |
-
cont_inputs = {k: v.to(device) for k, v in cont_inputs.items()}
|
| 296 |
-
cont_args = dict(args)
|
| 297 |
-
cont_args["max_new_tokens"] = AUTO_CONT_NEW_TOKENS
|
| 298 |
-
|
| 299 |
-
out = model_lm.generate(**cont_inputs, **cont_args)
|
| 300 |
-
steps += 1
|
| 301 |
-
|
| 302 |
-
return out
|
| 303 |
-
|
| 304 |
-
|
| 305 |
# ================== UTILITIES ==================
|
| 306 |
_SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
|
| 307 |
def _to_text(rec: Any) -> str:
|
|
@@ -437,7 +412,9 @@ _ANATOMY_OR_HISTORY = re.compile(
|
|
| 437 |
re.I
|
| 438 |
)
|
| 439 |
_PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
|
|
|
|
| 440 |
|
|
|
|
| 441 |
def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
|
| 442 |
retries = 1
|
| 443 |
chunks: List[Dict[str, Any]] = []
|
|
@@ -617,7 +594,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
| 617 |
if results:
|
| 618 |
dlog("PUBMED", "PubMed search hit")
|
| 619 |
return results
|
| 620 |
-
# Wikipedia fallback (unconditional after PubMed miss)
|
| 621 |
wiki = wiki_summary_allow(q, sentences=3)
|
| 622 |
if wiki:
|
| 623 |
dlog("WIKI", "Wikipedia fallback hit")
|
|
@@ -625,7 +601,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
| 625 |
dlog("RETRIEVAL", "No results found")
|
| 626 |
return []
|
| 627 |
|
| 628 |
-
|
| 629 |
# FAISS path
|
| 630 |
q_emb = embed_model.encode([q], convert_to_numpy=True).astype("float32")
|
| 631 |
if _IS_IP:
|
|
@@ -666,7 +641,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
| 666 |
dlog("PUBMED", "PubMed search hit")
|
| 667 |
return results
|
| 668 |
|
| 669 |
-
# Wikipedia fallback (unconditional after PubMed miss)
|
| 670 |
wiki = wiki_summary_allow(q, sentences=3)
|
| 671 |
if wiki:
|
| 672 |
dlog("WIKI", "Wikipedia fallback hit")
|
|
@@ -675,7 +649,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
| 675 |
dlog("RETRIEVAL", "No results at all")
|
| 676 |
return []
|
| 677 |
|
| 678 |
-
|
| 679 |
def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
|
| 680 |
header = (
|
| 681 |
"You are Askstein (orthopedic biomechanics). Use ONLY the [Context] to answer. "
|
|
@@ -684,7 +657,7 @@ def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
|
|
| 684 |
"Do not discuss cardiology, neurology, or unrelated domains."
|
| 685 |
)
|
| 686 |
cleaned = []
|
| 687 |
-
per_chunk_chars =
|
| 688 |
for c in chunks:
|
| 689 |
t = _to_text(c)
|
| 690 |
if t: cleaned.append(t[:per_chunk_chars])
|
|
@@ -695,22 +668,64 @@ def _decode_generated(out_ids, in_len: int) -> str:
|
|
| 695 |
gen = out_ids[0][in_len:]
|
| 696 |
return tokenizer_lm.decode(gen, skip_special_tokens=True).lstrip(". \n").strip()
|
| 697 |
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
llm_prompt = f"{sys_prompt}\n\nQuestion: {question}\nAnswer:"
|
| 705 |
-
inputs = tokenizer_lm(llm_prompt, return_tensors="pt").to(device)
|
| 706 |
-
out = _generate(inputs, grounded=False)
|
| 707 |
in_len = inputs["input_ids"].shape[-1]
|
| 708 |
-
ans = _post_clean(_decode_generated(out, in_len))
|
| 709 |
-
# Strip any made-up reference sections the model might add
|
| 710 |
-
ans = re.sub(r"(?is)(^|\n)\s*references?:.*$", "", ans).strip()
|
| 711 |
-
return "[LLM fallback — ungrounded]\n\n" + ans
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
def _synthesize_answer(chunks: List[Dict[str, Any]], question: str) -> str:
|
| 715 |
prompt = build_prompt(chunks, question)
|
| 716 |
inputs = tokenizer_lm(prompt, return_tensors="pt").to(device)
|
|
@@ -726,35 +741,21 @@ def _answer_from_chunks(chunks: List[Dict[str, Any]], question: str) -> str:
|
|
| 726 |
return _synthesize_answer(chunks, question)
|
| 727 |
return _synthesize_answer(chunks, question)
|
| 728 |
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
"G = E/(2(1+ν)) from the voxel E map (ν≈0.3). Compute per-slice and report the minimum.")
|
| 743 |
-
if re.search(r"\b(outline|steps|workflow|protocol)\b.*\b(ct|qct).*(rigidity|ea|ei|gj)", q_lower):
|
| 744 |
-
return (
|
| 745 |
-
"CT-based structural rigidity (CTRA/QCT) workflow:\n"
|
| 746 |
-
"1) Acquire QCT (≤1 mm; density phantom).\n"
|
| 747 |
-
"2) Preprocess & segment bone.\n"
|
| 748 |
-
"3) HU→ρ; ρ→E calibration.\n"
|
| 749 |
-
"4) Cross-sections along neck axis.\n"
|
| 750 |
-
"5) EA, EI_x/EI_y, GJ (G≈E/(2(1+ν))).\n"
|
| 751 |
-
"6) Extract minima & validate vs FEA/mech tests."
|
| 752 |
-
)
|
| 753 |
-
if re.search(r"\b(modulus)\b.*\brigidity\b|\bdefine\s+modulus\b", q_lower):
|
| 754 |
-
return ("Elastic modulus (E) is a material property (Pa). "
|
| 755 |
-
"Rigidity is structural (EA, EI, GJ). Modulus ≠ rigidity.")
|
| 756 |
-
return None
|
| 757 |
|
|
|
|
| 758 |
def ask(question: str) -> str:
|
| 759 |
q = question.strip()
|
| 760 |
m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
|
|
@@ -763,8 +764,11 @@ def ask(question: str) -> str:
|
|
| 763 |
chunks = fetch_pubmed_chunks(pmid, max_papers=1)
|
| 764 |
return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
|
| 765 |
|
| 766 |
-
if _PAPERS_INTENT.search(q):
|
| 767 |
-
core_q = re.sub(
|
|
|
|
|
|
|
|
|
|
| 768 |
compact = _compact_terms(core_q)
|
| 769 |
pm_query = (
|
| 770 |
f'(({compact}) AND (hip[TiAb] OR femur[TiAb] OR femoral[TiAb])) AND '
|
|
@@ -772,82 +776,75 @@ def ask(question: str) -> str:
|
|
| 772 |
'AND ("2000"[DP] : "2025"[DP])'
|
| 773 |
)
|
| 774 |
cits = fetch_pubmed_citations(pm_query, max_results=5)
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
'("2000"[DP] : "2025"[DP])')
|
| 791 |
-
elif ("bending" in lq) or ("ei" in lq):
|
| 792 |
-
used_term = "EI"
|
| 793 |
-
pm_query = ('(bending[TiAb] OR "second moment"[TiAb] OR EI[TiAb]) AND '
|
| 794 |
-
'("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
|
| 795 |
-
'("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
|
| 796 |
-
'("2000"[DP] : "2025"[DP])')
|
| 797 |
-
else:
|
| 798 |
-
used_term = "EA"
|
| 799 |
-
pm_query = ('("axial rigidity"[TiAb] OR EA[TiAb] OR "axial stiffness"[TiAb]) AND '
|
| 800 |
-
'("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
|
| 801 |
-
'("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
|
| 802 |
-
'("2000"[DP] : "2025"[DP])')
|
| 803 |
-
citations = fetch_pubmed_citations(pm_query, max_results=5)
|
| 804 |
-
if not citations and used_term:
|
| 805 |
-
dlog("CITE", f"PubMed empty → fallback {used_term}")
|
| 806 |
-
citations = _fallback_cits_for(used_term)
|
| 807 |
-
else:
|
| 808 |
-
explanation = _answer_from_chunks(retrieve_context(core_q, top_k=5), core_q)
|
| 809 |
-
pm_query = f'"{core_q}"[Title/Abstract]'
|
| 810 |
-
citations = fetch_pubmed_citations(pm_query, max_results=5)
|
| 811 |
-
if not citations:
|
| 812 |
-
lab = detect_lab(core_q)
|
| 813 |
-
pm_query = build_lab_query(core_q, lab=lab)
|
| 814 |
-
citations = fetch_pubmed_citations(pm_query, max_results=5)
|
| 815 |
-
if not citations:
|
| 816 |
-
compact = _compact_terms(core_q)
|
| 817 |
-
pm_query = (
|
| 818 |
-
f'({compact}) AND ("Bone and Bones"[MeSH] OR Femur[TiAb] OR Hip[TiAb] '
|
| 819 |
-
f'OR Rigidity[TiAb] OR "Tomography, X-Ray Computed"[MeSH] OR "Finite Element Analysis"[MeSH]) '
|
| 820 |
-
f'NOT (heart[TiAb] OR cardiac[TiAb] OR brain[TiAb] OR skull[TiAb] OR EGFR[TiAb]) '
|
| 821 |
-
f'AND ("2000"[DP] : "2025"[DP])'
|
| 822 |
-
)
|
| 823 |
-
citations = fetch_pubmed_citations(pm_query, max_results=5)
|
| 824 |
-
resp = explanation
|
| 825 |
-
if citations:
|
| 826 |
-
resp += "\n\nCitations:\n" + "\n".join(citations)
|
| 827 |
-
else:
|
| 828 |
-
resp += f"\n\nSorry, no relevant citations found for “{core_q}.”"
|
| 829 |
-
return _ensure_min_answer(_post_clean(resp))
|
| 830 |
-
|
| 831 |
-
det_answer = deterministic_definitions_text(q)
|
| 832 |
-
if det_answer:
|
| 833 |
dlog("ASK", "Deterministic definition/workflow fired")
|
| 834 |
-
return
|
| 835 |
|
| 836 |
if not (_MSK_MUST.search(q) or _is_fe_override(q)):
|
| 837 |
-
chunks = retrieve_context(q, top_k=
|
| 838 |
if chunks:
|
| 839 |
-
dlog("CLEAN", "Post-clean applied")
|
| 840 |
answer = _answer_from_chunks(chunks, q)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
|
| 842 |
return direct_llm_fallback(q)
|
| 843 |
|
| 844 |
-
chunks = retrieve_context(q, top_k=
|
| 845 |
if not chunks:
|
| 846 |
return direct_llm_fallback(q)
|
| 847 |
-
dlog("CLEAN", "Post-clean applied")
|
| 848 |
answer = _answer_from_chunks(chunks, q)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
|
| 850 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
# ================== UI: NAME GATE + PER-ANSWER FEEDBACK ==================
|
| 852 |
def _now_iso():
|
| 853 |
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
|
|
@@ -870,63 +867,7 @@ def enter_app(first_name, last_name, state):
|
|
| 870 |
state["last_name"] = last_name
|
| 871 |
return gr.update(visible=False), gr.update(visible=True), state, f"Welcome, {first_name}! You can start chatting."
|
| 872 |
|
| 873 |
-
def _log_turn(state: Dict[str, Any], question: str, answer: str):
|
| 874 |
-
rec = {
|
| 875 |
-
"timestamp_utc": _now_iso(),
|
| 876 |
-
"session_id": state.get("session_id", ""),
|
| 877 |
-
"first_name": state.get("first_name", ""),
|
| 878 |
-
"last_name": state.get("last_name", ""),
|
| 879 |
-
"question": question,
|
| 880 |
-
"answer": answer,
|
| 881 |
-
}
|
| 882 |
-
with open(TRANSCRIPT_PATH, "a", encoding="utf-8") as f:
|
| 883 |
-
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 884 |
-
|
| 885 |
-
if PUSH_TRANSCRIPTS:
|
| 886 |
-
_push_file_to_hub(TRANSCRIPT_PATH, "analytics/transcripts.jsonl")
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
def predict(message, chat_history, state):
|
| 890 |
-
msg = (message or "").strip()
|
| 891 |
-
if not msg:
|
| 892 |
-
# No input → don't show feedback, just return current state
|
| 893 |
-
return chat_history, "", gr.update(visible=False), None, "", state
|
| 894 |
-
|
| 895 |
-
try:
|
| 896 |
-
answer = ask(msg)
|
| 897 |
-
except Exception as e:
|
| 898 |
-
answer = f"Sorry — something went wrong: {e!r}"
|
| 899 |
-
|
| 900 |
-
chat_history = (chat_history or []) + [(msg, answer)]
|
| 901 |
-
state["last_q"] = msg
|
| 902 |
-
state["last_a"] = answer
|
| 903 |
-
|
| 904 |
-
# Log every turn (safe if _log_turn isn't defined)
|
| 905 |
-
try:
|
| 906 |
-
_log_turn(state, msg, answer)
|
| 907 |
-
except Exception:
|
| 908 |
-
pass
|
| 909 |
-
|
| 910 |
-
return (
|
| 911 |
-
chat_history,
|
| 912 |
-
"", # clear input
|
| 913 |
-
gr.update(visible=True), # show feedback pane
|
| 914 |
-
gr.update(value=None), # reset rating
|
| 915 |
-
gr.update(value=""), # reset comment
|
| 916 |
-
state
|
| 917 |
-
)
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
# --- Hub upload helper --------------------------------------------------------
|
| 921 |
def _push_file_to_hub(local_path: str, repo_path: str) -> None:
|
| 922 |
-
"""
|
| 923 |
-
Upload a local file to your Space repo.
|
| 924 |
-
|
| 925 |
-
Requires:
|
| 926 |
-
- PUSH_FEEDBACK=1
|
| 927 |
-
- HF_WRITE_TOKEN (write access to the Space)
|
| 928 |
-
- SPACE_REPO_ID (e.g., "username/YourSpace")
|
| 929 |
-
"""
|
| 930 |
if not PUSH_FEEDBACK:
|
| 931 |
return
|
| 932 |
if not os.path.exists(local_path):
|
|
@@ -938,7 +879,6 @@ def _push_file_to_hub(local_path: str, repo_path: str) -> None:
|
|
| 938 |
if not HF_WRITE_TOKEN:
|
| 939 |
dlog("UPLOAD", "Skip: HF_WRITE_TOKEN not set")
|
| 940 |
return
|
| 941 |
-
|
| 942 |
try:
|
| 943 |
api = HfApi(token=HF_WRITE_TOKEN)
|
| 944 |
api.upload_file(
|
|
@@ -952,13 +892,28 @@ def _push_file_to_hub(local_path: str, repo_path: str) -> None:
|
|
| 952 |
except Exception as e:
|
| 953 |
dlog("UPLOAD", f"Upload failed: {e}")
|
| 954 |
|
| 955 |
-
# --- Feedback uploader --------------------------------------------------------
|
| 956 |
def _push_feedback_to_hub() -> None:
|
| 957 |
-
"""Upload feedback.csv to analytics/feedback.csv in this Space repo (if enabled)."""
|
| 958 |
_push_file_to_hub(FEEDBACK_PATH, "analytics/feedback.csv")
|
| 959 |
|
| 960 |
-
|
| 961 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
|
| 963 |
def save_feedback(rating, comment, state):
|
| 964 |
if rating is None:
|
|
@@ -986,6 +941,41 @@ def save_feedback(rating, comment, state):
|
|
| 986 |
except Exception as e:
|
| 987 |
return f"Failed to save feedback: {e}", gr.update(visible=True)
|
| 988 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 989 |
with gr.Blocks(theme="soft") as demo:
|
| 990 |
gr.Markdown("# Askstein — Orthopedic Biomechanics Chat (CT/QCT Rigidity, FE)")
|
| 991 |
gr.Markdown("Grounded answers (FAISS + PubMed). Please enter your name to continue.")
|
|
@@ -997,9 +987,9 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 997 |
with gate:
|
| 998 |
with gr.Row():
|
| 999 |
first_tb = gr.Textbox(label="First name", placeholder="e.g., Shubh", scale=1)
|
| 1000 |
-
last_tb
|
| 1001 |
enter_btn = gr.Button("Enter", variant="primary")
|
| 1002 |
-
gate_msg
|
| 1003 |
|
| 1004 |
# ---- App (hidden until gate passes) ----
|
| 1005 |
app = gr.Group(visible=False)
|
|
@@ -1013,12 +1003,12 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 1013 |
feedback_grp = gr.Group(visible=False)
|
| 1014 |
with feedback_grp:
|
| 1015 |
gr.Markdown("### How helpful was this answer?")
|
| 1016 |
-
rating = gr.Radio(choices=[1,
|
| 1017 |
comment = gr.Textbox(label="Optional comment", placeholder="What was good or missing?")
|
| 1018 |
submit_fb = gr.Button("Submit feedback")
|
| 1019 |
fb_status = gr.Markdown("")
|
| 1020 |
|
| 1021 |
-
# ---- Wiring (
|
| 1022 |
enter_btn.click(
|
| 1023 |
fn=enter_app,
|
| 1024 |
inputs=[first_tb, last_tb, state],
|
|
@@ -1053,6 +1043,6 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 1053 |
concurrency_limit=4,
|
| 1054 |
)
|
| 1055 |
|
| 1056 |
-
#
|
| 1057 |
-
demo.queue(max_size=64)
|
| 1058 |
demo.launch(max_threads=8)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
# ================== STD / CORE ==================
|
| 5 |
import os, re, json, time, sys, csv, uuid, datetime
|
| 6 |
from typing import List, Dict, Any, Optional
|
| 7 |
from functools import lru_cache
|
| 8 |
from xml.etree import ElementTree as ET
|
| 9 |
+
|
| 10 |
+
# ================== TRANSFORMERS / TORCH ==================
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
try:
|
| 13 |
+
from transformers import BitsAndBytesConfig # may exist even if bnb isn't installed
|
| 14 |
except Exception:
|
| 15 |
BitsAndBytesConfig = None
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import numpy as np
|
| 18 |
import requests
|
| 19 |
import gradio as gr
|
|
|
|
| 23 |
FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
|
| 24 |
META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
|
| 25 |
REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
# Normalize QUANTIZE env (default: no quantization)
|
| 28 |
+
QUANTIZE = os.environ.get("QUANTIZE", "none").strip().lower() # "none" | "8bit" | "4bit"
|
| 29 |
+
|
| 30 |
+
# Turn logging
|
| 31 |
+
TRANSCRIPT_PATH = os.environ.get("TRANSCRIPT_PATH", "transcripts.jsonl")
|
| 32 |
+
PUSH_TRANSCRIPTS = os.environ.get("PUSH_TRANSCRIPTS", "1") == "1" # set "0" to disable
|
| 33 |
|
| 34 |
# Models
|
| 35 |
BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
|
|
|
| 45 |
NCBI_APIKEY = os.environ.get("NCBI_APIKEY", "")
|
| 46 |
|
| 47 |
# Feedback logging
|
| 48 |
+
FEEDBACK_PATH = os.environ.get("FEEDBACK_PATH", "feedback.csv")
|
| 49 |
+
PUSH_FEEDBACK = os.environ.get("PUSH_FEEDBACK", "0") == "1" # set "1" to enable Hub upload
|
| 50 |
+
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN", os.environ.get("HF_TOKEN", ""))
|
| 51 |
+
HF_WRITE_TOKEN = os.environ.get("HF_WRITE_TOKEN", HF_READ_TOKEN)
|
| 52 |
+
SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "")
|
|
|
|
| 53 |
|
| 54 |
# Generation / toggles
|
| 55 |
ALLOW_WIKIPEDIA = False
|
|
|
|
| 63 |
AUTO_CONT_MAX_STEPS = 2 # continue up to 2 extra chunks
|
| 64 |
AUTO_CONT_NEW_TOKENS = 256 # tokens per continuation step
|
| 65 |
|
|
|
|
| 66 |
def dlog(tag, msg):
|
| 67 |
if DEBUG: print(f"[{tag}] {msg}")
|
| 68 |
|
|
|
|
| 70 |
import faiss
|
| 71 |
from sentence_transformers import SentenceTransformer
|
| 72 |
import torch
|
|
|
|
| 73 |
from peft import PeftModel
|
| 74 |
import wikipedia
|
| 75 |
from wikipedia.exceptions import DisambiguationError, PageError
|
| 76 |
from huggingface_hub import login, snapshot_download, HfApi
|
| 77 |
|
| 78 |
+
# ================== LOW-VRAM RUNTIME KNOBS ==================
|
| 79 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,expandable_segments:True")
|
| 80 |
+
|
| 81 |
# ================== GPU CHECK ==================
|
| 82 |
if not torch.cuda.is_available():
|
| 83 |
with gr.Blocks() as demo:
|
|
|
|
| 88 |
device = "cuda"
|
| 89 |
dtype = torch.float16
|
| 90 |
torch.manual_seed(42)
|
| 91 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 92 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 93 |
|
| 94 |
# ================== RELEVANCE CONFIG ==================
|
| 95 |
DEFAULT_REL_CONFIG = {
|
|
|
|
| 200 |
|
| 201 |
_IS_IP = isinstance(index, faiss.IndexFlatIP) or "IndexFlatIP" in type(index).__name__
|
| 202 |
|
| 203 |
+
# ================== HUGGING FACE LOGIN & ADAPTER PATH ==================
|
| 204 |
if HF_READ_TOKEN:
|
| 205 |
try:
|
| 206 |
login(token=HF_READ_TOKEN)
|
|
|
|
| 208 |
except Exception as e:
|
| 209 |
dlog("HF", f"Login issue: {e}")
|
| 210 |
|
|
|
|
| 211 |
if ADAPTER_REPO:
|
| 212 |
ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
|
| 213 |
|
| 214 |
+
# ================== QUANTIZATION AVAILABILITY ==================
|
| 215 |
+
try:
|
| 216 |
+
import bitsandbytes as _bnb # noqa: F401
|
| 217 |
+
_BNB_AVAILABLE = True
|
| 218 |
+
except Exception:
|
| 219 |
+
_BNB_AVAILABLE = False
|
| 220 |
+
|
| 221 |
+
# ================== LLM (BASE + LORA) ==================
|
| 222 |
dlog("LLM", f"Loading base model: {BASE_MODEL}")
|
| 223 |
tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 224 |
|
| 225 |
use_bnb = QUANTIZE in {"8bit", "4bit"} and BitsAndBytesConfig is not None and _BNB_AVAILABLE
|
| 226 |
|
| 227 |
if use_bnb:
|
| 228 |
+
# Quantized path (only if explicitly requested and bnb is installed)
|
| 229 |
bnb_config = BitsAndBytesConfig(
|
| 230 |
load_in_8bit=(QUANTIZE == "8bit"),
|
| 231 |
load_in_4bit=(QUANTIZE == "4bit"),
|
|
|
|
| 239 |
quantization_config=bnb_config,
|
| 240 |
)
|
| 241 |
else:
|
| 242 |
+
# fp16 path with SDPA attention (lower VRAM). Fallback if not supported.
|
| 243 |
+
try:
|
| 244 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 245 |
+
BASE_MODEL,
|
| 246 |
+
torch_dtype=dtype,
|
| 247 |
+
device_map="auto",
|
| 248 |
+
attn_implementation="sdpa",
|
| 249 |
+
)
|
| 250 |
+
except TypeError:
|
| 251 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 252 |
+
BASE_MODEL,
|
| 253 |
+
torch_dtype=dtype,
|
| 254 |
+
device_map="auto",
|
| 255 |
+
)
|
| 256 |
|
| 257 |
dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
|
| 258 |
model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
|
| 259 |
model_lm.eval()
|
| 260 |
|
|
|
|
| 261 |
GEN_ARGS_GROUNDED = dict(
|
| 262 |
max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
|
| 263 |
+
min_new_tokens=220,
|
| 264 |
do_sample=False,
|
| 265 |
num_beams=1,
|
| 266 |
no_repeat_ngram_size=3,
|
|
|
|
| 269 |
)
|
| 270 |
GEN_ARGS_FALLBACK = dict(
|
| 271 |
max_new_tokens=MAX_NEW_TOKENS_FALLBACK,
|
| 272 |
+
min_new_tokens=120,
|
| 273 |
do_sample=False,
|
| 274 |
num_beams=1,
|
| 275 |
no_repeat_ngram_size=3,
|
|
|
|
| 277 |
eos_token_id=tokenizer_lm.eos_token_id,
|
| 278 |
)
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# ================== UTILITIES ==================
|
| 281 |
_SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
|
| 282 |
def _to_text(rec: Any) -> str:
|
|
|
|
| 412 |
re.I
|
| 413 |
)
|
| 414 |
_PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
|
| 415 |
+
CITE_TRIGGER = re.compile(r"\b(cite|citations?|references?)\b", re.I)
|
| 416 |
|
| 417 |
+
# ================== PUBMED & RETRIEVAL ==================
|
| 418 |
def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
|
| 419 |
retries = 1
|
| 420 |
chunks: List[Dict[str, Any]] = []
|
|
|
|
| 594 |
if results:
|
| 595 |
dlog("PUBMED", "PubMed search hit")
|
| 596 |
return results
|
|
|
|
| 597 |
wiki = wiki_summary_allow(q, sentences=3)
|
| 598 |
if wiki:
|
| 599 |
dlog("WIKI", "Wikipedia fallback hit")
|
|
|
|
| 601 |
dlog("RETRIEVAL", "No results found")
|
| 602 |
return []
|
| 603 |
|
|
|
|
| 604 |
# FAISS path
|
| 605 |
q_emb = embed_model.encode([q], convert_to_numpy=True).astype("float32")
|
| 606 |
if _IS_IP:
|
|
|
|
| 641 |
dlog("PUBMED", "PubMed search hit")
|
| 642 |
return results
|
| 643 |
|
|
|
|
| 644 |
wiki = wiki_summary_allow(q, sentences=3)
|
| 645 |
if wiki:
|
| 646 |
dlog("WIKI", "Wikipedia fallback hit")
|
|
|
|
| 649 |
dlog("RETRIEVAL", "No results at all")
|
| 650 |
return []
|
| 651 |
|
|
|
|
| 652 |
def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
|
| 653 |
header = (
|
| 654 |
"You are Askstein (orthopedic biomechanics). Use ONLY the [Context] to answer. "
|
|
|
|
| 657 |
"Do not discuss cardiology, neurology, or unrelated domains."
|
| 658 |
)
|
| 659 |
cleaned = []
|
| 660 |
+
per_chunk_chars = 900 # lower prompt length = lower KV memory
|
| 661 |
for c in chunks:
|
| 662 |
t = _to_text(c)
|
| 663 |
if t: cleaned.append(t[:per_chunk_chars])
|
|
|
|
| 668 |
gen = out_ids[0][in_len:]
|
| 669 |
return tokenizer_lm.decode(gen, skip_special_tokens=True).lstrip(". \n").strip()
|
| 670 |
|
| 671 |
+
def _gen_once(inputs, args) -> Any:
|
| 672 |
+
with torch.inference_mode():
|
| 673 |
+
return model_lm.generate(**inputs, **args, use_cache=True)
|
| 674 |
+
|
| 675 |
+
def _generate(inputs, grounded: bool):
|
| 676 |
+
args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
|
|
|
|
|
|
|
|
|
|
| 677 |
in_len = inputs["input_ids"].shape[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
+
# First attempt
|
| 680 |
+
try:
|
| 681 |
+
out = _gen_once(inputs, args)
|
| 682 |
+
except torch.cuda.OutOfMemoryError:
|
| 683 |
+
try:
|
| 684 |
+
torch.cuda.empty_cache()
|
| 685 |
+
except Exception:
|
| 686 |
+
pass
|
| 687 |
+
small_args = dict(args)
|
| 688 |
+
small_args["max_new_tokens"] = min(256, args.get("max_new_tokens", 256))
|
| 689 |
+
# disable cache to save VRAM
|
| 690 |
+
with torch.inference_mode():
|
| 691 |
+
out = model_lm.generate(**inputs, **small_args, use_cache=False)
|
| 692 |
+
|
| 693 |
+
if not AUTO_CONTINUE:
|
| 694 |
+
return out
|
| 695 |
|
| 696 |
+
# Auto-continue if we hit cap without EOS
|
| 697 |
+
steps = 0
|
| 698 |
+
while steps < AUTO_CONT_MAX_STEPS:
|
| 699 |
+
seq = out[0]
|
| 700 |
+
ended_with_eos = (seq[-1].item() == tokenizer_lm.eos_token_id)
|
| 701 |
+
hit_cap = (seq.shape[0] - in_len) >= args["max_new_tokens"]
|
| 702 |
+
if ended_with_eos or not hit_cap:
|
| 703 |
+
break
|
| 704 |
+
|
| 705 |
+
cont_inputs = {
|
| 706 |
+
"input_ids": seq.unsqueeze(0),
|
| 707 |
+
"attention_mask": torch.ones_like(seq).unsqueeze(0),
|
| 708 |
+
}
|
| 709 |
+
cont_inputs = {k: v.to(device) for k, v in cont_inputs.items()}
|
| 710 |
+
cont_args = dict(args)
|
| 711 |
+
cont_args["max_new_tokens"] = AUTO_CONT_NEW_TOKENS
|
| 712 |
+
|
| 713 |
+
try:
|
| 714 |
+
with torch.inference_mode():
|
| 715 |
+
out = model_lm.generate(**cont_inputs, **cont_args, use_cache=True)
|
| 716 |
+
except torch.cuda.OutOfMemoryError:
|
| 717 |
+
try:
|
| 718 |
+
torch.cuda.empty_cache()
|
| 719 |
+
except Exception:
|
| 720 |
+
pass
|
| 721 |
+
with torch.inference_mode():
|
| 722 |
+
out = model_lm.generate(**cont_inputs, **cont_args, use_cache=False)
|
| 723 |
+
|
| 724 |
+
steps += 1
|
| 725 |
+
|
| 726 |
+
return out
|
| 727 |
+
|
| 728 |
+
# ================== ANSWER SYNTHESIS ==================
|
| 729 |
def _synthesize_answer(chunks: List[Dict[str, Any]], question: str) -> str:
|
| 730 |
prompt = build_prompt(chunks, question)
|
| 731 |
inputs = tokenizer_lm(prompt, return_tensors="pt").to(device)
|
|
|
|
| 741 |
return _synthesize_answer(chunks, question)
|
| 742 |
return _synthesize_answer(chunks, question)
|
| 743 |
|
| 744 |
+
@lru_cache(maxsize=None)
|
| 745 |
+
def direct_llm_fallback(question: str) -> str:
|
| 746 |
+
sys_prompt = (
|
| 747 |
+
"You are Askstein (orthopedic biomechanics). If you lack enough domain context, say you don’t know. "
|
| 748 |
+
"Avoid discussing non-musculoskeletal systems (cardiology, neurology). Do NOT invent references."
|
| 749 |
+
)
|
| 750 |
+
llm_prompt = f"{sys_prompt}\n\nQuestion: {question}\nAnswer:"
|
| 751 |
+
inputs = tokenizer_lm(llm_prompt, return_tensors="pt").to(device)
|
| 752 |
+
out = _generate(inputs, grounded=False)
|
| 753 |
+
in_len = inputs["input_ids"].shape[-1]
|
| 754 |
+
ans = _post_clean(_decode_generated(out, in_len))
|
| 755 |
+
ans = re.sub(r"(?is)(^|\n)\s*references?:.*$", "", ans).strip()
|
| 756 |
+
return "[LLM fallback — ungrounded]\n\n" + ans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
|
| 758 |
+
# ================== PUBLIC API ==================
|
| 759 |
def ask(question: str) -> str:
|
| 760 |
q = question.strip()
|
| 761 |
m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
|
|
|
|
| 764 |
chunks = fetch_pubmed_chunks(pmid, max_papers=1)
|
| 765 |
return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
|
| 766 |
|
| 767 |
+
if _PAPERS_INTENT.search(q) or CITE_TRIGGER.search(q):
|
| 768 |
+
core_q = re.sub(CITE_TRIGGER, "", q, count=1, flags=re.I).strip().rstrip(".")
|
| 769 |
+
core_q = re.sub(_PAPERS_INTENT, "", core_q, flags=re.I).strip()
|
| 770 |
+
if not core_q:
|
| 771 |
+
core_q = "CT/QCT structural rigidity femur hip finite element"
|
| 772 |
compact = _compact_terms(core_q)
|
| 773 |
pm_query = (
|
| 774 |
f'(({compact}) AND (hip[TiAb] OR femur[TiAb] OR femoral[TiAb])) AND '
|
|
|
|
| 776 |
'AND ("2000"[DP] : "2025"[DP])'
|
| 777 |
)
|
| 778 |
cits = fetch_pubmed_citations(pm_query, max_results=5)
|
| 779 |
+
if not cits:
|
| 780 |
+
lab = detect_lab(core_q)
|
| 781 |
+
pm_query = build_lab_query(core_q, lab=lab)
|
| 782 |
+
cits = fetch_pubmed_citations(pm_query, max_results=5)
|
| 783 |
+
if not cits:
|
| 784 |
+
cits = _fallback_cits_for("EA")
|
| 785 |
+
# Provide a short explanation + citations
|
| 786 |
+
explanation = _answer_from_chunks(retrieve_context(core_q, top_k=3), core_q) or direct_llm_fallback(core_q)
|
| 787 |
+
explanation = _post_clean(explanation)
|
| 788 |
+
if cits:
|
| 789 |
+
explanation += "\n\nCitations:\n" + "\n".join(cits)
|
| 790 |
+
return _ensure_min_answer(explanation)
|
| 791 |
+
|
| 792 |
+
det = deterministic_definitions_text(q)
|
| 793 |
+
if det:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
dlog("ASK", "Deterministic definition/workflow fired")
|
| 795 |
+
return det
|
| 796 |
|
| 797 |
if not (_MSK_MUST.search(q) or _is_fe_override(q)):
|
| 798 |
+
chunks = retrieve_context(q, top_k=3)
|
| 799 |
if chunks:
|
|
|
|
| 800 |
answer = _answer_from_chunks(chunks, q)
|
| 801 |
+
# tiny safety to release VRAM between turns
|
| 802 |
+
try:
|
| 803 |
+
torch.cuda.empty_cache()
|
| 804 |
+
except Exception:
|
| 805 |
+
pass
|
| 806 |
return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
|
| 807 |
return direct_llm_fallback(q)
|
| 808 |
|
| 809 |
+
chunks = retrieve_context(q, top_k=3)
|
| 810 |
if not chunks:
|
| 811 |
return direct_llm_fallback(q)
|
|
|
|
| 812 |
answer = _answer_from_chunks(chunks, q)
|
| 813 |
+
try:
|
| 814 |
+
torch.cuda.empty_cache()
|
| 815 |
+
except Exception:
|
| 816 |
+
pass
|
| 817 |
return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
|
| 818 |
|
| 819 |
+
def deterministic_definitions_text(core_q: str) -> Optional[str]:
|
| 820 |
+
q_lower = core_q.lower()
|
| 821 |
+
if "define axial rigidity" in q_lower or "what is axial rigidity" in q_lower:
|
| 822 |
+
return ("Axial rigidity (EA) is Σ(Eᵢ·dAᵢ) across a CT slice; units: N. "
|
| 823 |
+
"Modulus E per voxel comes from a density–modulus calibration; areas dAᵢ are voxel areas.")
|
| 824 |
+
if "define bending rigidity" in q_lower or "what is bending rigidity" in q_lower:
|
| 825 |
+
return ("Bending rigidity (EI) is Σ(Eᵢ·dAᵢ·yᵢ²) about a given axis; units: N·mm². "
|
| 826 |
+
"yᵢ is distance to the neutral axis; computed slice-by-slice from QCT.")
|
| 827 |
+
if ("define torsional rigidity" in q_lower) or ("what is torsional rigidity" in q_lower) or ("define gj" in q_lower):
|
| 828 |
+
return ("Torsional rigidity (GJ) = shear modulus G times polar moment J. "
|
| 829 |
+
"In QCT, J ≈ Σ(dAᵢ·rᵢ²) about the centroid; G ≈ E/(2(1+ν)).")
|
| 830 |
+
if "qct" in q_lower and ("torsional" in q_lower or "gj" in q_lower):
|
| 831 |
+
return ("From QCT, torsional rigidity is estimated as GJ, where J ≈ Σ(dAᵢ·rᵢ²) about the slice centroid and "
|
| 832 |
+
"G = E/(2(1+ν)) from the voxel E map (ν≈0.3). Compute per-slice and report the minimum.")
|
| 833 |
+
if re.search(r"\b(outline|steps|workflow|protocol)\b.*\b(ct|qct).*(rigidity|ea|ei|gj)", q_lower):
|
| 834 |
+
return (
|
| 835 |
+
"CT-based structural rigidity (CTRA/QCT) workflow:\n"
|
| 836 |
+
"1) Acquire QCT (≤1 mm; density phantom).\n"
|
| 837 |
+
"2) Preprocess & segment bone.\n"
|
| 838 |
+
"3) HU→ρ; ρ→E calibration.\n"
|
| 839 |
+
"4) Cross-sections along neck axis.\n"
|
| 840 |
+
"5) EA, EI_x/EI_y, GJ (G≈E/(2(1+ν))).\n"
|
| 841 |
+
"6) Extract minima & validate vs FEA/mech tests."
|
| 842 |
+
)
|
| 843 |
+
if re.search(r"\b(modulus)\b.*\brigidity\b|\bdefine\s+modulus\b", q_lower):
|
| 844 |
+
return ("Elastic modulus (E) is a material property (Pa). "
|
| 845 |
+
"Rigidity is structural (EA, EI, GJ). Modulus ≠ rigidity.")
|
| 846 |
+
return None
|
| 847 |
+
|
| 848 |
# ================== UI: NAME GATE + PER-ANSWER FEEDBACK ==================
|
| 849 |
def _now_iso():
|
| 850 |
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
|
|
|
|
| 867 |
state["last_name"] = last_name
|
| 868 |
return gr.update(visible=False), gr.update(visible=True), state, f"Welcome, {first_name}! You can start chatting."
|
| 869 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
def _push_file_to_hub(local_path: str, repo_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
if not PUSH_FEEDBACK:
|
| 872 |
return
|
| 873 |
if not os.path.exists(local_path):
|
|
|
|
| 879 |
if not HF_WRITE_TOKEN:
|
| 880 |
dlog("UPLOAD", "Skip: HF_WRITE_TOKEN not set")
|
| 881 |
return
|
|
|
|
| 882 |
try:
|
| 883 |
api = HfApi(token=HF_WRITE_TOKEN)
|
| 884 |
api.upload_file(
|
|
|
|
| 892 |
except Exception as e:
|
| 893 |
dlog("UPLOAD", f"Upload failed: {e}")
|
| 894 |
|
|
|
|
| 895 |
def _push_feedback_to_hub() -> None:
|
|
|
|
| 896 |
_push_file_to_hub(FEEDBACK_PATH, "analytics/feedback.csv")
|
| 897 |
|
| 898 |
+
def _log_turn(state: Dict[str, Any], question: str, answer: str):
|
| 899 |
+
rec = {
|
| 900 |
+
"timestamp_utc": _now_iso(),
|
| 901 |
+
"session_id": state.get("session_id", ""),
|
| 902 |
+
"first_name": state.get("first_name", ""),
|
| 903 |
+
"last_name": state.get("last_name", ""),
|
| 904 |
+
"question": question,
|
| 905 |
+
"answer": answer,
|
| 906 |
+
}
|
| 907 |
+
try:
|
| 908 |
+
with open(TRANSCRIPT_PATH, "a", encoding="utf-8") as f:
|
| 909 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 910 |
+
except Exception:
|
| 911 |
+
pass
|
| 912 |
+
try:
|
| 913 |
+
if PUSH_TRANSCRIPTS:
|
| 914 |
+
_push_file_to_hub(TRANSCRIPT_PATH, "analytics/transcripts.jsonl")
|
| 915 |
+
except Exception:
|
| 916 |
+
pass
|
| 917 |
|
| 918 |
def save_feedback(rating, comment, state):
|
| 919 |
if rating is None:
|
|
|
|
| 941 |
except Exception as e:
|
| 942 |
return f"Failed to save feedback: {e}", gr.update(visible=True)
|
| 943 |
|
| 944 |
+
def predict(message, chat_history, state):
|
| 945 |
+
msg = (message or "").strip()
|
| 946 |
+
if not msg:
|
| 947 |
+
return chat_history, "", gr.update(visible=False), None, "", state
|
| 948 |
+
|
| 949 |
+
try:
|
| 950 |
+
answer = ask(msg)
|
| 951 |
+
except Exception as e:
|
| 952 |
+
answer = f"Sorry — something went wrong: {e!r}"
|
| 953 |
+
|
| 954 |
+
chat_history = (chat_history or []) + [(msg, answer)]
|
| 955 |
+
state["last_q"] = msg
|
| 956 |
+
state["last_a"] = answer
|
| 957 |
+
|
| 958 |
+
try:
|
| 959 |
+
_log_turn(state, msg, answer)
|
| 960 |
+
except Exception:
|
| 961 |
+
pass
|
| 962 |
+
|
| 963 |
+
# free a bit of VRAM between turns
|
| 964 |
+
try:
|
| 965 |
+
torch.cuda.empty_cache()
|
| 966 |
+
except Exception:
|
| 967 |
+
pass
|
| 968 |
+
|
| 969 |
+
return (
|
| 970 |
+
chat_history,
|
| 971 |
+
"", # clear input
|
| 972 |
+
gr.update(visible=True), # show feedback pane
|
| 973 |
+
gr.update(value=None), # reset rating
|
| 974 |
+
gr.update(value=""), # reset comment
|
| 975 |
+
state
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
# ================== UI ==================
|
| 979 |
with gr.Blocks(theme="soft") as demo:
|
| 980 |
gr.Markdown("# Askstein — Orthopedic Biomechanics Chat (CT/QCT Rigidity, FE)")
|
| 981 |
gr.Markdown("Grounded answers (FAISS + PubMed). Please enter your name to continue.")
|
|
|
|
| 987 |
with gate:
|
| 988 |
with gr.Row():
|
| 989 |
first_tb = gr.Textbox(label="First name", placeholder="e.g., Shubh", scale=1)
|
| 990 |
+
last_tb = gr.Textbox(label="Last name", placeholder="e.g., Laiwala", scale=1)
|
| 991 |
enter_btn = gr.Button("Enter", variant="primary")
|
| 992 |
+
gate_msg = gr.Markdown("", elem_classes=["text-sm"])
|
| 993 |
|
| 994 |
# ---- App (hidden until gate passes) ----
|
| 995 |
app = gr.Group(visible=False)
|
|
|
|
| 1003 |
feedback_grp = gr.Group(visible=False)
|
| 1004 |
with feedback_grp:
|
| 1005 |
gr.Markdown("### How helpful was this answer?")
|
| 1006 |
+
rating = gr.Radio(choices=[1,2,3,4,5], label="Rating (1=poor, 5=great)")
|
| 1007 |
comment = gr.Textbox(label="Optional comment", placeholder="What was good or missing?")
|
| 1008 |
submit_fb = gr.Button("Submit feedback")
|
| 1009 |
fb_status = gr.Markdown("")
|
| 1010 |
|
| 1011 |
+
# ---- Wiring (must stay inside Blocks) ----
|
| 1012 |
enter_btn.click(
|
| 1013 |
fn=enter_app,
|
| 1014 |
inputs=[first_tb, last_tb, state],
|
|
|
|
| 1043 |
concurrency_limit=4,
|
| 1044 |
)
|
| 1045 |
|
| 1046 |
+
# ================== QUEUE & LAUNCH ==================
|
| 1047 |
+
demo.queue(max_size=64) # no deprecated concurrency_count
|
| 1048 |
demo.launch(max_threads=8)
|