Slaiwala commited on
Commit
7ff267a
·
verified ·
1 Parent(s): cec50ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -249
app.py CHANGED
@@ -1,29 +1,19 @@
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
 
4
  import os, re, json, time, sys, csv, uuid, datetime
5
  from typing import List, Dict, Any, Optional
6
  from functools import lru_cache
7
  from xml.etree import ElementTree as ET
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  try:
11
- from transformers import BitsAndBytesConfig # exists even if bitsandbytes isn't installed
12
  except Exception:
13
  BitsAndBytesConfig = None
14
 
15
- # Normalize QUANTIZE env
16
- QUANTIZE = os.environ.get("QUANTIZE", "none").strip().lower()
17
-
18
- # Detect bitsandbytes presence
19
- try:
20
- import bitsandbytes as _bnb # noqa: F401
21
- _BNB_AVAILABLE = True
22
- except Exception:
23
- _BNB_AVAILABLE = False
24
-
25
-
26
-
27
  import numpy as np
28
  import requests
29
  import gradio as gr
@@ -33,11 +23,13 @@ ASSETS_DIR = os.environ.get("ASSETS_DIR", "assets")
33
  FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
34
  META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
35
  REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
36
- QUANTIZE = os.environ.get("QUANTIZE", "4bit") # "none" | "8bit" | "4bit"
37
- # --- Turn logging ---
38
- TRANSCRIPT_PATH = os.environ.get("TRANSCRIPT_PATH", "transcripts.jsonl")
39
- PUSH_TRANSCRIPTS = os.environ.get("PUSH_TRANSCRIPTS", "1") == "1" # set to "0" to disable
40
 
 
 
 
 
 
 
41
 
42
  # Models
43
  BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
@@ -53,12 +45,11 @@ NCBI_TOOL = os.environ.get("NCBI_TOOL", "askstein")
53
  NCBI_APIKEY = os.environ.get("NCBI_APIKEY", "")
54
 
55
  # Feedback logging
56
- FEEDBACK_PATH = os.environ.get("FEEDBACK_PATH", "feedback.csv")
57
- PUSH_FEEDBACK = os.environ.get("PUSH_FEEDBACK", "0") == "1" # set to "1" to enable Hub upload
58
- HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN", os.environ.get("HF_TOKEN", ""))
59
- HF_WRITE_TOKEN = os.environ.get("HF_WRITE_TOKEN", HF_READ_TOKEN)
60
- SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "")
61
-
62
 
63
  # Generation / toggles
64
  ALLOW_WIKIPEDIA = False
@@ -72,7 +63,6 @@ AUTO_CONTINUE = True
72
  AUTO_CONT_MAX_STEPS = 2 # continue up to 2 extra chunks
73
  AUTO_CONT_NEW_TOKENS = 256 # tokens per continuation step
74
 
75
-
76
  def dlog(tag, msg):
77
  if DEBUG: print(f"[{tag}] {msg}")
78
 
@@ -80,12 +70,14 @@ def dlog(tag, msg):
80
  import faiss
81
  from sentence_transformers import SentenceTransformer
82
  import torch
83
- from transformers import AutoTokenizer, AutoModelForCausalLM
84
  from peft import PeftModel
85
  import wikipedia
86
  from wikipedia.exceptions import DisambiguationError, PageError
87
  from huggingface_hub import login, snapshot_download, HfApi
88
 
 
 
 
89
  # ================== GPU CHECK ==================
90
  if not torch.cuda.is_available():
91
  with gr.Blocks() as demo:
@@ -96,6 +88,8 @@ if not torch.cuda.is_available():
96
  device = "cuda"
97
  dtype = torch.float16
98
  torch.manual_seed(42)
 
 
99
 
100
  # ================== RELEVANCE CONFIG ==================
