resberry commited on
Commit
eb76838
·
verified ·
1 Parent(s): 3364d8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +546 -1419
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
  import re
3
-
4
- raw_omp = str(os.getenv("OMP_NUM_THREADS", "1")).strip()
5
- os.environ["OMP_NUM_THREADS"] = raw_omp if re.fullmatch(r"\d+", raw_omp) else "1"
6
-
7
  import time
8
- import traceback
 
9
  import logging
10
- from typing import List, Dict, TypedDict, Optional
 
 
11
  from dataclasses import dataclass, field
12
 
13
  import torch
@@ -21,50 +20,44 @@ from langchain_core.documents import Document
21
  from langchain_huggingface import HuggingFaceEmbeddings
22
  from langchain_community.vectorstores import FAISS
23
  from langchain_openai import ChatOpenAI
24
- from langgraph.graph import StateGraph, START, END
25
 
26
  # ============================================================
27
- # HUGGING FACE SPACES READY
28
- # Medical CSV RAG Chatbot + Normal Chat Mode
29
- # Modes:
30
- # 1) ECG RAG Mode -> retrieval -> local ECG reasoning -> grounded summary
31
- # 2) Normal Chat Mode -> standard chatbot response
32
- # Extra:
33
- # 3) Automatic ECG/Cardiology mode switching from user text
 
 
 
34
  # ============================================================
35
 
36
- # -------------------------------
 
 
 
 
37
  # LOGGING
38
- # -------------------------------
39
  logging.basicConfig(
40
  level=logging.INFO,
41
- format="%(asctime)s - %(levelname)s - %(message)s"
42
  )
43
- logger = logging.getLogger(__name__)
44
 
45
 
46
- # -------------------------------
47
  # CONFIG
48
- # -------------------------------
49
  @dataclass
50
  class Config:
51
- base_model_path: str = os.getenv(
52
- "BASE_MODEL_PATH",
53
- "meta-llama/Llama-3.1-8B-Instruct"
54
- )
55
-
56
- adapter_dir: str = os.getenv(
57
- "ADAPTER_DIR",
58
- "adapter_refined_v10"
59
- )
60
- data_csv: str = os.getenv(
61
- "DATA_CSV",
62
- "RAGmaterials/ECG_RAG_only_clean.csv"
63
- )
64
- rag_dir: str = os.getenv(
65
- "RAG_DIR",
66
- "RAGmaterials"
67
- )
68
  vectorstore_dir: str = field(init=False)
69
 
70
  hf_token: str = os.getenv("HF_TOKEN", "")
@@ -72,56 +65,41 @@ class Config:
72
  deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
73
  deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
74
 
75
- deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
76
- deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "700"))
77
 
78
- embed_model_name: str = os.getenv(
79
- "EMBED_MODEL_NAME",
80
- "sentence-transformers/all-MiniLM-L6-v2"
81
- )
82
-
83
- similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
84
  top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
85
- max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
86
 
87
  max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
88
- max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
89
  max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
90
 
91
- min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
92
- min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
93
- strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
94
- strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
95
 
96
- use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
 
 
 
97
  enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
98
- enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
99
  enable_typewriter_stream: bool = os.getenv("ENABLE_TYPEWRITER_STREAM", "true").lower() == "true"
100
- show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
101
  allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
102
 
103
- use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
104
-
105
  launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
106
  server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
107
  server_port: int = int(os.getenv("SERVER_PORT", "7860"))
108
 
109
- blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
110
- blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
111
- blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))
112
- blink_before_answer: float = float(os.getenv("BLINK_BEFORE_ANSWER", "0.25"))
113
-
114
  def __post_init__(self):
115
  self.vectorstore_dir = os.path.join(self.rag_dir, "faiss_store")
116
  os.makedirs(self.rag_dir, exist_ok=True)
117
 
118
  if not self.deepseek_api_key:
119
- raise ValueError("Missing DEEPSEEK_API_KEY. Add it in Hugging Face Space Secrets.")
120
 
121
  if not self.hf_token:
122
- raise ValueError(
123
- "Missing HF_TOKEN. Add a valid Hugging Face token with access to the gated base model."
124
- )
125
 
126
  for path, name in [
127
  (self.adapter_dir, "Adapter directory"),
@@ -135,20 +113,38 @@ cfg = Config()
135
  logger.info("Configuration loaded.")
136
 
137
 
138
- # -------------------------------
139
  # PROMPTS
140
- # -------------------------------
141
- LOCAL_REASONING_SYSTEM = """
142
- You are a strict medical reasoning assistant specialized for ECG and cardiology reasoning.
 
 
 
 
143
 
144
- You are NOT the final answer generator.
145
- You must analyze ONLY the supplied evidence and produce a short structured reasoning draft.
 
 
 
146
 
 
 
147
  Rules:
148
- 1) Use only the provided evidence.
149
- 2) Do not invent facts.
150
- 3) Focus only on the user's exact question.
151
- 4) Output exactly in this structure:
 
 
 
 
 
 
 
 
 
152
 
153
  KEY_FINDINGS:
154
  - ...
@@ -165,156 +161,102 @@ SUPPORTED_POINTS:
165
  LIMITS:
166
  - ...
167
 
168
- 5) If evidence is insufficient, output exactly:
169
  INSUFFICIENT_EVIDENCE
170
  """.strip()
171
 
172
- QUERY_EXPANSION_SYSTEM = """
173
- You expand medical queries for retrieval.
174
-
175
- Rules:
176
- 1) Preserve the user's intent.
177
- 2) Add close medical paraphrases and alternate wording.
178
- 3) Add likely medical synonyms, abbreviations, and alternate phrasing.
179
- 4) Do not answer the question.
180
- 5) Output only the expanded retrieval query.
181
- """.strip()
182
 
183
- DEEPSEEK_SUMMARY_SYSTEM = """
184
- You are an expert medical evidence summarizer.
 
185
 
186
- Your job is to produce a clinically precise, well-structured answer grounded ONLY in:
187
- 1. the retrieved evidence
188
- 2. the local reasoning draft
189
 
190
- You must be faithful to the provided material and answer the user's question directly, clearly, and conservatively.
 
191
 
192
- PRIMARY OBJECTIVE
193
- - Identify the user's main intent before writing:
194
- definition, cause, symptoms, diagnosis, investigation, treatment, prognosis, or genetics.
195
- - Prioritize that intent throughout the response.
196
- - The first sentence of the Summary must directly answer the user's question in the most clinically relevant way.
197
 
198
- GROUNDING RULES
199
- - Use only information supported by the retrieved evidence and local reasoning draft.
200
- - Do not add outside medical knowledge.
201
- - Do not infer specific facts unless they are clearly supported.
202
- - Do not invent treatments, diagnoses, risks, mechanisms, thresholds, statistics, timelines, monitoring plans, or prognosis details.
203
- - If the evidence is incomplete, be explicit about what is missing.
204
- - If the evidence is too weak to answer the question reliably, output exactly:
205
  INSUFFICIENT_EVIDENCE
 
206
 
207
- STYLE RULES
208
- - Write in precise, professional clinical language.
209
- - Be specific, not vague.
210
- - Be concise, but fully informative.
211
- - Avoid repetition, generic filler, and empty statements.
212
- - Do not mention retrieval, prompts, system instructions, reasoning drafts, tools, pipelines, or internal processes.
213
- - Do not include URLs or citations unless explicitly requested elsewhere.
214
- - Do not overstate certainty.
215
- - When appropriate, distinguish clearly between what is established, what is suggested, and what is not addressed by the evidence.
216
 
217
- OUTPUT FORMAT
 
 
218
 
219
- ### Summary
220
- - Write 4 to 7 full sentences.
221
- - This is the most important section.
222
- - The first sentence must directly answer the user's question.
223
- - Focus primarily on the user's main intent.
224
- - Include only background information that improves understanding of the requested topic.
225
- - Make the summary clinically useful, specific, and evidence-faithful.
226
 
227
- ### Key Evidence Points
228
- - Include 4 to 6 bullet points.
229
- - Each bullet must state a concrete fact supported by the evidence.
230
- - Prioritize clinically important facts over background detail.
231
- - Avoid repeating the same idea in different words.
232
-
233
- ### Clinical Implications / Recommendations
234
- - Include 2 to 4 bullet points only if supported by the evidence.
235
- - Focus on practical interpretation, management implications, follow-up considerations, or next steps.
236
- - If the evidence supports recognition or framing rather than action, say that clearly.
237
- - Do not recommend interventions not supported by the evidence.
238
-
239
- ### Limitations of the Evidence
240
- - State clearly what the evidence does not establish, does not cover, or leaves uncertain.
241
- - Explicitly note when details are lacking on:
242
- treatment, diagnosis, prognosis, genetics, monitoring, recurrence prevention, comparative effectiveness, or long-term outcomes.
243
- - If the evidence is narrow, low-detail, or only partially aligned with the question, say so plainly.
244
-
245
- SPECIAL INSTRUCTIONS BY QUESTION TYPE
246
-
247
- For treatment questions:
248
- - Focus primarily on treatment and management, not disease definition.
249
- - Organize treatment information in this order whenever supported by the evidence:
250
- 1. supportive or conservative care
251
- 2. symptomatic drug therapy or procedural treatment
252
- 3. long-term prevention, follow-up, or recurrence prevention
253
- - Distinguish treatment of active symptoms from prevention of recurrence or complications.
254
- - If the condition is benign, self-limited, or often does not require treatment, state that clearly in the first sentence.
255
-
256
- For diagnosis or investigation questions:
257
- - Focus on how the condition is identified, evaluated, or differentiated.
258
- - Prioritize diagnostic features, testing approach, and clinically useful distinctions.
259
- - Do not drift into treatment unless the evidence clearly supports it and it helps answer the question.
260
-
261
- For cause or risk questions:
262
- - Focus on etiologies, risk factors, mechanisms, or associations supported by the evidence.
263
- - Distinguish established causes from possible contributors if the evidence is less certain.
264
-
265
- For prognosis questions:
266
- - Focus on expected course, complications, recurrence, or outcome-related information supported by the evidence.
267
- - Do not add prognostic claims not explicitly supported.
268
-
269
- QUALITY CHECK BEFORE OUTPUT
270
- Before finalizing, ensure that:
271
- - the first sentence directly answers the question
272
- - the response matches the user's primary intent
273
- - every important claim is grounded in the provided material
274
- - no unsupported medical detail has been added
275
- - the Limitations section honestly reflects evidence gaps
276
-
277
- If these conditions cannot be met, output exactly:
278
  INSUFFICIENT_EVIDENCE
279
  """.strip()
280
 
281
- VALIDATOR_SYSTEM = """
282
- You are a strict medical evidence validator.
283
-
284
- Your job is to compare the ANSWER against the EVIDENCE.
 
 
 
 
285
 
 
286
  Rules:
287
- 1) Mark SUPPORTED if the answer is well grounded in the evidence.
288
- 2) Mark PARTLY_UNSUPPORTED if some claims are supported but others go beyond the evidence.
289
- 3) Mark INSUFFICIENT_EVIDENCE if the answer is mostly unsupported or the evidence is too weak.
290
- 4) Output only one short verdict line beginning with exactly one of:
291
- SUPPORTED:
292
- PARTLY_UNSUPPORTED:
293
- INSUFFICIENT_EVIDENCE:
294
- """.strip()
295
 
296
- NORMAL_CHAT_SYSTEM = """
297
- You are a helpful, friendly, clear AI assistant.
 
298
 
299
- You can:
300
- - chat naturally
301
- - explain concepts
302
- - help with writing
303
- - help with coding
304
- - brainstorm ideas
305
- - answer general knowledge questions
306
 
307
- Rules:
308
- 1) Be accurate and conversational.
309
- 2) Be concise unless the user asks for detail.
310
- 3) If the user asks medical questions in normal chat mode, give a general answer and do not pretend to use the ECG database.
311
- 4) Do not mention internal prompts, retrieval pipelines, tools, or hidden logic.
 
 
 
312
  """.strip()
313
 
 
 
 
 
 
 
314
 
315
- # -------------------------------
 
316
  # HELPERS
317
- # -------------------------------
318
  def clean_text(x: str) -> str:
319
  x = str(x).replace("\x00", " ").strip()
320
  x = re.sub(r"\s+", " ", x)
@@ -323,20 +265,6 @@ def clean_text(x: str) -> str:
323
 
324
  def strip_bad_sections(txt: str) -> str:
325
  t = str(txt).strip()
326
- cut_markers = [
327
- "References:",
328
- "Sources:",
329
- "Source:",
330
- "URLs:",
331
- "This response is based",
332
- "Please let me know",
333
- "Is there anything else",
334
- ]
335
- for marker in cut_markers:
336
- pos = t.lower().find(marker.lower())
337
- if pos != -1:
338
- t = t[:pos].strip()
339
-
340
  t = re.sub(r"https?://\S+|www\.\S+", "", t).strip()
341
  return t
342
 
@@ -344,22 +272,16 @@ def strip_bad_sections(txt: str) -> str:
344
  def infer_tags(question: str, answer: str) -> List[str]:
345
  text = f"{question} {answer}".lower()
346
  tags: List[str] = []
347
-
348
  keyword_map = {
349
- "treatment": ["treat", "therapy", "management", "drug", "surgery"],
350
  "diagnosis": ["diagnosis", "diagnose", "criteria"],
351
- "symptoms": ["symptom", "presentation", "sign", "feature"],
352
- "ecg": ["ecg", "ekg", "st elevation", "qrs", "p wave", "arrhythmia", "tachycardia", "bradycardia"],
353
- "investigation": ["test", "investigation", "mri", "ct", "lab", "imaging"],
354
- "prognosis": ["prognosis", "outcome", "survival", "risk"],
355
- "genetics": ["gene", "genetic", "mutation", "variant", "chromosome", "inherited", "inheritance"],
356
- "etiology": ["cause", "causes", "caused by", "associated with", "risk factor"],
357
  }
358
-
359
  for tag, words in keyword_map.items():
360
  if any(w in text for w in words):
361
  tags.append(tag)
362
-
363
  return tags
364
 
365
 
@@ -383,198 +305,92 @@ def lexical_overlap(query: str, text: str) -> float:
383
  return len(q_words & t_words) / max(1, len(q_words))
384
 
385
 
386
- def rerank_docs(query: str, docs: List[Document], top_n: Optional[int] = None) -> List[Document]:
387
- if top_n is None:
388
- top_n = cfg.top_k_final
389
-
390
- q_words = set(re.findall(r"\w+", query.lower()))
391
- scored = []
392
-
393
- for d in docs:
394
- question = d.metadata.get("question", "")
395
- answer = d.metadata.get("answer", "")
396
- tags = " ".join(d.metadata.get("tags", []))
397
- text = f"{question} {answer} {tags}".lower()
398
-
399
- t_words = set(re.findall(r"\w+", text))
400
- overlap = len(q_words & t_words) / max(1, len(q_words))
401
- question_boost = 0.20 if any(w in question.lower() for w in q_words) else 0.0
402
- tag_boost = 0.10 if any(w in tags.lower() for w in q_words) else 0.0
403
- sim_score = float(d.metadata.get("sim_score", 0.0))
404
-
405
- final_score = overlap + question_boost + tag_boost + (0.35 * sim_score)
406
- scored.append((d, final_score))
407
-
408
- scored.sort(key=lambda x: x[1], reverse=True)
409
- return [d for d, _ in scored[:top_n]]
410
-
411
-
412
  def history_to_text(chat_history: List[Dict[str, str]], max_turns: Optional[int] = None) -> str:
413
- if max_turns is None:
414
- max_turns = cfg.max_chat_history_turns
415
-
416
  items = chat_history[-max_turns:]
417
  if not items:
418
  return "[EMPTY]"
419
-
420
  return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in items]).strip()