101
  DEFAULT_REL_CONFIG = {
@@ -206,7 +200,7 @@ except Exception as e:
206
 
207
  _IS_IP = isinstance(index, faiss.IndexFlatIP) or "IndexFlatIP" in type(index).__name__
208
 
209
- # ================== LOAD LLM (BASE + LORA) ==================
210
  if HF_READ_TOKEN:
211
  try:
212
  login(token=HF_READ_TOKEN)
@@ -214,17 +208,24 @@ if HF_READ_TOKEN:
214
  except Exception as e:
215
  dlog("HF", f"Login issue: {e}")
216
 
217
-
218
  if ADAPTER_REPO:
219
  ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
220
 
221
- # --- LLM load (quantized optional) ---
 
 
 
 
 
 
 
222
  dlog("LLM", f"Loading base model: {BASE_MODEL}")
223
  tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
224
 
225
  use_bnb = QUANTIZE in {"8bit", "4bit"} and BitsAndBytesConfig is not None and _BNB_AVAILABLE
226
 
227
  if use_bnb:
 
228
  bnb_config = BitsAndBytesConfig(
229
  load_in_8bit=(QUANTIZE == "8bit"),
230
  load_in_4bit=(QUANTIZE == "4bit"),
@@ -238,23 +239,28 @@ if use_bnb:
238
  quantization_config=bnb_config,
239
  )
240
  else:
241
- # Default / fallback: fp16 (no bitsandbytes required)
242
- base_model = AutoModelForCausalLM.from_pretrained(
243
- BASE_MODEL,
244
- torch_dtype=dtype,
245
- device_map="auto",
246
- )
247
-
248
-
249
-
 
 
 
 
 
250
 
251
  dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
252
  model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
253
  model_lm.eval()
254
 
255
-
256
  GEN_ARGS_GROUNDED = dict(
257
  max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
 
258
  do_sample=False,
259
  num_beams=1,
260
  no_repeat_ngram_size=3,
@@ -263,6 +269,7 @@ GEN_ARGS_GROUNDED = dict(
263
  )
264
  GEN_ARGS_FALLBACK = dict(
265
  max_new_tokens=MAX_NEW_TOKENS_FALLBACK,
 
266
  do_sample=False,
267
  num_beams=1,
268
  no_repeat_ngram_size=3,
@@ -270,38 +277,6 @@ GEN_ARGS_FALLBACK = dict(
270
  eos_token_id=tokenizer_lm.eos_token_id,
271
  )
272
 
273
- def _generate(inputs, grounded: bool):
274
- args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
275
- in_len = inputs["input_ids"].shape[-1]
276
- with torch.inference_mode():
277
- out = model_lm.generate(**inputs, **args)
278
-
279
- if not AUTO_CONTINUE:
280
- return out
281
-
282
- steps = 0
283
- while steps < AUTO_CONT_MAX_STEPS:
284
- seq = out[0]
285
- ended_with_eos = (seq[-1].item() == tokenizer_lm.eos_token_id)
286
- hit_cap = (seq.shape[0] - in_len) >= args["max_new_tokens"]
287
- if ended_with_eos or not hit_cap:
288
- break
289
-
290
- # continue generation from the current sequence
291
- cont_inputs = {
292
- "input_ids": seq.unsqueeze(0),
293
- "attention_mask": torch.ones_like(seq).unsqueeze(0),
294
- }
295
- cont_inputs = {k: v.to(device) for k, v in cont_inputs.items()}
296
- cont_args = dict(args)
297
- cont_args["max_new_tokens"] = AUTO_CONT_NEW_TOKENS
298
-
299
- out = model_lm.generate(**cont_inputs, **cont_args)
300
- steps += 1
301
-
302
- return out
303
-
304
-
305
  # ================== UTILITIES ==================
306
  _SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
307
  def _to_text(rec: Any) -> str:
@@ -437,7 +412,9 @@ _ANATOMY_OR_HISTORY = re.compile(
437
  re.I
438
  )
439
  _PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
 
440
 
 
441
  def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
442
  retries = 1
443
  chunks: List[Dict[str, Any]] = []
@@ -617,7 +594,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
617
  if results:
618
  dlog("PUBMED", "PubMed search hit")
619
  return results
620
- # Wikipedia fallback (unconditional after PubMed miss)
621
  wiki = wiki_summary_allow(q, sentences=3)
622
  if wiki:
623
  dlog("WIKI", "Wikipedia fallback hit")
@@ -625,7 +601,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
625
  dlog("RETRIEVAL", "No results found")
626
  return []
627
 
628
-
629
  # FAISS path
630
  q_emb = embed_model.encode([q], convert_to_numpy=True).astype("float32")
631
  if _IS_IP:
@@ -666,7 +641,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
666
  dlog("PUBMED", "PubMed search hit")
667
  return results
668
 
669
- # Wikipedia fallback (unconditional after PubMed miss)
670
  wiki = wiki_summary_allow(q, sentences=3)
671
  if wiki:
672
  dlog("WIKI", "Wikipedia fallback hit")
@@ -675,7 +649,6 @@ def retrieve_context(query: str, top_k: int = 10) -> List[Dict[str, Any]]:
675
  dlog("RETRIEVAL", "No results at all")
676
  return []
677
 
678
-
679
  def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
680
  header = (
681
  "You are Askstein (orthopedic biomechanics). Use ONLY the [Context] to answer. "
@@ -684,7 +657,7 @@ def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
684
  "Do not discuss cardiology, neurology, or unrelated domains."
685
  )
686
  cleaned = []
687
- per_chunk_chars = 1600
688
  for c in chunks:
689
  t = _to_text(c)
690
  if t: cleaned.append(t[:per_chunk_chars])
@@ -695,22 +668,64 @@ def _decode_generated(out_ids, in_len: int) -> str:
695
  gen = out_ids[0][in_len:]
696
  return tokenizer_lm.decode(gen, skip_special_tokens=True).lstrip(". \n").strip()
697
 
698
- @lru_cache(maxsize=None)
699
- def direct_llm_fallback(question: str) -> str:
700
- sys_prompt = (
701
- "You are Askstein (orthopedic biomechanics). If you lack enough domain context, say you don’t know. "
702
- "Avoid discussing non-musculoskeletal systems (cardiology, neurology). Do NOT invent references."
703
- )
704
- llm_prompt = f"{sys_prompt}\n\nQuestion: {question}\nAnswer:"
705
- inputs = tokenizer_lm(llm_prompt, return_tensors="pt").to(device)
706
- out = _generate(inputs, grounded=False)
707
  in_len = inputs["input_ids"].shape[-1]
708
- ans = _post_clean(_decode_generated(out, in_len))
709
- # Strip any made-up reference sections the model might add
710
- ans = re.sub(r"(?is)(^|\n)\s*references?:.*$", "", ans).strip()
711
- return "[LLM fallback — ungrounded]\n\n" + ans
712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  def _synthesize_answer(chunks: List[Dict[str, Any]], question: str) -> str:
715
  prompt = build_prompt(chunks, question)
716
  inputs = tokenizer_lm(prompt, return_tensors="pt").to(device)
@@ -726,35 +741,21 @@ def _answer_from_chunks(chunks: List[Dict[str, Any]], question: str) -> str:
726
  return _synthesize_answer(chunks, question)
727
  return _synthesize_answer(chunks, question)
728
 
729
- def deterministic_definitions_text(core_q: str) -> Optional[str]:
730
- q_lower = core_q.lower()
731
- if "define axial rigidity" in q_lower or "what is axial rigidity" in q_lower:
732
- return ("Axial rigidity (EA) is Σ(Eᵢ·dAᵢ) across a CT slice; units: N. "
733
- "Modulus E per voxel comes from a density–modulus calibration; areas dAᵢ are voxel areas.")
734
- if "define bending rigidity" in q_lower or "what is bending rigidity" in q_lower:
735
- return ("Bending rigidity (EI) is Σ(Eᵢ·dAᵢ·yᵢ²) about a given axis; units: N·mm². "
736
- "yᵢ is distance to the neutral axis; computed slice-by-slice from QCT.")
737
- if ("define torsional rigidity" in q_lower) or ("what is torsional rigidity" in q_lower) or ("define gj" in q_lower):
738
- return ("Torsional rigidity (GJ) = shear modulus G times polar moment J. "
739
- "In QCT, J ≈ Σ(dAᵢ·rᵢ²) about the centroid; G ≈ E/(2(1+ν)).")
740
- if "qct" in q_lower and ("torsional" in q_lower or "gj" in q_lower):
741
- return ("From QCT, torsional rigidity is estimated as GJ, where J ≈ Σ(dAᵢ·rᵢ²) about the slice centroid and "
742
- "G = E/(2(1+ν)) from the voxel E map (ν≈0.3). Compute per-slice and report the minimum.")
743
- if re.search(r"\b(outline|steps|workflow|protocol)\b.*\b(ct|qct).*(rigidity|ea|ei|gj)", q_lower):
744
- return (
745
- "CT-based structural rigidity (CTRA/QCT) workflow:\n"
746
- "1) Acquire QCT (≤1 mm; density phantom).\n"
747
- "2) Preprocess & segment bone.\n"
748
- "3) HU→ρ; ρ→E calibration.\n"
749
- "4) Cross-sections along neck axis.\n"
750
- "5) EA, EI_x/EI_y, GJ (G≈E/(2(1+ν))).\n"
751
- "6) Extract minima & validate vs FEA/mech tests."
752
- )
753
- if re.search(r"\b(modulus)\b.*\brigidity\b|\bdefine\s+modulus\b", q_lower):
754
- return ("Elastic modulus (E) is a material property (Pa). "
755
- "Rigidity is structural (EA, EI, GJ). Modulus ≠ rigidity.")
756
- return None
757
 
 
758
  def ask(question: str) -> str:
759
  q = question.strip()
760
  m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
@@ -763,8 +764,11 @@ def ask(question: str) -> str:
763
  chunks = fetch_pubmed_chunks(pmid, max_papers=1)
764
  return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
765
 
766
- if _PAPERS_INTENT.search(q):
767
- core_q = re.sub(_PAPERS_INTENT, "", q, flags=re.I).strip() or "CT/QCT structural rigidity femur hip finite element"
 
 
 
768
  compact = _compact_terms(core_q)
769
  pm_query = (
770
  f'(({compact}) AND (hip[TiAb] OR femur[TiAb] OR femoral[TiAb])) AND '
@@ -772,82 +776,75 @@ def ask(question: str) -> str:
772
  'AND ("2000"[DP] : "2025"[DP])'
773
  )
774
  cits = fetch_pubmed_citations(pm_query, max_results=5)
775
- return "Recommended papers:\n" + "\n".join(f"- {c}" for c in cits) if cits else "Sorry, no good matches."
776
-
777
- comp = re.match(r"(.+?)\s+and\s+(?:cite|references?|studies?|papers?)", q, flags=re.IGNORECASE)
778
- if comp:
779
- core_q = comp.group(1).strip()
780
- det_text = deterministic_definitions_text(core_q)
781
- used_term = None
782
- if det_text:
783
- explanation = det_text
784
- lq = core_q.lower()
785
- if ("torsional" in lq) or ("gj" in lq):
786
- used_term = "GJ"
787
- pm_query = ('(torsion[TiAb] OR "polar moment"[TiAb] OR GJ[TiAb]) AND '
788
- '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
789
- '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
790
- '("2000"[DP] : "2025"[DP])')
791
- elif ("bending" in lq) or ("ei" in lq):
792
- used_term = "EI"
793
- pm_query = ('(bending[TiAb] OR "second moment"[TiAb] OR EI[TiAb]) AND '
794
- '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
795
- '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
796
- '("2000"[DP] : "2025"[DP])')
797
- else:
798
- used_term = "EA"
799
- pm_query = ('("axial rigidity"[TiAb] OR EA[TiAb] OR "axial stiffness"[TiAb]) AND '
800
- '("Bone and Bones"[MeSH] OR Femur[TiAb]) AND '
801
- '("Finite Element Analysis"[MeSH] OR QCT[TiAb] OR CT[TiAb]) AND '
802
- '("2000"[DP] : "2025"[DP])')
803
- citations = fetch_pubmed_citations(pm_query, max_results=5)
804
- if not citations and used_term:
805
- dlog("CITE", f"PubMed empty → fallback {used_term}")
806
- citations = _fallback_cits_for(used_term)
807
- else:
808
- explanation = _answer_from_chunks(retrieve_context(core_q, top_k=5), core_q)
809
- pm_query = f'"{core_q}"[Title/Abstract]'
810
- citations = fetch_pubmed_citations(pm_query, max_results=5)
811
- if not citations:
812
- lab = detect_lab(core_q)
813
- pm_query = build_lab_query(core_q, lab=lab)
814
- citations = fetch_pubmed_citations(pm_query, max_results=5)
815
- if not citations:
816
- compact = _compact_terms(core_q)
817
- pm_query = (
818
- f'({compact}) AND ("Bone and Bones"[MeSH] OR Femur[TiAb] OR Hip[TiAb] '
819
- f'OR Rigidity[TiAb] OR "Tomography, X-Ray Computed"[MeSH] OR "Finite Element Analysis"[MeSH]) '
820
- f'NOT (heart[TiAb] OR cardiac[TiAb] OR brain[TiAb] OR skull[TiAb] OR EGFR[TiAb]) '
821
- f'AND ("2000"[DP] : "2025"[DP])'
822
- )
823
- citations = fetch_pubmed_citations(pm_query, max_results=5)
824
- resp = explanation
825
- if citations:
826
- resp += "\n\nCitations:\n" + "\n".join(citations)
827
- else:
828
- resp += f"\n\nSorry, no relevant citations found for “{core_q}.”"
829
- return _ensure_min_answer(_post_clean(resp))
830
-
831
- det_answer = deterministic_definitions_text(q)
832
- if det_answer:
833
  dlog("ASK", "Deterministic definition/workflow fired")
834
- return det_answer
835
 
836
  if not (_MSK_MUST.search(q) or _is_fe_override(q)):
837
- chunks = retrieve_context(q, top_k=5)
838
  if chunks:
839
- dlog("CLEAN", "Post-clean applied")
840
  answer = _answer_from_chunks(chunks, q)
 
 
 
 
 
841
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
842
  return direct_llm_fallback(q)
843
 
844
- chunks = retrieve_context(q, top_k=5)
845
  if not chunks:
846
  return direct_llm_fallback(q)
847
- dlog("CLEAN", "Post-clean applied")
848
  answer = _answer_from_chunks(chunks, q)
 
 
 
 
849
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
  # ================== UI: NAME GATE + PER-ANSWER FEEDBACK ==================
852
  def _now_iso():
853
  return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
@@ -870,63 +867,7 @@ def enter_app(first_name, last_name, state):
870
  state["last_name"] = last_name
871
  return gr.update(visible=False), gr.update(visible=True), state, f"Welcome, {first_name}! You can start chatting."
872
 
873
- def _log_turn(state: Dict[str, Any], question: str, answer: str):
874
- rec = {
875
- "timestamp_utc": _now_iso(),
876
- "session_id": state.get("session_id", ""),
877
- "first_name": state.get("first_name", ""),
878
- "last_name": state.get("last_name", ""),
879
- "question": question,
880
- "answer": answer,
881
- }
882
- with open(TRANSCRIPT_PATH, "a", encoding="utf-8") as f:
883
- f.write(json.dumps(rec, ensure_ascii=False) + "\n")
884
-
885
- if PUSH_TRANSCRIPTS:
886
- _push_file_to_hub(TRANSCRIPT_PATH, "analytics/transcripts.jsonl")
887
-
888
-
889
- def predict(message, chat_history, state):
890
- msg = (message or "").strip()
891
- if not msg:
892
- # No input → don't show feedback, just return current state
893
- return chat_history, "", gr.update(visible=False), None, "", state
894
-
895
- try:
896
- answer = ask(msg)
897
- except Exception as e:
898
- answer = f"Sorry — something went wrong: {e!r}"
899
-
900
- chat_history = (chat_history or []) + [(msg, answer)]
901
- state["last_q"] = msg
902
- state["last_a"] = answer
903
-
904
- # Log every turn (safe if _log_turn isn't defined)
905
- try:
906
- _log_turn(state, msg, answer)
907
- except Exception:
908
- pass
909
-
910
- return (
911
- chat_history,
912
- "", # clear input
913
- gr.update(visible=True), # show feedback pane
914
- gr.update(value=None), # reset rating
915
- gr.update(value=""), # reset comment
916
- state
917
- )
918
-
919
-
920
- # --- Hub upload helper --------------------------------------------------------
921
  def _push_file_to_hub(local_path: str, repo_path: str) -> None:
922
- """
923
- Upload a local file to your Space repo.
924
-
925
- Requires:
926
- - PUSH_FEEDBACK=1
927
- - HF_WRITE_TOKEN (write access to the Space)
928
- - SPACE_REPO_ID (e.g., "username/YourSpace")
929
- """
930
  if not PUSH_FEEDBACK:
931
  return
932
  if not os.path.exists(local_path):
@@ -938,7 +879,6 @@ def _push_file_to_hub(local_path: str, repo_path: str) -> None:
938
  if not HF_WRITE_TOKEN:
939
  dlog("UPLOAD", "Skip: HF_WRITE_TOKEN not set")
940
  return
941
-
942
  try:
943
  api = HfApi(token=HF_WRITE_TOKEN)
944
  api.upload_file(
@@ -952,13 +892,28 @@ def _push_file_to_hub(local_path: str, repo_path: str) -> None:
952
  except Exception as e:
953
  dlog("UPLOAD", f"Upload failed: {e}")
954
 
955
- # --- Feedback uploader --------------------------------------------------------
956
  def _push_feedback_to_hub() -> None:
957
- """Upload feedback.csv to analytics/feedback.csv in this Space repo (if enabled)."""
958
  _push_file_to_hub(FEEDBACK_PATH, "analytics/feedback.csv")
959
 
960
-
961
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962
 
963
  def save_feedback(rating, comment, state):
964
  if rating is None:
@@ -986,6 +941,41 @@ def save_feedback(rating, comment, state):
986
  except Exception as e:
987
  return f"Failed to save feedback: {e}", gr.update(visible=True)
988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
989
  with gr.Blocks(theme="soft") as demo:
990
  gr.Markdown("# Askstein — Orthopedic Biomechanics Chat (CT/QCT Rigidity, FE)")
991
  gr.Markdown("Grounded answers (FAISS + PubMed). Please enter your name to continue.")
@@ -997,9 +987,9 @@ with gr.Blocks(theme="soft") as demo:
997
  with gate:
998
  with gr.Row():
999
  first_tb = gr.Textbox(label="First name", placeholder="e.g., Shubh", scale=1)
1000
- last_tb = gr.Textbox(label="Last name", placeholder="e.g., Laiwala", scale=1)
1001
  enter_btn = gr.Button("Enter", variant="primary")
1002
- gate_msg = gr.Markdown("", elem_classes=["text-sm"])
1003
 
1004
  # ---- App (hidden until gate passes) ----
1005
  app = gr.Group(visible=False)
@@ -1013,12 +1003,12 @@ with gr.Blocks(theme="soft") as demo:
1013
  feedback_grp = gr.Group(visible=False)
1014
  with feedback_grp:
1015
  gr.Markdown("### How helpful was this answer?")
1016
- rating = gr.Radio(choices=[1, 2, 3, 4, 5], label="Rating (1=poor, 5=great)")
1017
  comment = gr.Textbox(label="Optional comment", placeholder="What was good or missing?")
1018
  submit_fb = gr.Button("Submit feedback")
1019
  fb_status = gr.Markdown("")
1020
 
1021
- # ---- Wiring (MUST stay inside the Blocks context) ----
1022
  enter_btn.click(
1023
  fn=enter_app,
1024
  inputs=[first_tb, last_tb, state],
@@ -1053,6 +1043,6 @@ with gr.Blocks(theme="soft") as demo:
1053
  concurrency_limit=4,
1054
  )
1055
 
1056
- # Queue & launch (outside the Blocks)
1057
- demo.queue(max_size=64)
1058
  demo.launch(max_threads=8)
 
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
+ # ================== STD / CORE ==================
5
  import os, re, json, time, sys, csv, uuid, datetime
6
  from typing import List, Dict, Any, Optional
7
  from functools import lru_cache
8
  from xml.etree import ElementTree as ET
9
+
10
+ # ================== TRANSFORMERS / TORCH ==================
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  try:
13
+ from transformers import BitsAndBytesConfig # may exist even if bnb isn't installed
14
  except Exception:
15
  BitsAndBytesConfig = None
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import numpy as np
18
  import requests
19
  import gradio as gr
 
23
  FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
24
  META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
25
  REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
 
 
 
 
26
 
27
+ # Normalize QUANTIZE env (default: no quantization)
28
+ QUANTIZE = os.environ.get("QUANTIZE", "none").strip().lower() # "none" | "8bit" | "4bit"
29
+
30
+ # Turn logging
31
+ TRANSCRIPT_PATH = os.environ.get("TRANSCRIPT_PATH", "transcripts.jsonl")
32
+ PUSH_TRANSCRIPTS = os.environ.get("PUSH_TRANSCRIPTS", "1") == "1" # set "0" to disable
33
 
34
  # Models
35
  BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
 
45
  NCBI_APIKEY = os.environ.get("NCBI_APIKEY", "")
46
 
47
  # Feedback logging
48
+ FEEDBACK_PATH = os.environ.get("FEEDBACK_PATH", "feedback.csv")
49
+ PUSH_FEEDBACK = os.environ.get("PUSH_FEEDBACK", "0") == "1" # set "1" to enable Hub upload
50
+ HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN", os.environ.get("HF_TOKEN", ""))
51
+ HF_WRITE_TOKEN = os.environ.get("HF_WRITE_TOKEN", HF_READ_TOKEN)
52
+ SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "")
 
53
 
54
  # Generation / toggles
55
  ALLOW_WIKIPEDIA = False
 
63
  AUTO_CONT_MAX_STEPS = 2 # continue up to 2 extra chunks
64
  AUTO_CONT_NEW_TOKENS = 256 # tokens per continuation step
65
 
 
66
  def dlog(tag, msg):
67
  if DEBUG: print(f"[{tag}] {msg}")
68
 
 
70
  import faiss
71
  from sentence_transformers import SentenceTransformer
72
  import torch
 
73
  from peft import PeftModel
74
  import wikipedia
75
  from wikipedia.exceptions import DisambiguationError, PageError
76
  from huggingface_hub import login, snapshot_download, HfApi
77
 
78
+ # ================== LOW-VRAM RUNTIME KNOBS ==================
79
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,expandable_segments:True")
80
+
81
  # ================== GPU CHECK ==================
82
  if not torch.cuda.is_available():
83
  with gr.Blocks() as demo:
 
88
  device = "cuda"
89
  dtype = torch.float16
90
  torch.manual_seed(42)
91
+ torch.backends.cuda.matmul.allow_tf32 = True
92
+ torch.backends.cudnn.allow_tf32 = True
93
 
94
  # ================== RELEVANCE CONFIG ==================
95
  DEFAULT_REL_CONFIG = {
 
200
 
201
  _IS_IP = isinstance(index, faiss.IndexFlatIP) or "IndexFlatIP" in type(index).__name__
202
 
203
+ # ================== HUGGING FACE LOGIN & ADAPTER PATH ==================
204
  if HF_READ_TOKEN:
205
  try:
206
  login(token=HF_READ_TOKEN)
 
208
  except Exception as e:
209
  dlog("HF", f"Login issue: {e}")
210
 
 
211
  if ADAPTER_REPO:
212
  ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
213
 
214
+ # ================== QUANTIZATION AVAILABILITY ==================
215
+ try:
216
+ import bitsandbytes as _bnb # noqa: F401
217
+ _BNB_AVAILABLE = True
218
+ except Exception:
219
+ _BNB_AVAILABLE = False
220
+
221
+ # ================== LLM (BASE + LORA) ==================
222
  dlog("LLM", f"Loading base model: {BASE_MODEL}")
223
  tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
224
 
225
  use_bnb = QUANTIZE in {"8bit", "4bit"} and BitsAndBytesConfig is not None and _BNB_AVAILABLE
226
 
227
  if use_bnb:
228
+ # Quantized path (only if explicitly requested and bnb is installed)
229
  bnb_config = BitsAndBytesConfig(
230
  load_in_8bit=(QUANTIZE == "8bit"),
231
  load_in_4bit=(QUANTIZE == "4bit"),
 
239
  quantization_config=bnb_config,
240
  )
241
  else:
242
+ # fp16 path with SDPA attention (lower VRAM). Fallback if not supported.
243
+ try:
244
+ base_model = AutoModelForCausalLM.from_pretrained(
245
+ BASE_MODEL,
246
+ torch_dtype=dtype,
247
+ device_map="auto",
248
+ attn_implementation="sdpa",
249
+ )
250
+ except TypeError:
251
+ base_model = AutoModelForCausalLM.from_pretrained(
252
+ BASE_MODEL,
253
+ torch_dtype=dtype,
254
+ device_map="auto",
255
+ )
256
 
257
  dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
258
  model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
259
  model_lm.eval()
260
 
 
261
  GEN_ARGS_GROUNDED = dict(
262
  max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
263
+ min_new_tokens=220,
264
  do_sample=False,
265
  num_beams=1,
266
  no_repeat_ngram_size=3,
 
269
  )
270
  GEN_ARGS_FALLBACK = dict(
271
  max_new_tokens=MAX_NEW_TOKENS_FALLBACK,
272
+ min_new_tokens=120,
273
  do_sample=False,
274
  num_beams=1,
275
  no_repeat_ngram_size=3,
 
277
  eos_token_id=tokenizer_lm.eos_token_id,
278
  )
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # ================== UTILITIES ==================
281
  _SANITIZE = re.compile(r"```.*?```|<\s*script[^>]*>.*?<\s*/\s*script\s*>", re.DOTALL|re.IGNORECASE)
282
  def _to_text(rec: Any) -> str:
 
412
  re.I
413
  )
414
  _PAPERS_INTENT = re.compile(r"\b(key\s+papers|suggest\s+papers|landmark|seminal|important|top\s+papers)\b", re.I)
415
+ CITE_TRIGGER = re.compile(r"\b(cite|citations?|references?)\b", re.I)
416
 
417
+ # ================== PUBMED & RETRIEVAL ==================
418
  def fetch_pubmed_chunks(query_or_pmid: str, max_papers: int = 3) -> List[Dict[str, Any]]:
419
  retries = 1
420
  chunks: List[Dict[str, Any]] = []
 
594
  if results:
595
  dlog("PUBMED", "PubMed search hit")
596
  return results
 
597
  wiki = wiki_summary_allow(q, sentences=3)
598
  if wiki:
599
  dlog("WIKI", "Wikipedia fallback hit")
 
601
  dlog("RETRIEVAL", "No results found")
602
  return []
603
 
 
604
  # FAISS path
605
  q_emb = embed_model.encode([q], convert_to_numpy=True).astype("float32")
606
  if _IS_IP:
 
641
  dlog("PUBMED", "PubMed search hit")
642
  return results
643
 
 
644
  wiki = wiki_summary_allow(q, sentences=3)
645
  if wiki:
646
  dlog("WIKI", "Wikipedia fallback hit")
 
649
  dlog("RETRIEVAL", "No results at all")
650
  return []
651
 
 
652
  def build_prompt(chunks: List[Dict[str, Any]], question: str) -> str:
653
  header = (
654
  "You are Askstein (orthopedic biomechanics). Use ONLY the [Context] to answer. "
 
657
  "Do not discuss cardiology, neurology, or unrelated domains."
658
  )
659
  cleaned = []
660
+ per_chunk_chars = 900 # lower prompt length = lower KV memory
661
  for c in chunks:
662
  t = _to_text(c)
663
  if t: cleaned.append(t[:per_chunk_chars])
 
668
  gen = out_ids[0][in_len:]
669
  return tokenizer_lm.decode(gen, skip_special_tokens=True).lstrip(". \n").strip()
670
 
671
+ def _gen_once(inputs, args) -> Any:
672
+ with torch.inference_mode():
673
+ return model_lm.generate(**inputs, **args, use_cache=True)
674
+
675
+ def _generate(inputs, grounded: bool):
676
+ args = GEN_ARGS_GROUNDED if grounded else GEN_ARGS_FALLBACK
 
 
 
677
  in_len = inputs["input_ids"].shape[-1]
 
 
 
 
678
 
679
+ # First attempt
680
+ try:
681
+ out = _gen_once(inputs, args)
682
+ except torch.cuda.OutOfMemoryError:
683
+ try:
684
+ torch.cuda.empty_cache()
685
+ except Exception:
686
+ pass
687
+ small_args = dict(args)
688
+ small_args["max_new_tokens"] = min(256, args.get("max_new_tokens", 256))
689
+ # disable cache to save VRAM
690
+ with torch.inference_mode():
691
+ out = model_lm.generate(**inputs, **small_args, use_cache=False)
692
+
693
+ if not AUTO_CONTINUE:
694
+ return out
695
 
696
+ # Auto-continue if we hit cap without EOS
697
+ steps = 0
698
+ while steps < AUTO_CONT_MAX_STEPS:
699
+ seq = out[0]
700
+ ended_with_eos = (seq[-1].item() == tokenizer_lm.eos_token_id)
701
+ hit_cap = (seq.shape[0] - in_len) >= args["max_new_tokens"]
702
+ if ended_with_eos or not hit_cap:
703
+ break
704
+
705
+ cont_inputs = {
706
+ "input_ids": seq.unsqueeze(0),
707
+ "attention_mask": torch.ones_like(seq).unsqueeze(0),
708
+ }
709
+ cont_inputs = {k: v.to(device) for k, v in cont_inputs.items()}
710
+ cont_args = dict(args)
711
+ cont_args["max_new_tokens"] = AUTO_CONT_NEW_TOKENS
712
+
713
+ try:
714
+ with torch.inference_mode():
715
+ out = model_lm.generate(**cont_inputs, **cont_args, use_cache=True)
716
+ except torch.cuda.OutOfMemoryError:
717
+ try:
718
+ torch.cuda.empty_cache()
719
+ except Exception:
720
+ pass
721
+ with torch.inference_mode():
722
+ out = model_lm.generate(**cont_inputs, **cont_args, use_cache=False)
723
+
724
+ steps += 1
725
+
726
+ return out
727
+
728
+ # ================== ANSWER SYNTHESIS ==================
729
  def _synthesize_answer(chunks: List[Dict[str, Any]], question: str) -> str:
730
  prompt = build_prompt(chunks, question)
731
  inputs = tokenizer_lm(prompt, return_tensors="pt").to(device)
 
741
  return _synthesize_answer(chunks, question)
742
  return _synthesize_answer(chunks, question)
743
 
744
+ @lru_cache(maxsize=None)
745
+ def direct_llm_fallback(question: str) -> str:
746
+ sys_prompt = (
747
+ "You are Askstein (orthopedic biomechanics). If you lack enough domain context, say you don’t know. "
748
+ "Avoid discussing non-musculoskeletal systems (cardiology, neurology). Do NOT invent references."
749
+ )
750
+ llm_prompt = f"{sys_prompt}\n\nQuestion: {question}\nAnswer:"
751
+ inputs = tokenizer_lm(llm_prompt, return_tensors="pt").to(device)
752
+ out = _generate(inputs, grounded=False)
753
+ in_len = inputs["input_ids"].shape[-1]
754
+ ans = _post_clean(_decode_generated(out, in_len))
755
+ ans = re.sub(r"(?is)(^|\n)\s*references?:.*$", "", ans).strip()
756
+ return "[LLM fallback ungrounded]\n\n" + ans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
 
758
+ # ================== PUBLIC API ==================
759
  def ask(question: str) -> str:
760
  q = question.strip()
761
  m = re.search(r"pmid[:\s]*(\d+)", q, re.IGNORECASE)
 
764
  chunks = fetch_pubmed_chunks(pmid, max_papers=1)
765
  return "\n".join(c.get("text", "") for c in chunks) or "Sorry, no abstract found."
766
 
767
+ if _PAPERS_INTENT.search(q) or CITE_TRIGGER.search(q):
768
+ core_q = re.sub(CITE_TRIGGER, "", q, count=1, flags=re.I).strip().rstrip(".")
769
+ core_q = re.sub(_PAPERS_INTENT, "", core_q, flags=re.I).strip()
770
+ if not core_q:
771
+ core_q = "CT/QCT structural rigidity femur hip finite element"
772
  compact = _compact_terms(core_q)
773
  pm_query = (
774
  f'(({compact}) AND (hip[TiAb] OR femur[TiAb] OR femoral[TiAb])) AND '
 
776
  'AND ("2000"[DP] : "2025"[DP])'
777
  )
778
  cits = fetch_pubmed_citations(pm_query, max_results=5)
779
+ if not cits:
780
+ lab = detect_lab(core_q)
781
+ pm_query = build_lab_query(core_q, lab=lab)
782
+ cits = fetch_pubmed_citations(pm_query, max_results=5)
783
+ if not cits:
784
+ cits = _fallback_cits_for("EA")
785
+ # Provide a short explanation + citations
786
+ explanation = _answer_from_chunks(retrieve_context(core_q, top_k=3), core_q) or direct_llm_fallback(core_q)
787
+ explanation = _post_clean(explanation)
788
+ if cits:
789
+ explanation += "\n\nCitations:\n" + "\n".join(cits)
790
+ return _ensure_min_answer(explanation)
791
+
792
+ det = deterministic_definitions_text(q)
793
+ if det:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  dlog("ASK", "Deterministic definition/workflow fired")
795
+ return det
796
 
797
  if not (_MSK_MUST.search(q) or _is_fe_override(q)):
798
+ chunks = retrieve_context(q, top_k=3)
799
  if chunks:
 
800
  answer = _answer_from_chunks(chunks, q)
801
+ # tiny safety to release VRAM between turns
802
+ try:
803
+ torch.cuda.empty_cache()
804
+ except Exception:
805
+ pass
806
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
807
  return direct_llm_fallback(q)
808
 
809
+ chunks = retrieve_context(q, top_k=3)
810
  if not chunks:
811
  return direct_llm_fallback(q)
 
812
  answer = _answer_from_chunks(chunks, q)
813
+ try:
814
+ torch.cuda.empty_cache()
815
+ except Exception:
816
+ pass
817
  return _ensure_min_answer(_post_clean(answer)) or direct_llm_fallback(q)
818
 
819
+ def deterministic_definitions_text(core_q: str) -> Optional[str]:
820
+ q_lower = core_q.lower()
821
+ if "define axial rigidity" in q_lower or "what is axial rigidity" in q_lower:
822
+ return ("Axial rigidity (EA) is Σ(Eᵢ·dAᵢ) across a CT slice; units: N. "
823
+ "Modulus E per voxel comes from a density–modulus calibration; areas dAᵢ are voxel areas.")
824
+ if "define bending rigidity" in q_lower or "what is bending rigidity" in q_lower:
825
+ return ("Bending rigidity (EI) is Σ(Eᵢ·dAᵢ·yᵢ²) about a given axis; units: N·mm². "
826
+ "yᵢ is distance to the neutral axis; computed slice-by-slice from QCT.")
827
+ if ("define torsional rigidity" in q_lower) or ("what is torsional rigidity" in q_lower) or ("define gj" in q_lower):
828
+ return ("Torsional rigidity (GJ) = shear modulus G times polar moment J. "
829
+ "In QCT, J ≈ Σ(dAᵢ·rᵢ²) about the centroid; G ≈ E/(2(1+ν)).")
830
+ if "qct" in q_lower and ("torsional" in q_lower or "gj" in q_lower):
831
+ return ("From QCT, torsional rigidity is estimated as GJ, where J ≈ Σ(dAᵢ·rᵢ²) about the slice centroid and "
832
+ "G = E/(2(1+ν)) from the voxel E map (ν≈0.3). Compute per-slice and report the minimum.")
833
+ if re.search(r"\b(outline|steps|workflow|protocol)\b.*\b(ct|qct).*(rigidity|ea|ei|gj)", q_lower):
834
+ return (
835
+ "CT-based structural rigidity (CTRA/QCT) workflow:\n"
836
+ "1) Acquire QCT (≤1 mm; density phantom).\n"
837
+ "2) Preprocess & segment bone.\n"
838
+ "3) HU→ρ; ρ→E calibration.\n"
839
+ "4) Cross-sections along neck axis.\n"
840
+ "5) EA, EI_x/EI_y, GJ (G≈E/(2(1+ν))).\n"
841
+ "6) Extract minima & validate vs FEA/mech tests."
842
+ )
843
+ if re.search(r"\b(modulus)\b.*\brigidity\b|\bdefine\s+modulus\b", q_lower):
844
+ return ("Elastic modulus (E) is a material property (Pa). "
845
+ "Rigidity is structural (EA, EI, GJ). Modulus ≠ rigidity.")
846
+ return None
847
+
848
  # ================== UI: NAME GATE + PER-ANSWER FEEDBACK ==================
849
  def _now_iso():
850
  return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat()
 
867
  state["last_name"] = last_name
868
  return gr.update(visible=False), gr.update(visible=True), state, f"Welcome, {first_name}! You can start chatting."
869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
870
  def _push_file_to_hub(local_path: str, repo_path: str) -> None:
 
 
 
 
 
 
 
 
871
  if not PUSH_FEEDBACK:
872
  return
873
  if not os.path.exists(local_path):
 
879
  if not HF_WRITE_TOKEN:
880
  dlog("UPLOAD", "Skip: HF_WRITE_TOKEN not set")
881
  return
 
882
  try:
883
  api = HfApi(token=HF_WRITE_TOKEN)
884
  api.upload_file(
 
892
  except Exception as e:
893
  dlog("UPLOAD", f"Upload failed: {e}")
894
 
 
895
  def _push_feedback_to_hub() -> None:
 
896
  _push_file_to_hub(FEEDBACK_PATH, "analytics/feedback.csv")
897
 
898
+ def _log_turn(state: Dict[str, Any], question: str, answer: str):
899
+ rec = {
900
+ "timestamp_utc": _now_iso(),
901
+ "session_id": state.get("session_id", ""),
902
+ "first_name": state.get("first_name", ""),
903
+ "last_name": state.get("last_name", ""),
904
+ "question": question,
905
+ "answer": answer,
906
+ }
907
+ try:
908
+ with open(TRANSCRIPT_PATH, "a", encoding="utf-8") as f:
909
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
910
+ except Exception:
911
+ pass
912
+ try:
913
+ if PUSH_TRANSCRIPTS:
914
+ _push_file_to_hub(TRANSCRIPT_PATH, "analytics/transcripts.jsonl")
915
+ except Exception:
916
+ pass
917
 
918
  def save_feedback(rating, comment, state):
919
  if rating is None:
 
941
  except Exception as e:
942
  return f"Failed to save feedback: {e}", gr.update(visible=True)
943
 
944
+ def predict(message, chat_history, state):
945
+ msg = (message or "").strip()
946
+ if not msg:
947
+ return chat_history, "", gr.update(visible=False), None, "", state
948
+
949
+ try:
950
+ answer = ask(msg)
951
+ except Exception as e:
952
+ answer = f"Sorry — something went wrong: {e!r}"
953
+
954
+ chat_history = (chat_history or []) + [(msg, answer)]
955
+ state["last_q"] = msg
956
+ state["last_a"] = answer
957
+
958
+ try:
959
+ _log_turn(state, msg, answer)
960
+ except Exception:
961
+ pass
962
+
963
+ # free a bit of VRAM between turns
964
+ try:
965
+ torch.cuda.empty_cache()
966
+ except Exception:
967
+ pass
968
+
969
+ return (
970
+ chat_history,
971
+ "", # clear input
972
+ gr.update(visible=True), # show feedback pane
973
+ gr.update(value=None), # reset rating
974
+ gr.update(value=""), # reset comment
975
+ state
976
+ )
977
+
978
+ # ================== UI ==================
979
  with gr.Blocks(theme="soft") as demo:
980
  gr.Markdown("# Askstein — Orthopedic Biomechanics Chat (CT/QCT Rigidity, FE)")
981
  gr.Markdown("Grounded answers (FAISS + PubMed). Please enter your name to continue.")
 
987
  with gate:
988
  with gr.Row():
989
  first_tb = gr.Textbox(label="First name", placeholder="e.g., Shubh", scale=1)
990
+ last_tb = gr.Textbox(label="Last name", placeholder="e.g., Laiwala", scale=1)
991
  enter_btn = gr.Button("Enter", variant="primary")
992
+ gate_msg = gr.Markdown("", elem_classes=["text-sm"])
993
 
994
  # ---- App (hidden until gate passes) ----
995
  app = gr.Group(visible=False)
 
1003
  feedback_grp = gr.Group(visible=False)
1004
  with feedback_grp:
1005
  gr.Markdown("### How helpful was this answer?")
1006
+ rating = gr.Radio(choices=[1,2,3,4,5], label="Rating (1=poor, 5=great)")
1007
  comment = gr.Textbox(label="Optional comment", placeholder="What was good or missing?")
1008
  submit_fb = gr.Button("Submit feedback")
1009
  fb_status = gr.Markdown("")
1010
 
1011
+ # ---- Wiring (must stay inside Blocks) ----
1012
  enter_btn.click(
1013
  fn=enter_app,
1014
  inputs=[first_tb, last_tb, state],
 
1043
  concurrency_limit=4,
1044
  )
1045
 
1046
+ # ================== QUEUE & LAUNCH ==================
1047
+ demo.queue(max_size=64) # no deprecated concurrency_count
1048
  demo.launch(max_threads=8)