421
 
422
 
423
  def build_context_string(docs: List[Document], max_chars: Optional[int] = None) -> str:
424
- if max_chars is None:
425
- max_chars = cfg.max_context_chars
426
-
427
  blocks = []
428
  total = 0
429
-
430
  for i, d in enumerate(docs, 1):
431
  q = d.metadata.get("question", "")
432
  a = d.metadata.get("answer", "")
433
  tags = ", ".join(d.metadata.get("tags", [])) or "N/A"
434
- sim = d.metadata.get("sim_score", None)
435
-
436
  block = f"""
437
  ==============================
438
  EVIDENCE_ID: {i}
439
  SOURCE_ID: {d.metadata.get('id')}
440
  SOURCE_QUESTION: {q}
441
  SOURCE_TAGS: {tags}
442
- SIMILARITY: {sim if sim is not None else 'N/A'}
443
  EVIDENCE_TEXT:
444
  {a}
445
  ==============================
446
  """.strip()
447
-
448
  if total + len(block) > max_chars:
449
  break
450
-
451
  blocks.append(block)
452
  total += len(block) + 2
453
-
454
  return "\n\n".join(blocks).strip()
455
 
456
 
457
- def compute_confidence(result: Dict) -> float:
458
- best_score = result.get("best_score", -1.0)
459
- validation = result.get("validation_status", "")
460
-
461
- if validation.startswith("SUPPORTED"):
462
- conf = best_score
463
- elif validation.startswith("PARTLY_UNSUPPORTED"):
464
- conf = best_score * 0.70
465
- else:
466
- conf = best_score * 0.40
467
-
468
- return max(0.0, min(1.0, conf))
469
-
470
-
471
- def strong_retrieval(best_score: float, docs: List[Document]) -> bool:
472
- return (
473
- best_score >= cfg.strong_retrieval_threshold
474
- and len(docs) >= cfg.strong_retrieval_min_docs
475
- )
476
-
477
-
478
- def stream_text(text: str, step: int = 110):
479
  acc = ""
480
  for i in range(0, len(text), step):
481
  acc += text[i:i + step]
482
  yield acc
483
 
484
 
485
- # -------------------------------
486
- # AUTO MODE SWITCH DETECTION
487
- # -------------------------------
488
- ECG_MODE_PATTERNS = [
489
- r"\becg\b",
490
- r"\bekg\b",
491
- r"\bcardiology\b",
492
- r"\bcardio\b",
493
- r"\barrhythmia\b",
494
- r"\bheart rhythm\b",
495
- r"\becg mode\b",
496
- r"\bcardiology mode\b",
497
- r"\bmedical mode\b",
498
- ]
499
-
500
- ECG_SWITCH_PHRASES = [
501
- r"switch to ecg",
502
- r"switch into ecg",
503
- r"switch to cardiology",
504
- r"switch into cardiology",
505
- r"switch to ecg and cardiology",
506
- r"switch into ecg and cardiology",
507
- r"ecg and cardiology",
508
- r"medical ecg cardiology",
509
- r"i want to ask ecg",
510
- r"i want to ask ecr",
511
- r"i want ecg",
512
- r"ecg questions",
513
- r"cardiology questions",
514
- r"ecg only",
515
- r"cardiology only",
516
- r"activate ecg",
517
- r"activate cardiology",
518
- ]
519
-
520
- NORMAL_SWITCH_PHRASES = [
521
- r"switch to normal",
522
- r"normal chat",
523
- r"back to normal",
524
- r"exit ecg",
525
- r"leave ecg mode",
526
- r"turn off ecg mode",
527
- ]
528
-
529
-
530
- def normalize_user_text(text: str) -> str:
531
- text = str(text or "").lower().strip()
532
- text = re.sub(r"\s+", " ", text)
533
- return text
534
 
535
 
536
- def detect_mode_switch_request(user_message: str) -> Optional[str]:
537
- text = normalize_user_text(user_message)
 
 
 
538
 
539
- for pat in NORMAL_SWITCH_PHRASES:
540
- if re.search(pat, text):
541
- return "normal_chat"
542
 
543
- strong_switch = any(re.search(pat, text) for pat in ECG_SWITCH_PHRASES)
544
- ecg_present = any(re.search(pat, text) for pat in ECG_MODE_PATTERNS)
 
545
 
546
- if strong_switch or (
547
- ("switch" in text or "mode" in text or "questions" in text or "related" in text)
548
- and ecg_present
549
- ):
550
- return "ecg_rag"
551
 
552
- return None
 
 
 
 
 
 
 
 
553
 
554
 
555
- def mode_switch_message(mode_value: str) -> str:
556
- if mode_value == "ecg_rag":
557
- return (
558
- "❤️ **ECG & Cardiology Mode activated**\n\n"
559
- "UI updated successfully.\n"
560
- "Ready for **medical, ECG, and cardiology** questions."
561
- )
562
- return (
563
- "💬 **Normal Chat Mode activated**\n\n"
564
- "UI updated successfully.\n"
565
- "Ready for general conversation again."
566
- )
567
 
568
 
569
- # -------------------------------
570
  # EMBEDDINGS + VECTORSTORE
571
- # -------------------------------
572
  logger.info("Loading embeddings...")
573
  embeddings = HuggingFaceEmbeddings(
574
  model_name=cfg.embed_model_name,
575
  model_kwargs={
576
  "device": "cuda" if torch.cuda.is_available() else "cpu",
577
- "token": None,
578
  },
579
  encode_kwargs={"normalize_embeddings": True},
580
  )
@@ -605,7 +421,7 @@ def build_vectorstore():
605
  "question": q,
606
  "answer": a,
607
  "tags": infer_tags(q, a),
608
- }
609
  )
610
  )
611
 
@@ -630,16 +446,15 @@ vectorstore = load_vectorstore()
630
  logger.info("Vectorstore ready.")
631
 
632
 
633
- # -------------------------------
634
- # LOCAL MODEL + ECG ADAPTER
635
- # -------------------------------
636
  logger.info("Loading tokenizer...")
637
  tokenizer = AutoTokenizer.from_pretrained(
638
  cfg.base_model_path,
639
  use_fast=True,
640
- token=cfg.hf_token if cfg.hf_token else None
641
  )
642
-
643
  if tokenizer.pad_token is None:
644
  tokenizer.pad_token = tokenizer.eos_token
645
 
@@ -680,10 +495,19 @@ if base_model is None:
680
 
681
  base_model.eval()
682
 
683
- logger.info("Loading ECG reasoning adapter...")
684
  reason_model = PeftModel.from_pretrained(base_model, cfg.adapter_dir)
685
  reason_model.eval()
686
 
 
 
 
 
 
 
 
 
 
687
 
688
  def get_primary_model_device(model) -> torch.device:
689
  try:
@@ -692,15 +516,50 @@ def get_primary_model_device(model) -> torch.device:
692
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
693
 
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  @torch.inference_mode()
696
  def run_local_reasoner(user_query: str, context: str) -> str:
697
  try:
698
  messages = [
699
  {"role": "system", "content": LOCAL_REASONING_SYSTEM},
700
- {
701
- "role": "user",
702
- "content": f"QUESTION:\n{user_query}\n\nEVIDENCE:\n{context if context.strip() else '[EMPTY]'}"
703
- },
704
  ]
705
 
706
  prompt = tokenizer.apply_chat_template(
@@ -732,76 +591,31 @@ def run_local_reasoner(user_query: str, context: str) -> str:
732
 
733
  gen_ids = out[0, inputs["input_ids"].shape[1]:]
734
  text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
735
- text = strip_bad_sections(text)
736
-
737
- return text if text else "INSUFFICIENT_EVIDENCE"
738
-
739
  except Exception as e:
740
  logger.error(f"Local reasoner error: {e}")
741
  traceback.print_exc()
742
  return "INSUFFICIENT_EVIDENCE"
743
 
744
 
745
- # -------------------------------
746
- # REMOTE LLM (DEEPSEEK)
747
- # -------------------------------
748
- deepseek_llm = ChatOpenAI(
749
- model=cfg.deepseek_model,
750
- api_key=cfg.deepseek_api_key,
751
- base_url=cfg.deepseek_base_url,
752
- temperature=cfg.deepseek_temperature,
753
- max_tokens=cfg.deepseek_max_tokens,
754
- )
755
-
756
- _query_expansion_cache: Dict[str, str] = {}
757
-
758
-
759
- def llm_text(system_prompt: str, user_prompt: str, fallback: str = "INSUFFICIENT_EVIDENCE") -> str:
760
- try:
761
- resp = deepseek_llm.invoke([
762
- {"role": "system", "content": system_prompt},
763
- {"role": "user", "content": user_prompt},
764
- ])
765
- text = resp.content if hasattr(resp, "content") else str(resp)
766
- text = strip_bad_sections(text)
767
- return text if text.strip() else fallback
768
- except Exception as e:
769
- logger.error(f"DeepSeek error: {e}")
770
- traceback.print_exc()
771
- return fallback
772
-
773
-
774
- def run_query_expansion(user_query: str) -> str:
775
- if not cfg.enable_query_expansion:
776
- return user_query
777
-
778
- if cfg.use_query_cache and user_query in _query_expansion_cache:
779
- logger.info(f"Using cached expansion for: {user_query[:80]}")
780
- return _query_expansion_cache[user_query]
781
-
782
  prompt = f"""
783
- USER_QUERY:
784
- {user_query}
785
-
786
- Expand this for retrieval with close medical phrasing, synonyms, and alternate wording.
787
- Do not answer the question.
788
- """.strip()
789
 
790
- expanded = llm_text(QUERY_EXPANSION_SYSTEM, prompt, fallback=user_query)
791
- expanded = expanded.strip() if expanded else user_query
792
 
793
- if cfg.use_query_cache:
794
- _query_expansion_cache[user_query] = expanded
795
 
796
- return expanded
 
 
 
797
 
798
 
799
- def run_deepseek_summary(
800
- user_query: str,
801
- context: str,
802
- reasoning_draft: str,
803
- chat_history: List[Dict[str, str]],
804
- ) -> str:
805
  prompt = f"""
806
  CHAT_HISTORY:
807
  {history_to_text(chat_history)}
@@ -814,30 +628,28 @@ RETRIEVED_EVIDENCE:
814
 
815
  LOCAL_REASONING_DRAFT:
816
  {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'}
817
-
818
- Write a grounded final summary answer using only the evidence and reasoning draft.
819
  """.strip()
 
820
 
821
- return llm_text(
822
- DEEPSEEK_SUMMARY_SYSTEM,
823
- prompt,
824
- fallback="I could not generate a grounded summary from the retrieved evidence."
825
- )
826
-
827
-
828
- def run_validator(context: str, answer: str) -> str:
829
- if not cfg.enable_validator:
830
- return "SUPPORTED (validator disabled)"
831
 
 
832
  prompt = f"""
833
- EVIDENCE:
 
 
 
834
  {context if context.strip() else '[EMPTY]'}
835
 
836
- ANSWER:
837
- {answer if answer.strip() else '[EMPTY]'}
838
- """.strip()
839
 
840
- return llm_text(VALIDATOR_SYSTEM, prompt, fallback="PARTLY_UNSUPPORTED: validator unavailable")
 
 
 
 
 
 
841
 
842
 
843
  def run_normal_chat(user_query: str, chat_history: List[Dict[str, str]]) -> str:
@@ -847,20 +659,13 @@ CHAT_HISTORY:
847
 
848
  USER_MESSAGE:
849
  {user_query}
850
-
851
- Respond as a normal helpful chatbot.
852
  """.strip()
853
-
854
- return llm_text(
855
- NORMAL_CHAT_SYSTEM,
856
- prompt,
857
- fallback="Sorry, I could not generate a response."
858
- )
859
 
860
 
861
- # -------------------------------
862
  # WARMUP
863
- # -------------------------------
864
  def warmup_models():
865
  logger.info("Warming up local reasoner...")
866
  try:
@@ -871,6 +676,7 @@ def warmup_models():
871
  EVIDENCE_ID: 1
872
  SOURCE_QUESTION: What are ECG findings in hyperkalemia?
873
  SOURCE_TAGS: ecg
 
874
  EVIDENCE_TEXT:
875
  Hyperkalemia may cause peaked T waves, PR prolongation, QRS widening, and severe conduction abnormalities.
876
  ==============================
@@ -881,38 +687,58 @@ Hyperkalemia may cause peaked T waves, PR prolongation, QRS widening, and severe
881
  logger.warning(f"Warmup failed: {e}")
882
 
883
 
884
- warmup_models()
 
885
 
886
 
887
- # -------------------------------
888
  # STATE
889
- # -------------------------------
890
- class ChatState(TypedDict, total=False):
891
  user_query: str
892
- expanded_query: str
893
  chat_history: List[Dict[str, str]]
894
 
 
 
 
895
  retrieved_docs: List[Document]
896
  best_score: float
897
- used_context: bool
898
  context: str
899
- retrieval_attempts: int
900
- retrieval_mode: str
901
 
902
- reasoning_draft: str
 
 
903
  final_answer: str
904
- validation_status: str
905
 
906
 
907
- # -------------------------------
908
  # RETRIEVAL
909
- # -------------------------------
910
- def retrieve_docs_once(query_for_search: str, original_query: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  try:
912
- scored = vectorstore.similarity_search_with_score(
913
- query_for_search,
914
- k=cfg.similarity_k,
915
- )
916
  except Exception as e:
917
  logger.error(f"Retriever error: {e}")
918
  traceback.print_exc()
@@ -922,1068 +748,377 @@ def retrieve_docs_once(query_for_search: str, original_query: str):
922
  return [], -1.0
923
 
924
  filtered_docs = []
925
- best_score = -1.0
926
-
927
  for doc, raw_score in scored:
928
  sim = score_to_similarity(raw_score)
929
- best_score = max(best_score, sim)
930
-
931
  q = doc.metadata.get("question", "")
932
  a = doc.metadata.get("answer", "")
933
  ov = lexical_overlap(original_query, f"{q} {a}")
934
 
935
- if ov >= cfg.min_lexical_overlap and sim >= cfg.min_faiss_similarity:
936
  new_doc = Document(page_content=doc.page_content, metadata=dict(doc.metadata))
937
  new_doc.metadata["sim_score"] = sim
938
  new_doc.metadata["lexical_overlap"] = ov
939
  filtered_docs.append(new_doc)
940
 
941
  reranked = rerank_docs(original_query, filtered_docs, top_n=cfg.top_k_final)
 
942
  return reranked, best_score
943
 
944
 
945
- # -------------------------------
946
- # LANGGRAPH NODES
947
- # -------------------------------
948
- def retrieve_node(state: ChatState) -> ChatState:
949
- query = state.get("expanded_query") or state["user_query"]
950
- retrieval_attempts = int(state.get("retrieval_attempts", 0)) + 1
951
- retrieval_mode = "expanded" if state.get("expanded_query") else "original"
952
 
953
- docs, best_score = retrieve_docs_once(
954
- query_for_search=query,
955
- original_query=state["user_query"],
956
- )
957
 
958
- if not docs:
959
- return {
960
- "retrieved_docs": [],
961
- "best_score": best_score,
962
- "used_context": False,
963
- "context": "",
964
- "retrieval_attempts": retrieval_attempts,
965
- "retrieval_mode": retrieval_mode,
966
- }
967
 
968
- return {
969
- "retrieved_docs": docs,
970
- "best_score": best_score,
971
- "used_context": True,
972
- "context": build_context_string(docs, max_chars=cfg.max_context_chars),
973
- "retrieval_attempts": retrieval_attempts,
974
- "retrieval_mode": retrieval_mode,
975
- }
976
 
977
 
978
- def should_retry_retrieval(state: ChatState) -> str:
979
- used_context = state.get("used_context", False)
980
- best_score = state.get("best_score", -1.0)
981
- attempts = int(state.get("retrieval_attempts", 0))
982
-
983
- if used_context and best_score >= cfg.min_faiss_similarity:
984
- return "local_reasoning"
985
-
986
- if not cfg.enable_query_expansion:
987
- return "local_reasoning"
988
-
989
- if attempts >= 2:
990
- return "local_reasoning"
991
-
992
- return "expand_query"
993
-
994
-
995
- def expand_query_node(state: ChatState) -> ChatState:
996
- expanded = run_query_expansion(state["user_query"])
997
- if not expanded.strip():
998
- expanded = state["user_query"]
999
- return {"expanded_query": expanded}
1000
-
1001
-
1002
- def local_reasoning_node(state: ChatState) -> ChatState:
1003
- context = state.get("context", "").strip()
1004
- if not context:
1005
- return {"reasoning_draft": "INSUFFICIENT_EVIDENCE"}
1006
-
1007
- reasoning = run_local_reasoner(state["user_query"], context)
1008
- return {"reasoning_draft": reasoning}
1009
-
1010
-
1011
- def generate_node(state: ChatState) -> ChatState:
1012
- context = state.get("context", "").strip()
1013
- reasoning = state.get("reasoning_draft", "INSUFFICIENT_EVIDENCE")
1014
- history = state.get("chat_history", [])
1015
-
1016
- if not context:
1017
- return {"final_answer": "I could not find sufficiently relevant evidence in the RAG database for this question."}
1018
-
1019
- answer = run_deepseek_summary(
1020
- user_query=state["user_query"],
1021
- context=context,
1022
- reasoning_draft=reasoning,
1023
- chat_history=history,
1024
- )
1025
- return {"final_answer": answer}
1026
-
1027
-
1028
- def validate_node(state: ChatState) -> ChatState:
1029
- context = state.get("context", "").strip()
1030
- answer = state.get("final_answer", "").strip()
1031
- best_score = state.get("best_score", -1.0)
1032
- docs = state.get("retrieved_docs", [])
1033
 
1034
- if not context or not answer:
1035
- return {"validation_status": "INSUFFICIENT_EVIDENCE: missing context or answer"}
1036
 
1037
- if strong_retrieval(best_score, docs):
1038
- return {"validation_status": "SUPPORTED (validator skipped due to strong retrieval)"}
 
1039
 
1040
- verdict = run_validator(context, answer)
 
1041
 
1042
- if verdict.startswith("SUPPORTED"):
1043
- return {"validation_status": verdict}
1044
 
1045
- if verdict.startswith("PARTLY_UNSUPPORTED"):
1046
- return {
1047
- "validation_status": verdict,
1048
- "final_answer": answer + "\n\nEvidence limits: some parts may not be fully supported by the retrieved evidence."
1049
- }
1050
 
1051
- if verdict.startswith("INSUFFICIENT_EVIDENCE"):
1052
- return {
1053
- "validation_status": verdict,
1054
- "final_answer": answer + "\n\nEvidence limits: the retrieved evidence was weak or only partially relevant."
 
 
 
 
 
 
 
 
 
1055
  }
 
 
 
1056
 
1057
- return {"validation_status": verdict}
1058
-
1059
-
1060
- def finalize_node(state: ChatState) -> ChatState:
1061
- answer = strip_bad_sections(state.get("final_answer", ""))
1062
- if not answer:
1063
- answer = "I could not generate an answer."
1064
- return {"final_answer": answer}
1065
-
1066
-
1067
- # -------------------------------
1068
- # GRAPH
1069
- # -------------------------------
1070
- builder = StateGraph(ChatState)
1071
- builder.add_node("retrieve", retrieve_node)
1072
- builder.add_node("expand_query", expand_query_node)
1073
- builder.add_node("local_reasoning", local_reasoning_node)
1074
- builder.add_node("generate", generate_node)
1075
- builder.add_node("validate", validate_node)
1076
- builder.add_node("finalize", finalize_node)
1077
-
1078
- builder.add_edge(START, "retrieve")
1079
- builder.add_conditional_edges(
1080
- "retrieve",
1081
- should_retry_retrieval,
1082
- {
1083
- "expand_query": "expand_query",
1084
- "local_reasoning": "local_reasoning",
1085
- }
1086
- )
1087
- builder.add_edge("expand_query", "retrieve")
1088
- builder.add_edge("local_reasoning", "generate")
1089
- builder.add_edge("generate", "validate")
1090
- builder.add_edge("validate", "finalize")
1091
- builder.add_edge("finalize", END)
1092
-
1093
- graph = builder.compile()
1094
- logger.info("LangGraph compiled.")
1095
-
1096
-
1097
- # -------------------------------
1098
- # FORMATTING HELPERS
1099
- # -------------------------------
1100
- def format_sources_minimal(result: Optional[Dict], chat_mode: str = "ecg_rag") -> str:
1101
- if chat_mode == "normal_chat":
1102
- return "## Retrieved Sources\n\nNormal chat mode is active. No ECG evidence retrieval used."
1103
-
1104
- if not result:
1105
- return "## Retrieved Sources\n\nNo sources yet."
1106
-
1107
- docs = result.get("retrieved_docs", [])
1108
- best_score = result.get("best_score", -1.0)
1109
 
1110
- if not docs:
1111
- return (
1112
- "## Retrieved Sources\n\n"
1113
- "No sufficiently relevant evidence retrieved.\n\n"
1114
- f"**Best score:** `{best_score:.3f}`"
1115
- )
1116
 
1117
- lines = [
1118
- "## Retrieved Sources",
1119
- f"**Best score:** `{best_score:.3f}`",
1120
- "",
1121
- ]
1122
-
1123
- for i, d in enumerate(docs, 1):
1124
- question = d.metadata.get("question", "")
1125
- answer = d.metadata.get("answer", "")
1126
- similarity = d.metadata.get("sim_score", "N/A")
1127
- preview = answer[:210].strip()
1128
- if len(answer) > 210:
1129
- preview += "..."
1130
-
1131
- lines.extend([
1132
- f"### Evidence {i}",
1133
- f"- **Question:** {question}",
1134
- f"- **Similarity:** `{similarity}`",
1135
- f"- **Preview:** {preview}",
1136
- "",
1137
- ])
1138
 
1139
- return "\n".join(lines)
 
1140
 
 
 
1141
 
1142
- def format_debug_text(result: Optional[Dict], chat_mode: str = "ecg_rag") -> str:
1143
- if chat_mode == "normal_chat":
1144
- return "MODE: normal_chat\nNo retrieval/debug evidence used."
1145
 
1146
- if not result:
1147
- return "No debug result yet."
1148
 
1149
- return f"""
1150
- BEST SCORE: {result.get('best_score', -1.0)}
1151
- USED CONTEXT: {result.get('used_context', False)}
1152
- RETRIEVAL ATTEMPTS: {result.get('retrieval_attempts', 0)}
1153
- RETRIEVAL MODE: {result.get('retrieval_mode', 'N/A')}
1154
- VALIDATION STATUS: {result.get('validation_status', 'N/A')}
 
 
 
 
 
 
1155
 
1156
- ----- CONTEXT -----
1157
- {result.get('context', '')}
 
 
 
1158
 
1159
- ----- LOCAL REASONING DRAFT -----
1160
- {result.get('reasoning_draft', '')}
1161
- """.strip()
1162
 
1163
 
1164
- # -------------------------------
1165
  # UI HELPERS
1166
- # -------------------------------
1167
  CUSTOM_CSS = """
1168
- :root {
1169
- --bg-main: #07111f;
1170
- --bg-soft: #0b1728;
1171
- --card: rgba(10, 19, 35, 0.86);
1172
- --card-2: rgba(14, 25, 43, 0.94);
1173
- --border: rgba(148, 163, 184, 0.16);
1174
- --text: #e5eefb;
1175
- --muted: #94a3b8;
1176
- --primary: #7c3aed;
1177
- --primary-2: #2563eb;
1178
- --success: #10b981;
1179
- }
1180
-
1181
  html, body, .gradio-container {
1182
  margin: 0 !important;
1183
  padding: 0 !important;
1184
- min-height: 100%;
1185
- background:
1186
- radial-gradient(circle at top left, rgba(124,58,237,0.22), transparent 28%),
1187
- radial-gradient(circle at top right, rgba(37,99,235,0.18), transparent 24%),
1188
- linear-gradient(180deg, #050b16 0%, #091321 100%);
1189
- color: var(--text);
1190
  }
1191
-
1192
  .gradio-container {
1193
- max-width: 100% !important;
1194
- padding: 12px !important;
1195
- }
1196
-
1197
- footer {
1198
- visibility: hidden;
1199
  }
1200
-
1201
- .top-card {
1202
- border: 1px solid var(--border);
1203
- background: linear-gradient(135deg, rgba(11,23,40,0.95), rgba(18,31,56,0.92));
1204
- border-radius: 22px;
1205
  padding: 16px;
1206
  margin-bottom: 12px;
1207
- box-shadow: 0 14px 40px rgba(0,0,0,0.20);
1208
  }
1209
-
1210
- .hero-title {
1211
- font-size: 1.6rem;
1212
  font-weight: 800;
1213
- color: #f8fbff;
1214
  margin-bottom: 6px;
1215
- line-height: 1.15;
1216
  }
1217
-
1218
- .hero-subtitle {
1219
- color: #cbd5e1;
1220
  font-size: 0.95rem;
1221
- line-height: 1.5;
1222
- }
1223
-
1224
- .badges {
1225
- display: flex;
1226
- gap: 8px;
1227
- flex-wrap: wrap;
1228
- margin-top: 12px;
1229
- }
1230
-
1231
- .badge {
1232
- display: inline-flex;
1233
- align-items: center;
1234
- gap: 6px;
1235
- padding: 6px 10px;
1236
- border-radius: 999px;
1237
- font-size: 11px;
1238
- color: #e6eefc;
1239
- border: 1px solid rgba(255,255,255,0.12);
1240
- background: rgba(255,255,255,0.06);
1241
- }
1242
-
1243
- .panel-wrap {
1244
- border: 1px solid var(--border);
1245
- background: linear-gradient(180deg, rgba(10,19,35,0.96), rgba(7,14,26,0.94));
1246
- border-radius: 20px;
1247
- padding: 12px;
1248
- box-shadow: 0 16px 45px rgba(0,0,0,0.22);
1249
  }
1250
-
1251
  #chatbot {
1252
- height: min(62vh, 640px) !important;
1253
- min-height: 360px !important;
1254
  border-radius: 18px !important;
1255
- border: 1px solid var(--border) !important;
1256
- overflow: hidden !important;
1257
- box-shadow: 0 14px 40px rgba(0,0,0,0.26) !important;
1258
  }
1259
-
1260
- .status-card {
1261
- padding: 12px 14px;
1262
  border-radius: 16px;
1263
- background: linear-gradient(135deg, #0f172a 0%, #172554 100%);
1264
- color: #f9fafb;
1265
- font-size: 14px;
1266
- border: 1px solid rgba(255,255,255,0.12);
1267
- box-shadow: 0 10px 30px rgba(0,0,0,0.2);
1268
- }
1269
-
1270
- .muted {
1271
- color: #a5b4fc;
1272
- font-size: 12px;
1273
  }
1274
-
1275
- .blink-dots {
1276
- font-size: 22px;
1277
- font-weight: 800;
1278
  letter-spacing: 4px;
 
1279
  animation: blinkDots 1s steps(1, end) infinite;
1280
- display: inline-block;
1281
- padding: 2px 0;
1282
  }
1283
-
1284
  @keyframes blinkDots {
1285
  0% { opacity: 1; }
1286
- 50% { opacity: 0.15; }
1287
  100% { opacity: 1; }
1288
  }
1289
-
1290
  textarea, .gr-textbox textarea {
1291
- border-radius: 16px !important;
1292
- font-size: 15px !important;
1293
- }
1294
-
1295
- .gr-textbox label, .gr-markdown, .gr-button {
1296
- font-size: 14px !important;
1297
  }
1298
-
1299
  button {
1300
  border-radius: 14px !important;
1301
  min-height: 44px !important;
1302
  font-weight: 600 !important;
1303
  }
1304
-
1305
- .mobile-stack {
1306
- display: flex;
1307
- flex-direction: column;
1308
- gap: 12px;
1309
- }
1310
-
1311
- .mobile-scroll {
1312
- max-height: 34vh;
1313
- overflow-y: auto;
1314
- }
1315
-
1316
- .command-note {
1317
- color: #cbd5e1;
1318
- font-size: 0.88rem;
1319
- line-height: 1.45;
1320
- }
1321
-
1322
- .mode-note {
1323
- color: #cbd5e1;
1324
- font-size: 0.88rem;
1325
- margin-top: 6px;
1326
- }
1327
-
1328
- @media (max-width: 1024px) {
1329
- .gradio-container { padding: 10px !important; }
1330
- .hero-title { font-size: 1.45rem; }
1331
- .hero-subtitle { font-size: 0.92rem; }
1332
- #chatbot { height: 56vh !important; }
1333
- }
1334
-
1335
- @media (max-width: 768px) {
1336
- .gradio-container { padding: 8px !important; }
1337
- .top-card { padding: 14px; border-radius: 18px; }
1338
- .hero-title { font-size: 1.28rem; }
1339
- .hero-subtitle { font-size: 0.88rem; line-height: 1.45; }
1340
- .badge { font-size: 10px; padding: 5px 8px; }
1341
- .panel-wrap { padding: 10px; border-radius: 16px; }
1342
- #chatbot {
1343
- height: 52vh !important;
1344
- min-height: 320px !important;
1345
- border-radius: 16px !important;
1346
- }
1347
- button { width: 100% !important; }
1348
- .mobile-scroll { max-height: 240px; }
1349
- }
1350
-
1351
- @media (max-width: 480px) {
1352
- .hero-title { font-size: 1.15rem; }
1353
- .hero-subtitle { font-size: 0.83rem; }
1354
- #chatbot {
1355
- height: 50vh !important;
1356
- min-height: 300px !important;
1357
- }
1358
- textarea, .gr-textbox textarea { font-size: 14px !important; }
1359
- }
1360
  """
1361
 
1362
 
1363
- def hero_html() -> str:
1364
  return """
1365
- <div class="top-card">
1366
- <div class="hero-title">🫀 Mr Cardio</div>
1367
- <div class="hero-subtitle">
1368
- ECG and cardiology specialist chatbot with automatic mode switching,
1369
- evidence retrieval, local ECG reasoning, grounded summaries, and normal chat mode.
1370
- </div>
1371
- <div class="badges">
1372
- <div class="badge">ECG RAG</div>
1373
- <div class="badge">Normal Chat</div>
1374
- <div class="badge">FAISS Retrieval</div>
1375
- <div class="badge">LoRA Adapter</div>
1376
- <div class="badge">Validated Output</div>
1377
  </div>
1378
  </div>
1379
  """
1380
 
1381
 
1382
  def thinking_html(stage: str) -> str:
1383
- icon = "⏳"
1384
- subtitle = "Retrieval → reasoning → grounded answer"
1385
-
1386
- if "switch" in stage.lower() or "activating" in stage.lower() or "updating ui" in stage.lower():
1387
- icon = "⚡"
1388
- subtitle = "Updating mode and interface"
1389
-
1390
  return f"""
1391
- <div class="status-card">
1392
- <div style="display:flex;align-items:center;gap:12px;">
1393
- <div style="font-size:19px;">{icon}</div>
1394
- <div>
1395
- <div style="font-weight:700;">{stage}</div>
1396
- <div class="muted">{subtitle}</div>
1397
- <div class="blink-dots">...</div>
1398
- </div>
1399
- </div>
1400
  </div>
1401
  """
1402
 
1403
 
1404
- def initialize_session():
1405
- return {
1406
- "chat_history": [],
1407
- "last_result": None,
1408
- "chat_mode": "ecg_rag",
1409
- }
1410
-
1411
-
1412
- def add_assistant_placeholder(history, text="..."):
1413
  history = history or []
1414
- history.append({
1415
- "role": "assistant",
1416
- "content": text,
1417
- "metadata": {"title": "Thinking"}
1418
- })
1419
  return history
1420
 
1421
 
1422
- def update_last_assistant_message(history, text, title=None):
1423
  history = history or []
1424
  if not history or history[-1]["role"] != "assistant":
1425
- msg = {"role": "assistant", "content": text}
1426
- if title:
1427
- msg["metadata"] = {"title": title}
1428
- history.append(msg)
1429
  return history
1430
-
1431
- history[-1] = {"role": "assistant", "content": text}
1432
- if title:
1433
- history[-1]["metadata"] = {"title": title}
1434
  return history
1435
 
1436
 
1437
- def user_submit(user_message, chat_ui_history):
1438
- chat_ui_history = chat_ui_history or []
1439
  user_message = (user_message or "").strip()
1440
-
1441
  if not user_message:
1442
- return "", chat_ui_history
1443
-
1444
- chat_ui_history.append({"role": "user", "content": user_message})
1445
- return "", chat_ui_history
1446
 
1447
 
1448
- def set_chat_mode(mode_value: str, session_state: Dict):
1449
- if session_state is None:
1450
- session_state = initialize_session()
1451
- session_state["chat_mode"] = mode_value
1452
- return session_state
1453
-
1454
-
1455
- def get_mode_label(session_state: Dict) -> str:
1456
- mode = (session_state or {}).get("chat_mode", "ecg_rag")
1457
-
1458
- if mode == "normal_chat":
1459
- return """
1460
- <div class="mode-note">
1461
- <b>Mode:</b> Normal Chat
1462
- </div>
1463
- """
1464
-
1465
- return """
1466
- <div class="mode-note">
1467
- <b>Mode:</b> ECG &amp; Cardiology
1468
- <br>
1469
- <span style="color:#93c5fd;">Medical / ECG / Cardiology specialist mode active</span>
1470
- </div>
1471
- """
1472
-
1473
-
1474
- # -------------------------------
1475
- # CORE CHAT
1476
- # -------------------------------
1477
- def run_chat_turn(user_message: str, memory_state: Dict) -> Dict:
1478
- if memory_state is None:
1479
- memory_state = initialize_session()
1480
-
1481
- chat_mode = memory_state.get("chat_mode", "ecg_rag")
1482
-
1483
- if chat_mode == "normal_chat":
1484
- answer = run_normal_chat(
1485
- user_query=user_message,
1486
- chat_history=memory_state["chat_history"]
1487
- )
1488
-
1489
- result = {
1490
- "final_answer": answer,
1491
- "best_score": -1.0,
1492
- "used_context": False,
1493
- "validation_status": "NORMAL_CHAT_MODE",
1494
- "retrieved_docs": [],
1495
- "context": "",
1496
- "reasoning_draft": "",
1497
- "retrieval_attempts": 0,
1498
- "retrieval_mode": "none",
1499
- }
1500
- else:
1501
- state_in = {
1502
- "user_query": user_message,
1503
- "chat_history": memory_state["chat_history"],
1504
- "retrieval_attempts": 0,
1505
- }
1506
-
1507
- try:
1508
- result = graph.invoke(state_in)
1509
- except Exception as e:
1510
- logger.error(f"Graph invocation error: {e}")
1511
- traceback.print_exc()
1512
- result = {
1513
- "final_answer": f"I hit a runtime error while processing the request: {e}",
1514
- "best_score": -1.0,
1515
- "used_context": False,
1516
- "validation_status": "ERROR",
1517
- "retrieved_docs": [],
1518
- "context": "",
1519
- "reasoning_draft": "",
1520
- "retrieval_attempts": 0,
1521
- "retrieval_mode": "error",
1522
- }
1523
 
1524
- answer = result.get("final_answer", "").strip() or "I could not generate an answer."
1525
- best_score = result.get("best_score", -1.0)
1526
- validation_status = result.get("validation_status", "N/A")
1527
- confidence = compute_confidence(result) if chat_mode == "ecg_rag" else 1.0
1528
 
1529
- answer_with_footer = (
1530
- f"{answer}\n\n---\n"
1531
- f"📊 mode={chat_mode} | confidence={confidence:.2f} | best_score={best_score:.3f} | validation={validation_status}"
1532
- )
1533
 
1534
- memory_state["chat_history"].append({"role": "user", "content": user_message})
1535
- memory_state["chat_history"].append({"role": "assistant", "content": answer})
1536
- memory_state["chat_history"] = memory_state["chat_history"][-12:]
1537
- memory_state["last_result"] = result
1538
 
1539
- return {
1540
- "answer": answer_with_footer,
1541
- "memory_state": memory_state,
1542
- "sources_markdown": format_sources_minimal(result, chat_mode=chat_mode),
1543
- "debug_text": format_debug_text(result, chat_mode=chat_mode),
1544
- }
1545
 
 
 
 
 
 
1546
 
1547
- def bot_respond_stream(chat_ui_history, session_state):
1548
- global vectorstore
1549
 
 
 
 
 
1550
  if session_state is None:
1551
  session_state = initialize_session()
1552
 
1553
- if not chat_ui_history:
1554
- yield (
1555
- chat_ui_history,
1556
- session_state,
1557
- "## Retrieved Sources\n\nNo sources yet.",
1558
- "No debug result yet.",
1559
- "",
1560
- get_mode_label(session_state),
1561
- session_state.get("chat_mode", "ecg_rag"),
1562
- )
1563
- return
1564
-
1565
- user_message = str(chat_ui_history[-1]["content"]).strip()
1566
- chat_mode = session_state.get("chat_mode", "ecg_rag")
1567
-
1568
- # ---------------------------------
1569
- # AUTO MODE SWITCH
1570
- # ---------------------------------
1571
- requested_mode = detect_mode_switch_request(user_message)
1572
-
1573
- if requested_mode and requested_mode != chat_mode:
1574
- session_state["chat_mode"] = requested_mode
1575
-
1576
- chat_ui_history = add_assistant_placeholder(chat_ui_history, text="...")
1577
- yield (
1578
- chat_ui_history,
1579
- session_state,
1580
- format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1581
- format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1582
- thinking_html(
1583
- f"Switching to {'ECG & Cardiology Mode' if requested_mode == 'ecg_rag' else 'Normal Chat Mode'}"
1584
- ),
1585
- get_mode_label(session_state),
1586
- requested_mode,
1587
- )
1588
- time.sleep(cfg.blink_stage_1)
1589
-
1590
- yield (
1591
- chat_ui_history,
1592
- session_state,
1593
- format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1594
- format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1595
- thinking_html("Updating UI"),
1596
- get_mode_label(session_state),
1597
- requested_mode,
1598
- )
1599
- time.sleep(cfg.blink_stage_2)
1600
-
1601
- final_switch_text = mode_switch_message(requested_mode)
1602
-
1603
- if cfg.enable_typewriter_stream:
1604
- for partial in stream_text(final_switch_text, step=90):
1605
- chat_ui_history = update_last_assistant_message(
1606
- chat_ui_history,
1607
- partial,
1608
- title="Mode Update"
1609
- )
1610
- yield (
1611
- chat_ui_history,
1612
- session_state,
1613
- format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1614
- format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1615
- "",
1616
- get_mode_label(session_state),
1617
- requested_mode,
1618
- )
1619
-
1620
- chat_ui_history = update_last_assistant_message(
1621
- chat_ui_history,
1622
- final_switch_text,
1623
- title="Mode Update"
1624
- )
1625
-
1626
- session_state["chat_history"].append({"role": "user", "content": user_message})
1627
- session_state["chat_history"].append({"role": "assistant", "content": final_switch_text})
1628
- session_state["chat_history"] = session_state["chat_history"][-12:]
1629
-
1630
- yield (
1631
- chat_ui_history,
1632
- session_state,
1633
- format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1634
- format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1635
- "",
1636
- get_mode_label(session_state),
1637
- requested_mode,
1638
- )
1639
  return
1640
 
1641
- if user_message == "/sources":
1642
- result = session_state.get("last_result")
1643
- chat_ui_history.append({
1644
- "role": "assistant",
1645
- "content": format_sources_minimal(result, chat_mode=chat_mode),
1646
- "metadata": {"title": "Sources"}
1647
- })
1648
- yield (
1649
- chat_ui_history,
1650
- session_state,
1651
- format_sources_minimal(result, chat_mode=chat_mode),
1652
- format_debug_text(result, chat_mode=chat_mode),
1653
- "",
1654
- get_mode_label(session_state),
1655
- session_state.get("chat_mode", "ecg_rag"),
1656
- )
1657
- return
1658
-
1659
- if user_message == "/debug":
1660
- result = session_state.get("last_result")
1661
- chat_ui_history.append({
1662
- "role": "assistant",
1663
- "content": format_debug_text(result, chat_mode=chat_mode),
1664
- "metadata": {"title": "Debug"}
1665
- })
1666
- yield (
1667
- chat_ui_history,
1668
- session_state,
1669
- format_sources_minimal(result, chat_mode=chat_mode),
1670
- format_debug_text(result, chat_mode=chat_mode),
1671
- "",
1672
- get_mode_label(session_state),
1673
- session_state.get("chat_mode", "ecg_rag"),
1674
- )
1675
- return
1676
-
1677
- if user_message == "/rebuild":
1678
- if not cfg.allow_rebuild_vectorstore:
1679
- chat_ui_history.append({
1680
- "role": "assistant",
1681
- "content": "Vector store rebuild is disabled on this Space.",
1682
- "metadata": {"title": "Restricted"}
1683
- })
1684
- yield (
1685
- chat_ui_history,
1686
- session_state,
1687
- format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1688
- format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1689
- "",
1690
- get_mode_label(session_state),
1691
- session_state.get("chat_mode", "ecg_rag"),
1692
- )
1693
- return
1694
-
1695
- chat_ui_history = add_assistant_placeholder(chat_ui_history)
1696
- yield (
1697
- chat_ui_history,
1698
- session_state,
1699
- "",
1700
- "",
1701
- thinking_html("Rebuilding vector store"),
1702
- get_mode_label(session_state),
1703
- session_state.get("chat_mode", "ecg_rag"),
1704
- )
1705
-
1706
- time.sleep(cfg.blink_stage_1)
1707
-
1708
- chat_ui_history = update_last_assistant_message(
1709
- chat_ui_history,
1710
- "Rebuilding vector store and reloading embeddings...",
1711
- title="Maintenance"
1712
- )
1713
- yield (
1714
- chat_ui_history,
1715
- session_state,
1716
- "",
1717
- "",
1718
- thinking_html("Rebuilding vector store"),
1719
- get_mode_label(session_state),
1720
- session_state.get("chat_mode", "ecg_rag"),
1721
- )
1722
-
1723
- build_vectorstore()
1724
- vectorstore = load_vectorstore()
1725
 
1726
- chat_ui_history = update_last_assistant_message(
1727
- chat_ui_history,
1728
- "✅ Vector store rebuilt and reloaded.",
1729
- title="Done"
1730
- )
1731
- yield (
1732
- chat_ui_history,
1733
- session_state,
1734
- format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1735
- format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1736
- "",
1737
- get_mode_label(session_state),
1738
- session_state.get("chat_mode", "ecg_rag"),
1739
- )
1740
- return
1741
 
1742
- chat_ui_history = add_assistant_placeholder(chat_ui_history, text="...")
1743
- yield (
1744
- chat_ui_history,
1745
- session_state,
1746
- "",
1747
- "",
1748
- thinking_html("Starting"),
1749
- get_mode_label(session_state),
1750
- session_state.get("chat_mode", "ecg_rag"),
1751
- )
1752
- time.sleep(cfg.blink_stage_1)
1753
 
1754
- if chat_mode == "normal_chat":
1755
- yield (
1756
- chat_ui_history,
1757
- session_state,
1758
- "",
1759
- "",
1760
- thinking_html("Generating normal chat reply"),
1761
- get_mode_label(session_state),
1762
- session_state.get("chat_mode", "ecg_rag"),
1763
- )
1764
- time.sleep(cfg.blink_stage_2)
1765
  else:
1766
- yield (
1767
- chat_ui_history,
1768
- session_state,
1769
- "",
1770
- "",
1771
- thinking_html("Retrieving evidence"),
1772
- get_mode_label(session_state),
1773
- session_state.get("chat_mode", "ecg_rag"),
1774
- )
1775
- time.sleep(cfg.blink_stage_2)
1776
-
1777
- yield (
1778
- chat_ui_history,
1779
- session_state,
1780
- "",
1781
- "",
1782
- thinking_html("Running ECG adapter reasoning"),
1783
- get_mode_label(session_state),
1784
- session_state.get("chat_mode", "ecg_rag"),
1785
- )
1786
- time.sleep(cfg.blink_stage_3)
1787
-
1788
- out = run_chat_turn(user_message, session_state)
1789
-
1790
- yield (
1791
- chat_ui_history,
1792
- session_state,
1793
- out["sources_markdown"],
1794
- out["debug_text"],
1795
- thinking_html("Generating grounded summary" if chat_mode == "ecg_rag" else "Finishing reply"),
1796
- get_mode_label(session_state),
1797
- session_state.get("chat_mode", "ecg_rag"),
1798
- )
1799
- time.sleep(cfg.blink_before_answer)
1800
 
1801
  if cfg.enable_typewriter_stream:
1802
- for partial in stream_text(out["answer"], step=120):
1803
- chat_ui_history = update_last_assistant_message(
1804
- chat_ui_history,
1805
- partial,
1806
- title="Answer"
1807
- )
1808
- yield (
1809
- chat_ui_history,
1810
- session_state,
1811
- out["sources_markdown"],
1812
- out["debug_text"],
1813
- "",
1814
- get_mode_label(session_state),
1815
- session_state.get("chat_mode", "ecg_rag"),
1816
- )
1817
-
1818
- chat_ui_history = update_last_assistant_message(
1819
- chat_ui_history,
1820
- out["answer"],
1821
- title="Answer"
1822
- )
1823
-
1824
- yield (
1825
- chat_ui_history,
1826
- out["memory_state"],
1827
- out["sources_markdown"],
1828
- out["debug_text"],
1829
- "",
1830
- get_mode_label(out["memory_state"]),
1831
- out["memory_state"].get("chat_mode", "ecg_rag"),
1832
- )
1833
-
1834
-
1835
- def clear_chat():
1836
- st = initialize_session()
1837
- return (
1838
- [],
1839
- st,
1840
- "## Retrieved Sources\n\nNo sources yet.",
1841
- "No debug result yet.",
1842
- "",
1843
- get_mode_label(st),
1844
- st.get("chat_mode", "ecg_rag"),
1845
- )
1846
-
1847
-
1848
- def rebuild_from_button(session_state, chatbot_history):
1849
- global vectorstore
1850
 
1851
- if session_state is None:
1852
- session_state = initialize_session()
1853
 
1854
- chat_mode = session_state.get("chat_mode", "ecg_rag")
1855
 
1856
- if not cfg.allow_rebuild_vectorstore:
1857
- chatbot_history = chatbot_history or []
1858
- chatbot_history.append({
1859
- "role": "assistant",
1860
- "content": "Vector store rebuild is disabled on this Space.",
1861
- "metadata": {"title": "Restricted"}
1862
- })
1863
- return (
1864
- chatbot_history,
1865
- session_state,
1866
- format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1867
- format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1868
- "",
1869
- get_mode_label(session_state),
1870
- session_state.get("chat_mode", "ecg_rag"),
1871
- )
1872
-
1873
- build_vectorstore()
1874
- vectorstore = load_vectorstore()
1875
-
1876
- chatbot_history = chatbot_history or []
1877
- chatbot_history.append({
1878
- "role": "assistant",
1879
- "content": "✅ Vector store rebuilt and reloaded.",
1880
- "metadata": {"title": "Done"}
1881
- })
1882
-
1883
- return (
1884
- chatbot_history,
1885
- session_state,
1886
- format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1887
- format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1888
- "",
1889
- get_mode_label(session_state),
1890
- session_state.get("chat_mode", "ecg_rag"),
1891
- )
1892
-
1893
-
1894
- # -------------------------------
1895
  # APP
1896
- # -------------------------------
1897
- with gr.Blocks(
1898
- title="Medical CSV RAG Chatbot",
1899
- css=CUSTOM_CSS,
1900
- theme=gr.themes.Soft(
1901
- primary_hue="indigo",
1902
- secondary_hue="blue",
1903
- neutral_hue="slate",
1904
- radius_size="lg",
1905
- text_size="md",
1906
- ),
1907
- ) as demo:
1908
-
1909
- gr.HTML(hero_html())
1910
 
1911
  session_state = gr.State(initialize_session())
1912
 
1913
- with gr.Column(elem_classes=["mobile-stack"]):
1914
- with gr.Group(elem_classes=["panel-wrap"]):
1915
- mode_selector = gr.Radio(
1916
- choices=[
1917
- ("ECG RAG Mode", "ecg_rag"),
1918
- ("Normal Chat Mode", "normal_chat"),
1919
- ],
1920
- value="ecg_rag",
1921
- label="Chat Mode",
1922
- interactive=True,
1923
- )
1924
-
1925
- mode_status = gr.HTML(get_mode_label(initialize_session()))
1926
-
1927
- chatbot = gr.Chatbot(
1928
- label="Clinical Chat",
1929
- height=640,
1930
- elem_id="chatbot",
1931
- type="messages",
1932
- show_copy_button=True,
1933
- bubble_full_width=False,
1934
- avatar_images=(None, None),
1935
- )
1936
-
1937
- user_box = gr.Textbox(
1938
- label="Ask a question",
1939
- placeholder="e.g. What are the ECG findings in hyperkalemia? or type 'switch to ECG mode'",
1940
- lines=2,
1941
- autofocus=True,
1942
- )
1943
 
1944
- status_html = gr.HTML("")
 
 
 
 
 
1945
 
1946
- with gr.Row():
1947
- send_btn = gr.Button("Send", variant="primary")
1948
- clear_btn = gr.Button("Clear")
1949
- rebuild_btn = gr.Button("Rebuild Store")
1950
 
1951
- gr.HTML(
1952
- """
1953
- <div class="command-note">
1954
- Commands: <code>/sources</code>, <code>/debug</code>, <code>/rebuild</code>
1955
- </div>
1956
- """
1957
- )
1958
 
1959
- with gr.Accordion("Retrieved Sources", open=False):
1960
- with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]):
1961
- sources_panel = gr.Markdown("## Retrieved Sources\n\nNo sources yet.")
1962
-
1963
- if cfg.show_debug_panel:
1964
- with gr.Accordion("Debug Panel", open=False):
1965
- with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]):
1966
- debug_panel = gr.Textbox(
1967
- label="Debug",
1968
- value="No debug result yet.",
1969
- lines=18,
1970
- max_lines=28,
1971
- interactive=False,
1972
- )
1973
- else:
1974
- debug_panel = gr.Textbox(visible=False, value="")
1975
 
1976
- mode_selector.change(
1977
- fn=set_chat_mode,
1978
- inputs=[mode_selector, session_state],
1979
- outputs=[session_state],
1980
- queue=False,
1981
- ).then(
1982
- fn=get_mode_label,
1983
- inputs=[session_state],
1984
- outputs=[mode_status],
1985
- queue=False,
1986
- )
1987
 
1988
  submit_event = user_box.submit(
1989
  fn=user_submit,
@@ -1991,41 +1126,33 @@ with gr.Blocks(
1991
  outputs=[user_box, chatbot],
1992
  queue=True,
1993
  )
1994
-
1995
  submit_event.then(
1996
  fn=bot_respond_stream,
1997
  inputs=[chatbot, session_state],
1998
- outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
1999
  queue=True,
2000
  )
2001
 
2002
- send_click = send_btn.click(
2003
  fn=user_submit,
2004
  inputs=[user_box, chatbot],
2005
  outputs=[user_box, chatbot],
2006
  queue=True,
2007
  )
2008
-
2009
- send_click.then(
2010
  fn=bot_respond_stream,
2011
  inputs=[chatbot, session_state],
2012
- outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2013
  queue=True,
2014
  )
2015
 
2016
  clear_btn.click(
2017
  fn=clear_chat,
2018
  inputs=[],
2019
- outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2020
  queue=False,
2021
  )
2022
 
2023
- rebuild_btn.click(
2024
- fn=rebuild_from_button,
2025
- inputs=[session_state, chatbot],
2026
- outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2027
- queue=True,
2028
- )
2029
 
2030
  demo.queue(default_concurrency_limit=1)
2031
 
@@ -2034,4 +1161,4 @@ if __name__ == "__main__":
2034
  debug=cfg.launch_debug,
2035
  server_name=cfg.server_name,
2036
  server_port=cfg.server_port,
2037
- )
 
1
  import os
2
  import re
 
 
 
 
3
  import time
4
+ import json
5
+ import queue
6
  import logging
7
+ import threading
8
+ import traceback
9
+ from typing import List, Dict, TypedDict, Optional, Tuple
10
  from dataclasses import dataclass, field
11
 
12
  import torch
 
20
  from langchain_huggingface import HuggingFaceEmbeddings
21
  from langchain_community.vectorstores import FAISS
22
  from langchain_openai import ChatOpenAI
23
+
24
 
25
  # ============================================================
26
+ # AGENTIC ECG CHATBOT
27
+ # - Starts as normal chatbot
28
+ # - Detects ECG / cardiology intent automatically
29
+ # - Retrieves from CSV RAG store only for ECG questions
30
+ # - Runs local ECG adapter reasoning
31
+ # - Runs remote evidence summarizer
32
+ # - Runs remote clinical-composer agent
33
+ # - Merges both into a final long answer
34
+ # - Simple UI with Send / Clear
35
+ # - Visible thinking status + progress logs
36
  # ============================================================
37
 
38
+ raw_omp = str(os.getenv("OMP_NUM_THREADS", "1")).strip()
39
+ os.environ["OMP_NUM_THREADS"] = raw_omp if re.fullmatch(r"\d+", raw_omp) else "1"
40
+
41
+
42
+ # ============================================================
43
  # LOGGING
44
+ # ============================================================
45
  logging.basicConfig(
46
  level=logging.INFO,
47
+ format="%(asctime)s | %(levelname)s | %(message)s"
48
  )
49
+ logger = logging.getLogger("agentic_ecg_chatbot")
50
 
51
 
52
+ # ============================================================
53
  # CONFIG
54
+ # ============================================================
55
  @dataclass
56
  class Config:
57
+ base_model_path: str = os.getenv("BASE_MODEL_PATH", "meta-llama/Llama-3.1-8B-Instruct")
58
+ adapter_dir: str = os.getenv("ADAPTER_DIR", "adapter_refined_v10")
59
+ data_csv: str = os.getenv("DATA_CSV", "RAGmaterials/ECG_RAG_only_clean.csv")
60
+ rag_dir: str = os.getenv("RAG_DIR", "RAGmaterials")
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  vectorstore_dir: str = field(init=False)
62
 
63
  hf_token: str = os.getenv("HF_TOKEN", "")
 
65
  deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
66
  deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
67
 
68
+ embed_model_name: str = os.getenv("EMBED_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
 
69
 
70
+ similarity_k: int = int(os.getenv("SIMILARITY_K", "10"))
 
 
 
 
 
71
  top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
72
+ max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5500"))
73
 
74
  max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
75
+ max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "220"))
76
  max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
77
 
78
+ min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.06"))
79
+ min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.18"))
 
 
80
 
81
+ deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.15"))
82
+ deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "900"))
83
+
84
+ use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
85
  enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
 
86
  enable_typewriter_stream: bool = os.getenv("ENABLE_TYPEWRITER_STREAM", "true").lower() == "true"
87
+ enable_warmup: bool = os.getenv("ENABLE_WARMUP", "true").lower() == "true"
88
  allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
89
 
 
 
90
  launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
91
  server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
92
  server_port: int = int(os.getenv("SERVER_PORT", "7860"))
93
 
 
 
 
 
 
94
  def __post_init__(self):
95
  self.vectorstore_dir = os.path.join(self.rag_dir, "faiss_store")
96
  os.makedirs(self.rag_dir, exist_ok=True)
97
 
98
  if not self.deepseek_api_key:
99
+ raise ValueError("Missing DEEPSEEK_API_KEY in environment / Space secrets.")
100
 
101
  if not self.hf_token:
102
+ raise ValueError("Missing HF_TOKEN in environment / Space secrets.")
 
 
103
 
104
  for path, name in [
105
  (self.adapter_dir, "Adapter directory"),
 
113
  logger.info("Configuration loaded.")
114
 
115
 
116
+ # ============================================================
117
  # PROMPTS
118
+ # ============================================================
119
+ INTENT_CLASSIFIER_SYSTEM = """
120
+ You classify user messages.
121
+
122
+ Return only one label:
123
+ - ECG_RAG
124
+ - NORMAL_CHAT
125
 
126
+ Choose ECG_RAG if the message is about ECG, EKG, cardiology, arrhythmia, heart rhythm, cardiac conduction,
127
+ ST changes, QRS, PR, QT, tachycardia, bradycardia, atrial fibrillation, flutter, bundle branch block,
128
+ heart block, hyperkalemia ECG changes, or similar cardiology interpretation.
129
+ Otherwise return NORMAL_CHAT.
130
+ """.strip()
131
 
132
+ QUERY_EXPANSION_SYSTEM = """
133
+ You expand ECG and cardiology retrieval queries.
134
  Rules:
135
+ 1. Preserve the exact user intent.
136
+ 2. Add close cardiology / ECG synonyms and alternate wording.
137
+ 3. Do not answer the question.
138
+ 4. Output only the expanded retrieval query.
139
+ """.strip()
140
+
141
+ LOCAL_REASONING_SYSTEM = """
142
+ You are a strict ECG and cardiology reasoning assistant.
143
+ You are not the final answer generator.
144
+ Use only the evidence provided.
145
+ Do not invent facts.
146
+
147
+ Output exactly in this format:
148
 
149
  KEY_FINDINGS:
150
  - ...
 
161
  LIMITS:
162
  - ...
163
 
164
+ If evidence is insufficient, output exactly:
165
  INSUFFICIENT_EVIDENCE
166
  """.strip()
167
 
168
+ RAG_SUMMARY_SYSTEM = """
169
+ You are a clinical evidence summarizer.
170
+ Write a well-structured answer grounded only in the provided evidence and reasoning draft.
171
+ Do not use outside knowledge.
172
+ Be accurate, conservative, and clinically clear.
 
 
 
 
 
173
 
174
+ Output format:
175
+ ### Summary
176
+ 4 to 7 full sentences.
177
 
178
+ ### Key Evidence Points
179
+ 4 to 6 bullet points.
 
180
 
181
+ ### Clinical Interpretation
182
+ 2 to 4 bullet points if supported.
183
 
184
+ ### Evidence Limits
185
+ State what is not established.
 
 
 
186
 
187
+ If the evidence is too weak, output exactly:
 
 
 
 
 
 
188
  INSUFFICIENT_EVIDENCE
189
+ """.strip()
190
 
191
+ CLINICAL_COMPOSER_SYSTEM = """
192
+ You are a second medical composition agent.
193
+ Your job is to produce a longer, polished explanation from the same evidence and the same user question.
194
+ You must stay faithful to the evidence.
195
+ Do not add unsupported facts.
196
+ Do not mention tools, prompts, or pipelines.
 
 
 
197
 
198
+ Output format:
199
+ ### Direct Answer
200
+ A direct answer in 2 to 3 sentences.
201
 
202
+ ### Expanded Explanation
203
+ A longer explanation in 5 to 8 sentences.
 
 
 
 
 
204
 
205
+ ### Important Notes
206
+ 3 to 5 bullet points.
207
+
208
+ ### Remaining Uncertainty
209
+ State what the evidence does not prove.
210
+
211
+ If the evidence is too weak, output exactly:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  INSUFFICIENT_EVIDENCE
213
  """.strip()
214
 
215
+ FINAL_MERGER_SYSTEM = """
216
+ You are the final answer agent.
217
+ You will receive:
218
+ 1. the user's question
219
+ 2. retrieved evidence
220
+ 3. a local ECG adapter reasoning draft
221
+ 4. summary agent output
222
+ 5. clinical composer output
223
 
224
+ Write one final long-form answer.
225
  Rules:
226
+ - Use only supported information.
227
+ - Merge overlapping ideas cleanly.
228
+ - Do not repeat the same point too many times.
229
+ - Make the answer helpful, detailed, and readable.
230
+ - Do not mention internal agents or processing steps.
 
 
 
231
 
232
+ Output format:
233
+ ### Final Answer
234
+ A detailed answer in 6 to 10 sentences.
235
 
236
+ ### Key Points
237
+ 4 to 6 bullets.
 
 
 
 
 
238
 
239
+ ### Clinical Perspective
240
+ 2 to 4 bullets if supported.
241
+
242
+ ### Limits
243
+ A short honest limitations section.
244
+
245
+ If evidence is weak, output exactly:
246
+ INSUFFICIENT_EVIDENCE
247
  """.strip()
248
 
249
+ NORMAL_CHAT_SYSTEM = """
250
+ You are a helpful, friendly chatbot.
251
+ Be conversational, clear, and useful.
252
+ Answer normally.
253
+ Do not mention hidden tools or internal systems.
254
+ """.strip()
255
 
256
+
257
+ # ============================================================
258
  # HELPERS
259
+ # ============================================================
260
  def clean_text(x: str) -> str:
261
  x = str(x).replace("\x00", " ").strip()
262
  x = re.sub(r"\s+", " ", x)
 
265
 
266
  def strip_bad_sections(txt: str) -> str:
267
  t = str(txt).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  t = re.sub(r"https?://\S+|www\.\S+", "", t).strip()
269
  return t
270
 
 
272
  def infer_tags(question: str, answer: str) -> List[str]:
273
  text = f"{question} {answer}".lower()
274
  tags: List[str] = []
 
275
  keyword_map = {
276
+ "ecg": ["ecg", "ekg", "qrs", "pr", "qt", "st elevation", "t wave", "arrhythmia", "tachycardia", "bradycardia"],
277
  "diagnosis": ["diagnosis", "diagnose", "criteria"],
278
+ "treatment": ["treat", "therapy", "management", "drug"],
279
+ "symptoms": ["symptom", "sign", "presentation"],
280
+ "etiology": ["cause", "caused by", "associated with", "risk factor"],
 
 
 
281
  }
 
282
  for tag, words in keyword_map.items():
283
  if any(w in text for w in words):
284
  tags.append(tag)
 
285
  return tags
286
 
287
 
 
305
  return len(q_words & t_words) / max(1, len(q_words))
306
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  def history_to_text(chat_history: List[Dict[str, str]], max_turns: Optional[int] = None) -> str:
309
+ max_turns = max_turns or cfg.max_chat_history_turns
 
 
310
  items = chat_history[-max_turns:]
311
  if not items:
312
  return "[EMPTY]"
 
313
  return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in items]).strip()
314
 
315
 
316
  def build_context_string(docs: List[Document], max_chars: Optional[int] = None) -> str:
317
+ max_chars = max_chars or cfg.max_context_chars
 
 
318
  blocks = []
319
  total = 0
 
320
  for i, d in enumerate(docs, 1):
321
  q = d.metadata.get("question", "")
322
  a = d.metadata.get("answer", "")
323
  tags = ", ".join(d.metadata.get("tags", [])) or "N/A"
324
+ sim = d.metadata.get("sim_score", "N/A")
 
325
  block = f"""
326
  ==============================
327
  EVIDENCE_ID: {i}
328
  SOURCE_ID: {d.metadata.get('id')}
329
  SOURCE_QUESTION: {q}
330
  SOURCE_TAGS: {tags}
331
+ SIMILARITY: {sim}
332
  EVIDENCE_TEXT:
333
  {a}
334
  ==============================
335
  """.strip()
 
336
  if total + len(block) > max_chars:
337
  break
 
338
  blocks.append(block)
339
  total += len(block) + 2
 
340
  return "\n\n".join(blocks).strip()
341
 
342
 
343
+ def stream_text(text: str, step: int = 120):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  acc = ""
345
  for i in range(0, len(text), step):
346
  acc += text[i:i + step]
347
  yield acc
348
 
349
 
350
+ # ============================================================
351
+ # PROGRESS / LOGGING
352
+ # ============================================================
353
+ def new_progress_state() -> Dict:
354
+ return {"lines": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
+ def add_progress(progress_state: Dict, msg: str):
358
+ line = f"[{time.strftime('%H:%M:%S')}] {msg}"
359
+ logger.info(msg)
360
+ progress_state["lines"].append(line)
361
+ progress_state["lines"] = progress_state["lines"][-80:]
362
 
 
 
 
363
 
364
+ def progress_text(progress_state: Dict) -> str:
365
+ lines = progress_state.get("lines", [])
366
+ return "\n".join(lines) if lines else "No progress yet."
367
 
 
 
 
 
 
368
 
369
+ # ============================================================
370
+ # ECG QUERY DETECTION
371
+ # ============================================================
372
+ ECG_REGEXES = [
373
+ r"\becg\b", r"\bekg\b", r"\bcardiology\b", r"\barrhythmia\b", r"\bheart rhythm\b",
374
+ r"\batrial fibrillation\b", r"\bafib\b", r"\bflutter\b", r"\bqrs\b", r"\bpr interval\b",
375
+ r"\bqt\b", r"\bst elevation\b", r"\bst depression\b", r"\bt wave\b", r"\bbradycardia\b",
376
+ r"\btachycardia\b", r"\bheart block\b", r"\bbundle branch block\b", r"\bhyperkalemia\b",
377
+ ]
378
 
379
 
380
+ def detect_ecg_by_rules(text: str) -> bool:
381
+ text = str(text or "").lower().strip()
382
+ return any(re.search(p, text) for p in ECG_REGEXES)
 
 
 
 
 
 
 
 
 
383
 
384
 
385
+ # ============================================================
386
  # EMBEDDINGS + VECTORSTORE
387
+ # ============================================================
388
  logger.info("Loading embeddings...")
389
  embeddings = HuggingFaceEmbeddings(
390
  model_name=cfg.embed_model_name,
391
  model_kwargs={
392
  "device": "cuda" if torch.cuda.is_available() else "cpu",
393
+ "token": cfg.hf_token if cfg.hf_token else None,
394
  },
395
  encode_kwargs={"normalize_embeddings": True},
396
  )
 
421
  "question": q,
422
  "answer": a,
423
  "tags": infer_tags(q, a),
424
+ },
425
  )
426
  )
427
 
 
446
  logger.info("Vectorstore ready.")
447
 
448
 
449
+ # ============================================================
450
+ # MODEL LOADING
451
+ # ============================================================
452
  logger.info("Loading tokenizer...")
453
  tokenizer = AutoTokenizer.from_pretrained(
454
  cfg.base_model_path,
455
  use_fast=True,
456
+ token=cfg.hf_token if cfg.hf_token else None,
457
  )
 
458
  if tokenizer.pad_token is None:
459
  tokenizer.pad_token = tokenizer.eos_token
460
 
 
495
 
496
  base_model.eval()
497
 
498
+ logger.info("Loading ECG adapter...")
499
  reason_model = PeftModel.from_pretrained(base_model, cfg.adapter_dir)
500
  reason_model.eval()
501
 
502
+ logger.info("Loading remote LLM client...")
503
+ remote_llm = ChatOpenAI(
504
+ model=cfg.deepseek_model,
505
+ api_key=cfg.deepseek_api_key,
506
+ base_url=cfg.deepseek_base_url,
507
+ temperature=cfg.deepseek_temperature,
508
+ max_tokens=cfg.deepseek_max_tokens,
509
+ )
510
+
511
 
512
  def get_primary_model_device(model) -> torch.device:
513
  try:
 
516
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
517
 
518
 
519
+ # ============================================================
520
+ # LLM CALLS
521
+ # ============================================================
522
+ def llm_text(system_prompt: str, user_prompt: str, fallback: str = "INSUFFICIENT_EVIDENCE") -> str:
523
+ try:
524
+ resp = remote_llm.invoke([
525
+ {"role": "system", "content": system_prompt},
526
+ {"role": "user", "content": user_prompt},
527
+ ])
528
+ text = resp.content if hasattr(resp, "content") else str(resp)
529
+ text = strip_bad_sections(text)
530
+ return text if text.strip() else fallback
531
+ except Exception as e:
532
+ logger.error(f"Remote LLM error: {e}")
533
+ traceback.print_exc()
534
+ return fallback
535
+
536
+
537
+ def classify_intent(user_query: str) -> str:
538
+ if detect_ecg_by_rules(user_query):
539
+ return "ECG_RAG"
540
+
541
+ result = llm_text(
542
+ INTENT_CLASSIFIER_SYSTEM,
543
+ f"USER_MESSAGE:\n{user_query}",
544
+ fallback="NORMAL_CHAT",
545
+ ).strip().upper()
546
+ return "ECG_RAG" if "ECG_RAG" in result else "NORMAL_CHAT"
547
+
548
+
549
+ def run_query_expansion(user_query: str) -> str:
550
+ if not cfg.enable_query_expansion:
551
+ return user_query
552
+ prompt = f"USER_QUERY:\n{user_query}\n\nExpand this for ECG/cardiology retrieval."
553
+ expanded = llm_text(QUERY_EXPANSION_SYSTEM, prompt, fallback=user_query)
554
+ return expanded.strip() if expanded else user_query
555
+
556
+
557
  @torch.inference_mode()
558
  def run_local_reasoner(user_query: str, context: str) -> str:
559
  try:
560
  messages = [
561
  {"role": "system", "content": LOCAL_REASONING_SYSTEM},
562
+ {"role": "user", "content": f"QUESTION:\n{user_query}\n\nEVIDENCE:\n{context or '[EMPTY]'}"},
 
 
 
563
  ]
564
 
565
  prompt = tokenizer.apply_chat_template(
 
591
 
592
  gen_ids = out[0, inputs["input_ids"].shape[1]:]
593
  text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
594
+ return strip_bad_sections(text) or "INSUFFICIENT_EVIDENCE"
 
 
 
595
  except Exception as e:
596
  logger.error(f"Local reasoner error: {e}")
597
  traceback.print_exc()
598
  return "INSUFFICIENT_EVIDENCE"
599
 
600
 
601
+ def run_rag_summary(user_query: str, context: str, reasoning_draft: str, chat_history: List[Dict[str, str]]) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  prompt = f"""
603
+ CHAT_HISTORY:
604
+ {history_to_text(chat_history)}
 
 
 
 
605
 
606
+ USER_QUESTION:
607
+ {user_query}
608
 
609
+ RETRIEVED_EVIDENCE:
610
+ {context if context.strip() else '[EMPTY]'}
611
 
612
+ LOCAL_REASONING_DRAFT:
613
+ {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'}
614
+ """.strip()
615
+ return llm_text(RAG_SUMMARY_SYSTEM, prompt, fallback="INSUFFICIENT_EVIDENCE")
616
 
617
 
618
+ def run_clinical_composer(user_query: str, context: str, reasoning_draft: str, chat_history: List[Dict[str, str]]) -> str:
 
 
 
 
 
619
  prompt = f"""
620
  CHAT_HISTORY:
621
  {history_to_text(chat_history)}
 
628
 
629
  LOCAL_REASONING_DRAFT:
630
  {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'}
 
 
631
  """.strip()
632
+ return llm_text(CLINICAL_COMPOSER_SYSTEM, prompt, fallback="INSUFFICIENT_EVIDENCE")
633
 
 
 
 
 
 
 
 
 
 
 
634
 
635
+ def run_final_merger(user_query: str, context: str, reasoning_draft: str, summary_a: str, summary_b: str) -> str:
636
  prompt = f"""
637
+ USER_QUESTION:
638
+ {user_query}
639
+
640
+ RETRIEVED_EVIDENCE:
641
  {context if context.strip() else '[EMPTY]'}
642
 
643
+ LOCAL_ECG_REASONING:
644
+ {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'}
 
645
 
646
+ SUMMARY_AGENT_OUTPUT:
647
+ {summary_a if summary_a.strip() else '[EMPTY]'}
648
+
649
+ CLINICAL_COMPOSER_OUTPUT:
650
+ {summary_b if summary_b.strip() else '[EMPTY]'}
651
+ """.strip()
652
+ return llm_text(FINAL_MERGER_SYSTEM, prompt, fallback="INSUFFICIENT_EVIDENCE")
653
 
654
 
655
  def run_normal_chat(user_query: str, chat_history: List[Dict[str, str]]) -> str:
 
659
 
660
  USER_MESSAGE:
661
  {user_query}
 
 
662
  """.strip()
663
+ return llm_text(NORMAL_CHAT_SYSTEM, prompt, fallback="Sorry, I could not generate a response.")
 
 
 
 
 
664
 
665
 
666
+ # ============================================================
667
  # WARMUP
668
+ # ============================================================
669
  def warmup_models():
670
  logger.info("Warming up local reasoner...")
671
  try:
 
676
  EVIDENCE_ID: 1
677
  SOURCE_QUESTION: What are ECG findings in hyperkalemia?
678
  SOURCE_TAGS: ecg
679
+ SIMILARITY: 0.9
680
  EVIDENCE_TEXT:
681
  Hyperkalemia may cause peaked T waves, PR prolongation, QRS widening, and severe conduction abnormalities.
682
  ==============================
 
687
  logger.warning(f"Warmup failed: {e}")
688
 
689
 
690
+ if cfg.enable_warmup:
691
+ warmup_models()
692
 
693
 
694
+ # ============================================================
695
  # STATE
696
+ # ============================================================
697
+ class AgentState(TypedDict, total=False):
698
  user_query: str
 
699
  chat_history: List[Dict[str, str]]
700
 
701
+ detected_mode: str
702
+ expanded_query: str
703
+
704
  retrieved_docs: List[Document]
705
  best_score: float
 
706
  context: str
 
 
707
 
708
+ local_reasoning: str
709
+ summary_agent: str
710
+ composer_agent: str
711
  final_answer: str
 
712
 
713
 
714
+ # ============================================================
715
  # RETRIEVAL
716
+ # ============================================================
717
+ def rerank_docs(query: str, docs: List[Document], top_n: Optional[int] = None) -> List[Document]:
718
+ top_n = top_n or cfg.top_k_final
719
+ q_words = set(re.findall(r"\w+", query.lower()))
720
+ scored = []
721
+
722
+ for d in docs:
723
+ question = d.metadata.get("question", "")
724
+ answer = d.metadata.get("answer", "")
725
+ tags = " ".join(d.metadata.get("tags", []))
726
+ text = f"{question} {answer} {tags}".lower()
727
+ t_words = set(re.findall(r"\w+", text))
728
+ overlap = len(q_words & t_words) / max(1, len(q_words))
729
+ question_boost = 0.20 if any(w in question.lower() for w in q_words) else 0.0
730
+ tag_boost = 0.10 if any(w in tags.lower() for w in q_words) else 0.0
731
+ sim_score = float(d.metadata.get("sim_score", 0.0))
732
+ final_score = overlap + question_boost + tag_boost + (0.35 * sim_score)
733
+ scored.append((d, final_score))
734
+
735
+ scored.sort(key=lambda x: x[1], reverse=True)
736
+ return [d for d, _ in scored[:top_n]]
737
+
738
+
739
+ def retrieve_docs_once(query_for_search: str, original_query: str) -> Tuple[List[Document], float]:
740
  try:
741
+ scored = vectorstore.similarity_search_with_score(query_for_search, k=cfg.similarity_k)
 
 
 
742
  except Exception as e:
743
  logger.error(f"Retriever error: {e}")
744
  traceback.print_exc()
 
748
  return [], -1.0
749
 
750
  filtered_docs = []
 
 
751
  for doc, raw_score in scored:
752
  sim = score_to_similarity(raw_score)
 
 
753
  q = doc.metadata.get("question", "")
754
  a = doc.metadata.get("answer", "")
755
  ov = lexical_overlap(original_query, f"{q} {a}")
756
 
757
+ if sim >= 0.45 or (ov >= cfg.min_lexical_overlap and sim >= cfg.min_faiss_similarity):
758
  new_doc = Document(page_content=doc.page_content, metadata=dict(doc.metadata))
759
  new_doc.metadata["sim_score"] = sim
760
  new_doc.metadata["lexical_overlap"] = ov
761
  filtered_docs.append(new_doc)
762
 
763
  reranked = rerank_docs(original_query, filtered_docs, top_n=cfg.top_k_final)
764
+ best_score = max((float(d.metadata.get("sim_score", -1.0)) for d in reranked), default=-1.0)
765
  return reranked, best_score
766
 
767
 
768
+ def retrieve_docs(query: str) -> Tuple[List[Document], float, str]:
769
+ docs_a, score_a = retrieve_docs_once(query, query)
770
+ if not cfg.enable_query_expansion:
771
+ return docs_a, score_a, query
 
 
 
772
 
773
+ expanded = run_query_expansion(query)
774
+ docs_b, score_b = retrieve_docs_once(expanded, query)
 
 
775
 
776
+ merged = []
777
+ seen_ids = set()
778
+ for d in docs_a + docs_b:
779
+ doc_id = d.metadata.get("id")
780
+ if doc_id not in seen_ids:
781
+ seen_ids.add(doc_id)
782
+ merged.append(d)
 
 
783
 
784
+ merged = rerank_docs(query, merged, top_n=cfg.top_k_final)
785
+ best_score = max(score_a, score_b)
786
+ return merged, best_score, expanded
 
 
 
 
 
787
 
788
 
789
+ # ============================================================
790
+ # CORE AGENTIC PIPELINE
791
+ # ============================================================
792
+ def initialize_session() -> Dict:
793
+ return {
794
+ "chat_history": [],
795
+ "last_result": None,
796
+ "progress": new_progress_state(),
797
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
 
 
 
799
 
800
+ def run_agentic_turn(user_query: str, session_state: Dict) -> Dict:
801
+ if session_state is None:
802
+ session_state = initialize_session()
803
 
804
+ progress = new_progress_state()
805
+ add_progress(progress, "User message received")
806
 
807
+ chat_history = session_state.get("chat_history", [])
 
808
 
809
+ add_progress(progress, "Detecting query type")
810
+ mode = classify_intent(user_query)
811
+ add_progress(progress, f"Detected mode: {mode}")
 
 
812
 
813
+ if mode == "NORMAL_CHAT":
814
+ add_progress(progress, "Running normal chat response")
815
+ answer = run_normal_chat(user_query, chat_history)
816
+ result = {
817
+ "mode": "normal_chat",
818
+ "final_answer": answer,
819
+ "retrieved_docs": [],
820
+ "best_score": -1.0,
821
+ "context": "",
822
+ "local_reasoning": "",
823
+ "summary_agent": "",
824
+ "composer_agent": "",
825
+ "progress_text": progress_text(progress),
826
  }
827
+ else:
828
+ add_progress(progress, "Running ECG retrieval")
829
+ docs, best_score, expanded_query = retrieve_docs(user_query)
830
 
831
+ add_progress(progress, f"Retrieved {len(docs)} document(s)")
832
+ add_progress(progress, f"Best score: {best_score:.3f}")
833
+ add_progress(progress, f"Expanded query: {expanded_query}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
 
835
+ context = build_context_string(docs)
 
 
 
 
 
836
 
837
+ if not context.strip():
838
+ add_progress(progress, "No strong ECG evidence found")
839
+ answer = "I could not find sufficiently relevant ECG evidence in the CSV knowledge base for this question."
840
+ result = {
841
+ "mode": "ecg_rag",
842
+ "final_answer": answer,
843
+ "retrieved_docs": docs,
844
+ "best_score": best_score,
845
+ "context": context,
846
+ "local_reasoning": "",
847
+ "summary_agent": "",
848
+ "composer_agent": "",
849
+ "progress_text": progress_text(progress),
850
+ }
851
+ else:
852
+ add_progress(progress, "Running local ECG adapter reasoning")
853
+ local_reasoning = run_local_reasoner(user_query, context)
 
 
 
 
854
 
855
+ add_progress(progress, "Running summary agent")
856
+ summary_agent = run_rag_summary(user_query, context, local_reasoning, chat_history)
857
 
858
+ add_progress(progress, "Running clinical composer agent")
859
+ composer_agent = run_clinical_composer(user_query, context, local_reasoning, chat_history)
860
 
861
+ add_progress(progress, "Running final merger agent")
862
+ final_answer = run_final_merger(user_query, context, local_reasoning, summary_agent, composer_agent)
 
863
 
864
+ if not final_answer.strip() or final_answer.strip() == "INSUFFICIENT_EVIDENCE":
865
+ final_answer = summary_agent if summary_agent.strip() else "INSUFFICIENT_EVIDENCE"
866
 
867
+ add_progress(progress, "Final answer ready")
868
+ result = {
869
+ "mode": "ecg_rag",
870
+ "final_answer": final_answer,
871
+ "retrieved_docs": docs,
872
+ "best_score": best_score,
873
+ "context": context,
874
+ "local_reasoning": local_reasoning,
875
+ "summary_agent": summary_agent,
876
+ "composer_agent": composer_agent,
877
+ "progress_text": progress_text(progress),
878
+ }
879
 
880
+ session_state["chat_history"].append({"role": "user", "content": user_query})
881
+ session_state["chat_history"].append({"role": "assistant", "content": result["final_answer"]})
882
+ session_state["chat_history"] = session_state["chat_history"][-12:]
883
+ session_state["last_result"] = result
884
+ session_state["progress"] = progress
885
 
886
+ return {"result": result, "session_state": session_state}
 
 
887
 
888
 
889
+ # ============================================================
890
  # UI HELPERS
891
+ # ============================================================
892
  CUSTOM_CSS = """
 
 
 
 
 
 
 
 
 
 
 
 
 
893
  html, body, .gradio-container {
894
  margin: 0 !important;
895
  padding: 0 !important;
896
+ background: #0b1220;
897
+ color: #e5e7eb;
 
 
 
 
898
  }
 
899
  .gradio-container {
900
+ max-width: 900px !important;
901
+ margin: 0 auto !important;
902
+ padding: 16px !important;
 
 
 
903
  }
904
+ .simple-card {
905
+ border: 1px solid rgba(255,255,255,0.08);
906
+ background: #111827;
907
+ border-radius: 18px;
 
908
  padding: 16px;
909
  margin-bottom: 12px;
 
910
  }
911
+ .app-title {
912
+ font-size: 1.4rem;
 
913
  font-weight: 800;
914
+ color: #f9fafb;
915
  margin-bottom: 6px;
 
916
  }
917
+ .app-subtitle {
 
 
918
  font-size: 0.95rem;
919
+ color: #cbd5e1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  }
 
921
  #chatbot {
922
+ min-height: 60vh !important;
 
923
  border-radius: 18px !important;
 
 
 
924
  }
925
+ .status-box {
926
+ border: 1px solid rgba(255,255,255,0.08);
927
+ background: linear-gradient(180deg, #111827 0%, #172033 100%);
928
  border-radius: 16px;
929
+ padding: 12px 14px;
930
+ color: #f3f4f6;
 
 
 
 
 
 
 
 
931
  }
932
+ .thinking-dots {
933
+ display: inline-block;
 
 
934
  letter-spacing: 4px;
935
+ font-weight: 800;
936
  animation: blinkDots 1s steps(1, end) infinite;
 
 
937
  }
 
938
  @keyframes blinkDots {
939
  0% { opacity: 1; }
940
+ 50% { opacity: 0.2; }
941
  100% { opacity: 1; }
942
  }
 
943
  textarea, .gr-textbox textarea {
944
+ border-radius: 14px !important;
 
 
 
 
 
945
  }
 
946
  button {
947
  border-radius: 14px !important;
948
  min-height: 44px !important;
949
  font-weight: 600 !important;
950
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  """
952
 
953
 
954
+ def header_html() -> str:
955
  return """
956
+ <div class="simple-card">
957
+ <div class="app-title">🫀 Agentic ECG Chatbot</div>
958
+ <div class="app-subtitle">
959
+ Starts as normal chat. If the question is ECG/cardiology-related, it automatically switches into ECG evidence mode,
960
+ retrieves from your CSV knowledge base, runs local ECG adapter reasoning, builds two summaries, and merges them into one long final answer.
 
 
 
 
 
 
 
961
  </div>
962
  </div>
963
  """
964
 
965
 
966
  def thinking_html(stage: str) -> str:
 
 
 
 
 
 
 
967
  return f"""
968
+ <div class="status-box">
969
+ <b>{stage}</b><br>
970
+ Model is thinking <span class="thinking-dots">...</span>
 
 
 
 
 
 
971
  </div>
972
  """
973
 
974
 
975
+ def add_assistant_placeholder(history, text="Thinking..."):
 
 
 
 
 
 
 
 
976
  history = history or []
977
+ history.append({"role": "assistant", "content": text, "metadata": {"title": "Thinking"}})
 
 
 
 
978
  return history
979
 
980
 
981
+ def update_last_assistant_message(history, text, title="Answer"):
982
  history = history or []
983
  if not history or history[-1]["role"] != "assistant":
984
+ history.append({"role": "assistant", "content": text, "metadata": {"title": title}})
 
 
 
985
  return history
986
+ history[-1] = {"role": "assistant", "content": text, "metadata": {"title": title}}
 
 
 
987
  return history
988
 
989
 
990
+ def user_submit(user_message, chat_history):
991
+ chat_history = chat_history or []
992
  user_message = (user_message or "").strip()
 
993
  if not user_message:
994
+ return "", chat_history
995
+ chat_history.append({"role": "user", "content": user_message})
996
+ return "", chat_history
 
997
 
998
 
999
+ def format_sources(result: Optional[Dict]) -> str:
1000
+ if not result:
1001
+ return "No sources yet."
1002
+ docs = result.get("retrieved_docs", [])
1003
+ if not docs:
1004
+ return "No ECG retrieval used for the last answer."
1005
+ lines = [f"Best score: {result.get('best_score', -1.0):.3f}", ""]
1006
+ for i, d in enumerate(docs, 1):
1007
+ q = d.metadata.get("question", "")
1008
+ a = d.metadata.get("answer", "")
1009
+ sim = d.metadata.get("sim_score", "N/A")
1010
+ preview = a[:220] + ("..." if len(a) > 220 else "")
1011
+ lines += [
1012
+ f"Evidence {i}",
1013
+ f"- Question: {q}",
1014
+ f"- Similarity: {sim}",
1015
+ f"- Preview: {preview}",
1016
+ "",
1017
+ ]
1018
+ return "\n".join(lines).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
 
 
 
 
 
1020
 
1021
+ def clear_chat():
1022
+ st = initialize_session()
1023
+ return [], st, "", "No progress yet.", "No sources yet."
 
1024
 
 
 
 
 
1025
 
1026
+ def rebuild_store(session_state, chat_history):
1027
+ global vectorstore
1028
+ if not cfg.allow_rebuild_vectorstore:
1029
+ chat_history = chat_history or []
1030
+ chat_history.append({"role": "assistant", "content": "Vector store rebuild is disabled.", "metadata": {"title": "Restricted"}})
1031
+ return chat_history, session_state, "", progress_text(session_state.get("progress", new_progress_state())), format_sources(session_state.get("last_result"))
1032
 
1033
+ build_vectorstore()
1034
+ vectorstore = load_vectorstore()
1035
+ chat_history = chat_history or []
1036
+ chat_history.append({"role": "assistant", "content": "✅ Vector store rebuilt.", "metadata": {"title": "Done"}})
1037
+ return chat_history, session_state, "", progress_text(session_state.get("progress", new_progress_state())), format_sources(session_state.get("last_result"))
1038
 
 
 
1039
 
1040
+ # ============================================================
1041
+ # STREAMING RESPONSE
1042
+ # ============================================================
1043
+ def bot_respond_stream(chat_history, session_state):
1044
  if session_state is None:
1045
  session_state = initialize_session()
1046
 
1047
+ if not chat_history:
1048
+ yield chat_history, session_state, "", "No progress yet.", "No sources yet."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
  return
1050
 
1051
+ user_message = str(chat_history[-1]["content"]).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1052
 
1053
+ chat_history = add_assistant_placeholder(chat_history, "Thinking...")
1054
+ yield chat_history, session_state, thinking_html("Understanding your message"), "Starting...", ""
1055
+ time.sleep(0.4)
 
 
 
 
 
 
 
 
 
 
 
 
1056
 
1057
+ yield chat_history, session_state, thinking_html("Detecting whether this is normal chat or ECG reasoning"), "Detecting intent...", ""
1058
+ time.sleep(0.4)
 
 
 
 
 
 
 
 
 
1059
 
1060
+ detected = classify_intent(user_message)
1061
+ if detected == "NORMAL_CHAT":
1062
+ yield chat_history, session_state, thinking_html("Normal chatbot mode active"), "Running normal chat...", ""
1063
+ time.sleep(0.4)
 
 
 
 
 
 
 
1064
  else:
1065
+ yield chat_history, session_state, thinking_html("ECG mode detected: retrieving evidence"), "Retrieving ECG evidence...", ""
1066
+ time.sleep(0.45)
1067
+ yield chat_history, session_state, thinking_html("Running local ECG adapter reasoning"), "Running local reasoning...", ""
1068
+ time.sleep(0.45)
1069
+ yield chat_history, session_state, thinking_html("Generating multiple summaries and composing final answer"), "Generating final answer...", ""
1070
+ time.sleep(0.45)
1071
+
1072
+ out = run_agentic_turn(user_message, session_state)
1073
+ result = out["result"]
1074
+ updated_session = out["session_state"]
1075
+ answer = result.get("final_answer", "I could not generate an answer.")
1076
+ sources = format_sources(result)
1077
+ prog = result.get("progress_text", "No progress yet.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1078
 
1079
  if cfg.enable_typewriter_stream:
1080
+ for partial in stream_text(answer, step=140):
1081
+ chat_history = update_last_assistant_message(chat_history, partial, title="Answer")
1082
+ yield chat_history, updated_session, "", prog, sources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1083
 
1084
+ chat_history = update_last_assistant_message(chat_history, answer, title="Answer")
1085
+ yield chat_history, updated_session, "", prog, sources
1086
 
 
1087
 
1088
+ # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1089
  # APP
1090
+ # ============================================================
1091
+ with gr.Blocks(title="Agentic ECG Chatbot", css=CUSTOM_CSS) as demo:
1092
+ gr.HTML(header_html())
 
 
 
 
 
 
 
 
 
 
 
1093
 
1094
  session_state = gr.State(initialize_session())
1095
 
1096
+ chatbot = gr.Chatbot(
1097
+ label="Chat",
1098
+ elem_id="chatbot",
1099
+ type="messages",
1100
+ show_copy_button=True,
1101
+ bubble_full_width=False,
1102
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1103
 
1104
+ user_box = gr.Textbox(
1105
+ label="Message",
1106
+ placeholder="Ask anything. ECG / cardiology questions are detected automatically.",
1107
+ lines=2,
1108
+ autofocus=True,
1109
+ )
1110
 
1111
+ status_html = gr.HTML("")
 
 
 
1112
 
1113
+ with gr.Row():
1114
+ send_btn = gr.Button("Submit", variant="primary")
1115
+ clear_btn = gr.Button("Clear")
 
 
 
 
1116
 
1117
+ with gr.Accordion("Progress Log", open=False):
1118
+ progress_panel = gr.Textbox(value="No progress yet.", lines=16, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1119
 
1120
+ with gr.Accordion("Retrieved ECG Sources", open=False):
1121
+ sources_panel = gr.Textbox(value="No sources yet.", lines=16, interactive=False)
 
 
 
 
 
 
 
 
 
1122
 
1123
  submit_event = user_box.submit(
1124
  fn=user_submit,
 
1126
  outputs=[user_box, chatbot],
1127
  queue=True,
1128
  )
 
1129
  submit_event.then(
1130
  fn=bot_respond_stream,
1131
  inputs=[chatbot, session_state],
1132
+ outputs=[chatbot, session_state, status_html, progress_panel, sources_panel],
1133
  queue=True,
1134
  )
1135
 
1136
+ send_event = send_btn.click(
1137
  fn=user_submit,
1138
  inputs=[user_box, chatbot],
1139
  outputs=[user_box, chatbot],
1140
  queue=True,
1141
  )
1142
+ send_event.then(
 
1143
  fn=bot_respond_stream,
1144
  inputs=[chatbot, session_state],
1145
+ outputs=[chatbot, session_state, status_html, progress_panel, sources_panel],
1146
  queue=True,
1147
  )
1148
 
1149
  clear_btn.click(
1150
  fn=clear_chat,
1151
  inputs=[],
1152
+ outputs=[chatbot, session_state, status_html, progress_panel, sources_panel],
1153
  queue=False,
1154
  )
1155
 
 
 
 
 
 
 
1156
 
1157
  demo.queue(default_concurrency_limit=1)
1158
 
 
1161
  debug=cfg.launch_debug,
1162
  server_name=cfg.server_name,
1163
  server_port=cfg.server_port,
1164
+ )