resberry commited on
Commit
9c86989
·
verified ·
1 Parent(s): c04022e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2022 -54
app.py CHANGED
@@ -1,69 +1,2037 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- with gr.Blocks() as demo:
63
- with gr.Sidebar():
64
- gr.LoginButton()
65
- chatbot.render()
 
 
66
 
 
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ raw_omp = str(os.getenv("OMP_NUM_THREADS", "1")).strip()
5
+ os.environ["OMP_NUM_THREADS"] = raw_omp if re.fullmatch(r"\d+", raw_omp) else "1"
6
+
7
+ import time
8
+ import traceback
9
+ import logging
10
+ from typing import List, Dict, TypedDict, Optional
11
+ from dataclasses import dataclass, field
12
+
13
+ import torch
14
+ import pandas as pd
15
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
18
+ from peft import PeftModel
19
+
20
+ from langchain_core.documents import Document
21
+ from langchain_huggingface import HuggingFaceEmbeddings
22
+ from langchain_community.vectorstores import FAISS
23
+ from langchain_openai import ChatOpenAIa
24
+ from langgraph.graph import StateGraph, START, END
25
+
26
+ # ============================================================
27
+ # HUGGING FACE SPACES READY
28
+ # Medical CSV RAG Chatbot + Normal Chat Mode
29
+ # Modes:
30
+ # 1) ECG RAG Mode -> retrieval -> local ECG reasoning -> grounded summary
31
+ # 2) Normal Chat Mode -> standard chatbot response
32
+ # Extra:
33
+ # 3) Automatic ECG/Cardiology mode switching from user text
34
+ # ============================================================
35
+
36
+ # -------------------------------
37
+ # LOGGING
38
+ # -------------------------------
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format="%(asctime)s - %(levelname)s - %(message)s"
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # -------------------------------
47
+ # CONFIG
48
+ # -------------------------------
49
+ @dataclass
50
+ class Config:
51
+ base_model_path: str = os.getenv(
52
+ "BASE_MODEL_PATH",
53
+ "meta-llama/Llama-3.1-8B-Instruct"
54
+ )
55
+
56
+ adapter_dir: str = os.getenv(
57
+ "ADAPTER_DIR",
58
+ "adapter_refined_v10"
59
+ )
60
+ data_csv: str = os.getenv(
61
+ "DATA_CSV",
62
+ "RAGmaterials/ECG_RAG_only_clean.csv"
63
+ )
64
+ rag_dir: str = os.getenv(
65
+ "RAG_DIR",
66
+ "RAGmaterials"
67
+ )
68
+ vectorstore_dir: str = field(init=False)
69
+
70
+ hf_token: str = os.getenv("HF_TOKEN", "")
71
+ deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
72
+ deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
73
+ deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
74
+
75
+ deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
76
+ deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "700"))
77
+
78
+ embed_model_name: str = os.getenv(
79
+ "EMBED_MODEL_NAME",
80
+ "sentence-transformers/all-MiniLM-L6-v2"
81
+ )
82
+
83
+ similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
84
+ top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
85
+ max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
86
+
87
+ max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
88
+ max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
89
+ max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
90
+
91
+ min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
92
+ min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
93
+ strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
94
+ strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
95
+
96
+ use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
97
+ enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
98
+ enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
99
+ enable_typewriter_stream: bool = os.getenv("ENABLE_TYPEWRITER_STREAM", "true").lower() == "true"
100
+ show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
101
+ allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
102
+
103
+ use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
104
+
105
+ launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
106
+ server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
107
+ server_port: int = int(os.getenv("SERVER_PORT", "7860"))
108
+
109
+ blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
110
+ blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
111
+ blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))
112
+ blink_before_answer: float = float(os.getenv("BLINK_BEFORE_ANSWER", "0.25"))
113
+
114
+ def __post_init__(self):
115
+ self.vectorstore_dir = os.path.join(self.rag_dir, "faiss_store")
116
+ os.makedirs(self.rag_dir, exist_ok=True)
117
+
118
+ if not self.deepseek_api_key:
119
+ raise ValueError("Missing DEEPSEEK_API_KEY. Add it in Hugging Face Space Secrets.")
120
+
121
+ if not self.hf_token:
122
+ raise ValueError(
123
+ "Missing HF_TOKEN. Add a valid Hugging Face token with access to the gated base model."
124
+ )
125
+
126
+ for path, name in [
127
+ (self.adapter_dir, "Adapter directory"),
128
+ (self.data_csv, "CSV data"),
129
+ ]:
130
+ if not os.path.exists(path):
131
+ raise FileNotFoundError(f"{name} not found at: {path}")
132
+
133
+
134
+ cfg = Config()
135
+ logger.info("Configuration loaded.")
136
+
137
+
138
+ # -------------------------------
139
+ # PROMPTS
140
+ # -------------------------------
141
+ LOCAL_REASONING_SYSTEM = """
142
+ You are a strict medical reasoning assistant specialized for ECG and cardiology reasoning.
143
+
144
+ You are NOT the final answer generator.
145
+ You must analyze ONLY the supplied evidence and produce a short structured reasoning draft.
146
+
147
+ Rules:
148
+ 1) Use only the provided evidence.
149
+ 2) Do not invent facts.
150
+ 3) Focus only on the user's exact question.
151
+ 4) Output exactly in this structure:
152
+
153
+ KEY_FINDINGS:
154
+ - ...
155
+ - ...
156
+
157
+ INTERPRETATION:
158
+ - ...
159
+ - ...
160
+
161
+ SUPPORTED_POINTS:
162
+ - [EVIDENCE_ID: X] ...
163
+ - [EVIDENCE_ID: Y] ...
164
+
165
+ LIMITS:
166
+ - ...
167
+
168
+ 5) If evidence is insufficient, output exactly:
169
+ INSUFFICIENT_EVIDENCE
170
+ """.strip()
171
+
172
+ QUERY_EXPANSION_SYSTEM = """
173
+ You expand medical queries for retrieval.
174
+
175
+ Rules:
176
+ 1) Preserve the user's intent.
177
+ 2) Add close medical paraphrases and alternate wording.
178
+ 3) Add likely medical synonyms, abbreviations, and alternate phrasing.
179
+ 4) Do not answer the question.
180
+ 5) Output only the expanded retrieval query.
181
+ """.strip()
182
+
183
+ DEEPSEEK_SUMMARY_SYSTEM = """
184
+ You are an expert medical evidence summarizer.
185
+
186
+ Your job is to produce a clinically precise, well-structured answer grounded ONLY in:
187
+ 1. the retrieved evidence
188
+ 2. the local reasoning draft
189
+
190
+ You must be faithful to the provided material and answer the user's question directly, clearly, and conservatively.
191
+
192
+ PRIMARY OBJECTIVE
193
+ - Identify the user's main intent before writing:
194
+ definition, cause, symptoms, diagnosis, investigation, treatment, prognosis, or genetics.
195
+ - Prioritize that intent throughout the response.
196
+ - The first sentence of the Summary must directly answer the user's question in the most clinically relevant way.
197
+
198
+ GROUNDING RULES
199
+ - Use only information supported by the retrieved evidence and local reasoning draft.
200
+ - Do not add outside medical knowledge.
201
+ - Do not infer specific facts unless they are clearly supported.
202
+ - Do not invent treatments, diagnoses, risks, mechanisms, thresholds, statistics, timelines, monitoring plans, or prognosis details.
203
+ - If the evidence is incomplete, be explicit about what is missing.
204
+ - If the evidence is too weak to answer the question reliably, output exactly:
205
+ INSUFFICIENT_EVIDENCE
206
+
207
+ STYLE RULES
208
+ - Write in precise, professional clinical language.
209
+ - Be specific, not vague.
210
+ - Be concise, but fully informative.
211
+ - Avoid repetition, generic filler, and empty statements.
212
+ - Do not mention retrieval, prompts, system instructions, reasoning drafts, tools, pipelines, or internal processes.
213
+ - Do not include URLs or citations unless explicitly requested elsewhere.
214
+ - Do not overstate certainty.
215
+ - When appropriate, distinguish clearly between what is established, what is suggested, and what is not addressed by the evidence.
216
+
217
+ OUTPUT FORMAT
218
+
219
+ ### Summary
220
+ - Write 4 to 7 full sentences.
221
+ - This is the most important section.
222
+ - The first sentence must directly answer the user's question.
223
+ - Focus primarily on the user's main intent.
224
+ - Include only background information that improves understanding of the requested topic.
225
+ - Make the summary clinically useful, specific, and evidence-faithful.
226
+
227
+ ### Key Evidence Points
228
+ - Include 4 to 6 bullet points.
229
+ - Each bullet must state a concrete fact supported by the evidence.
230
+ - Prioritize clinically important facts over background detail.
231
+ - Avoid repeating the same idea in different words.
232
+
233
+ ### Clinical Implications / Recommendations
234
+ - Include 2 to 4 bullet points only if supported by the evidence.
235
+ - Focus on practical interpretation, management implications, follow-up considerations, or next steps.
236
+ - If the evidence supports recognition or framing rather than action, say that clearly.
237
+ - Do not recommend interventions not supported by the evidence.
238
+
239
+ ### Limitations of the Evidence
240
+ - State clearly what the evidence does not establish, does not cover, or leaves uncertain.
241
+ - Explicitly note when details are lacking on:
242
+ treatment, diagnosis, prognosis, genetics, monitoring, recurrence prevention, comparative effectiveness, or long-term outcomes.
243
+ - If the evidence is narrow, low-detail, or only partially aligned with the question, say so plainly.
244
+
245
+ SPECIAL INSTRUCTIONS BY QUESTION TYPE
246
+
247
+ For treatment questions:
248
+ - Focus primarily on treatment and management, not disease definition.
249
+ - Organize treatment information in this order whenever supported by the evidence:
250
+ 1. supportive or conservative care
251
+ 2. symptomatic drug therapy or procedural treatment
252
+ 3. long-term prevention, follow-up, or recurrence prevention
253
+ - Distinguish treatment of active symptoms from prevention of recurrence or complications.
254
+ - If the condition is benign, self-limited, or often does not require treatment, state that clearly in the first sentence.
255
+
256
+ For diagnosis or investigation questions:
257
+ - Focus on how the condition is identified, evaluated, or differentiated.
258
+ - Prioritize diagnostic features, testing approach, and clinically useful distinctions.
259
+ - Do not drift into treatment unless the evidence clearly supports it and it helps answer the question.
260
+
261
+ For cause or risk questions:
262
+ - Focus on etiologies, risk factors, mechanisms, or associations supported by the evidence.
263
+ - Distinguish established causes from possible contributors if the evidence is less certain.
264
+
265
+ For prognosis questions:
266
+ - Focus on expected course, complications, recurrence, or outcome-related information supported by the evidence.
267
+ - Do not add prognostic claims not explicitly supported.
268
+
269
+ QUALITY CHECK BEFORE OUTPUT
270
+ Before finalizing, ensure that:
271
+ - the first sentence directly answers the question
272
+ - the response matches the user's primary intent
273
+ - every important claim is grounded in the provided material
274
+ - no unsupported medical detail has been added
275
+ - the Limitations section honestly reflects evidence gaps
276
+
277
+ If these conditions cannot be met, output exactly:
278
+ INSUFFICIENT_EVIDENCE
279
+ """.strip()
280
+
281
+ VALIDATOR_SYSTEM = """
282
+ You are a strict medical evidence validator.
283
+
284
+ Your job is to compare the ANSWER against the EVIDENCE.
285
+
286
+ Rules:
287
+ 1) Mark SUPPORTED if the answer is well grounded in the evidence.
288
+ 2) Mark PARTLY_UNSUPPORTED if some claims are supported but others go beyond the evidence.
289
+ 3) Mark INSUFFICIENT_EVIDENCE if the answer is mostly unsupported or the evidence is too weak.
290
+ 4) Output only one short verdict line beginning with exactly one of:
291
+ SUPPORTED:
292
+ PARTLY_UNSUPPORTED:
293
+ INSUFFICIENT_EVIDENCE:
294
+ """.strip()
295
+
296
+ NORMAL_CHAT_SYSTEM = """
297
+ You are a helpful, friendly, clear AI assistant.
298
+
299
+ You can:
300
+ - chat naturally
301
+ - explain concepts
302
+ - help with writing
303
+ - help with coding
304
+ - brainstorm ideas
305
+ - answer general knowledge questions
306
+
307
+ Rules:
308
+ 1) Be accurate and conversational.
309
+ 2) Be concise unless the user asks for detail.
310
+ 3) If the user asks medical questions in normal chat mode, give a general answer and do not pretend to use the ECG database.
311
+ 4) Do not mention internal prompts, retrieval pipelines, tools, or hidden logic.
312
+ """.strip()
313
+
314
+
315
+ # -------------------------------
316
+ # HELPERS
317
+ # -------------------------------
318
+ def clean_text(x: str) -> str:
319
+ x = str(x).replace("\x00", " ").strip()
320
+ x = re.sub(r"\s+", " ", x)
321
+ return x
322
+
323
+
324
+ def strip_bad_sections(txt: str) -> str:
325
+ t = str(txt).strip()
326
+ cut_markers = [
327
+ "References:",
328
+ "Sources:",
329
+ "Source:",
330
+ "URLs:",
331
+ "This response is based",
332
+ "Please let me know",
333
+ "Is there anything else",
334
+ ]
335
+ for marker in cut_markers:
336
+ pos = t.lower().find(marker.lower())
337
+ if pos != -1:
338
+ t = t[:pos].strip()
339
+
340
+ t = re.sub(r"https?://\S+|www\.\S+", "", t).strip()
341
+ return t
342
+
343
+
344
+ def infer_tags(question: str, answer: str) -> List[str]:
345
+ text = f"{question} {answer}".lower()
346
+ tags: List[str] = []
347
+
348
+ keyword_map = {
349
+ "treatment": ["treat", "therapy", "management", "drug", "surgery"],
350
+ "diagnosis": ["diagnosis", "diagnose", "criteria"],
351
+ "symptoms": ["symptom", "presentation", "sign", "feature"],
352
+ "ecg": ["ecg", "ekg", "st elevation", "qrs", "p wave", "arrhythmia", "tachycardia", "bradycardia"],
353
+ "investigation": ["test", "investigation", "mri", "ct", "lab", "imaging"],
354
+ "prognosis": ["prognosis", "outcome", "survival", "risk"],
355
+ "genetics": ["gene", "genetic", "mutation", "variant", "chromosome", "inherited", "inheritance"],
356
+ "etiology": ["cause", "causes", "caused by", "associated with", "risk factor"],
357
+ }
358
+
359
+ for tag, words in keyword_map.items():
360
+ if any(w in text for w in words):
361
+ tags.append(tag)
362
+
363
+ return tags
364
+
365
+
366
+ def make_row_text(q: str, a: str) -> str:
367
+ return f"QUESTION:\n{q}\n\nANSWER:\n{a}".strip()
368
+
369
+
370
+ def score_to_similarity(raw_score: float) -> float:
371
+ try:
372
+ raw_score = float(raw_score)
373
+ except Exception:
374
+ return -1.0
375
+ return 1.0 / (1.0 + max(raw_score, 0.0))
376
+
377
+
378
+ def lexical_overlap(query: str, text: str) -> float:
379
+ q_words = set(re.findall(r"\w+", query.lower()))
380
+ t_words = set(re.findall(r"\w+", text.lower()))
381
+ if not q_words:
382
+ return 0.0
383
+ return len(q_words & t_words) / max(1, len(q_words))
384
+
385
+
386
+ def rerank_docs(query: str, docs: List[Document], top_n: Optional[int] = None) -> List[Document]:
387
+ if top_n is None:
388
+ top_n = cfg.top_k_final
389
+
390
+ q_words = set(re.findall(r"\w+", query.lower()))
391
+ scored = []
392
+
393
+ for d in docs:
394
+ question = d.metadata.get("question", "")
395
+ answer = d.metadata.get("answer", "")
396
+ tags = " ".join(d.metadata.get("tags", []))
397
+ text = f"{question} {answer} {tags}".lower()
398
+
399
+ t_words = set(re.findall(r"\w+", text))
400
+ overlap = len(q_words & t_words) / max(1, len(q_words))
401
+ question_boost = 0.20 if any(w in question.lower() for w in q_words) else 0.0
402
+ tag_boost = 0.10 if any(w in tags.lower() for w in q_words) else 0.0
403
+ sim_score = float(d.metadata.get("sim_score", 0.0))
404
+
405
+ final_score = overlap + question_boost + tag_boost + (0.35 * sim_score)
406
+ scored.append((d, final_score))
407
+
408
+ scored.sort(key=lambda x: x[1], reverse=True)
409
+ return [d for d, _ in scored[:top_n]]
410
+
411
+
412
+ def history_to_text(chat_history: List[Dict[str, str]], max_turns: Optional[int] = None) -> str:
413
+ if max_turns is None:
414
+ max_turns = cfg.max_chat_history_turns
415
+
416
+ items = chat_history[-max_turns:]
417
+ if not items:
418
+ return "[EMPTY]"
419
+
420
+ return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in items]).strip()
421
+
422
+
423
+ def build_context_string(docs: List[Document], max_chars: Optional[int] = None) -> str:
424
+ if max_chars is None:
425
+ max_chars = cfg.max_context_chars
426
 
427
+ blocks = []
428
+ total = 0
429
 
430
+ for i, d in enumerate(docs, 1):
431
+ q = d.metadata.get("question", "")
432
+ a = d.metadata.get("answer", "")
433
+ tags = ", ".join(d.metadata.get("tags", [])) or "N/A"
434
+ sim = d.metadata.get("sim_score", None)
435
 
436
+ block = f"""
437
+ ==============================
438
+ EVIDENCE_ID: {i}
439
+ SOURCE_ID: {d.metadata.get('id')}
440
+ SOURCE_QUESTION: {q}
441
+ SOURCE_TAGS: {tags}
442
+ SIMILARITY: {sim if sim is not None else 'N/A'}
443
+ EVIDENCE_TEXT:
444
+ {a}
445
+ ==============================
446
+ """.strip()
447
 
448
+ if total + len(block) > max_chars:
449
+ break
450
+
451
+ blocks.append(block)
452
+ total += len(block) + 2
453
+
454
+ return "\n\n".join(blocks).strip()
455
+
456
+
457
+ def compute_confidence(result: Dict) -> float:
458
+ best_score = result.get("best_score", -1.0)
459
+ validation = result.get("validation_status", "")
460
+
461
+ if validation.startswith("SUPPORTED"):
462
+ conf = best_score
463
+ elif validation.startswith("PARTLY_UNSUPPORTED"):
464
+ conf = best_score * 0.70
465
+ else:
466
+ conf = best_score * 0.40
467
+
468
+ return max(0.0, min(1.0, conf))
469
+
470
+
471
+ def strong_retrieval(best_score: float, docs: List[Document]) -> bool:
472
+ return (
473
+ best_score >= cfg.strong_retrieval_threshold
474
+ and len(docs) >= cfg.strong_retrieval_min_docs
475
+ )
476
+
477
+
478
+ def stream_text(text: str, step: int = 110):
479
+ acc = ""
480
+ for i in range(0, len(text), step):
481
+ acc += text[i:i + step]
482
+ yield acc
483
+
484
+
485
+ # -------------------------------
486
+ # AUTO MODE SWITCH DETECTION
487
+ # -------------------------------
488
+ ECG_MODE_PATTERNS = [
489
+ r"\becg\b",
490
+ r"\bekg\b",
491
+ r"\bcardiology\b",
492
+ r"\bcardio\b",
493
+ r"\barrhythmia\b",
494
+ r"\bheart rhythm\b",
495
+ r"\becg mode\b",
496
+ r"\bcardiology mode\b",
497
+ r"\bmedical mode\b",
498
+ ]
499
+
500
+ ECG_SWITCH_PHRASES = [
501
+ r"switch to ecg",
502
+ r"switch into ecg",
503
+ r"switch to cardiology",
504
+ r"switch into cardiology",
505
+ r"switch to ecg and cardiology",
506
+ r"switch into ecg and cardiology",
507
+ r"ecg and cardiology",
508
+ r"medical ecg cardiology",
509
+ r"i want to ask ecg",
510
+ r"i want to ask ecr",
511
+ r"i want ecg",
512
+ r"ecg questions",
513
+ r"cardiology questions",
514
+ r"ecg only",
515
+ r"cardiology only",
516
+ r"activate ecg",
517
+ r"activate cardiology",
518
+ ]
519
+
520
+ NORMAL_SWITCH_PHRASES = [
521
+ r"switch to normal",
522
+ r"normal chat",
523
+ r"back to normal",
524
+ r"exit ecg",
525
+ r"leave ecg mode",
526
+ r"turn off ecg mode",
527
+ ]
528
+
529
+
530
+ def normalize_user_text(text: str) -> str:
531
+ text = str(text or "").lower().strip()
532
+ text = re.sub(r"\s+", " ", text)
533
+ return text
534
+
535
+
536
+ def detect_mode_switch_request(user_message: str) -> Optional[str]:
537
+ text = normalize_user_text(user_message)
538
+
539
+ for pat in NORMAL_SWITCH_PHRASES:
540
+ if re.search(pat, text):
541
+ return "normal_chat"
542
+
543
+ strong_switch = any(re.search(pat, text) for pat in ECG_SWITCH_PHRASES)
544
+ ecg_present = any(re.search(pat, text) for pat in ECG_MODE_PATTERNS)
545
+
546
+ if strong_switch or (
547
+ ("switch" in text or "mode" in text or "questions" in text or "related" in text)
548
+ and ecg_present
549
  ):
550
+ return "ecg_rag"
 
 
 
551
 
552
+ return None
 
553
 
554
 
555
+ def mode_switch_message(mode_value: str) -> str:
556
+ if mode_value == "ecg_rag":
557
+ return (
558
+ "❤️ **ECG & Cardiology Mode activated**\n\n"
559
+ "UI updated successfully.\n"
560
+ "Ready for **medical, ECG, and cardiology** questions."
561
+ )
562
+ return (
563
+ "💬 **Normal Chat Mode activated**\n\n"
564
+ "UI updated successfully.\n"
565
+ "Ready for general conversation again."
566
+ )
567
+
568
+
569
+ # -------------------------------
570
+ # EMBEDDINGS + VECTORSTORE
571
+ # -------------------------------
572
+ logger.info("Loading embeddings...")
573
+ embeddings = HuggingFaceEmbeddings(
574
+ model_name=cfg.embed_model_name,
575
+ model_kwargs={
576
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
577
+ "token": None,
578
+ },
579
+ encode_kwargs={"normalize_embeddings": True},
580
+ )
581
+
582
+
583
+ def build_vectorstore():
584
+ logger.info(f"Reading CSV: {cfg.data_csv}")
585
+ df = pd.read_csv(cfg.data_csv)
586
+ df.columns = [c.strip().lower() for c in df.columns]
587
+
588
+ required = {"instruction", "response"}
589
+ if not required.issubset(df.columns):
590
+ raise ValueError(f"CSV must contain columns {required}. Found: {df.columns.tolist()}")
591
+
592
+ df = df[["instruction", "response"]].dropna().reset_index(drop=True)
593
+ df["instruction"] = df["instruction"].map(clean_text)
594
+ df["response"] = df["response"].map(clean_text)
595
+
596
+ docs = []
597
+ for i, row in df.iterrows():
598
+ q = row["instruction"]
599
+ a = row["response"]
600
+ docs.append(
601
+ Document(
602
+ page_content=make_row_text(q, a),
603
+ metadata={
604
+ "id": int(i),
605
+ "question": q,
606
+ "answer": a,
607
+ "tags": infer_tags(q, a),
608
+ }
609
+ )
610
+ )
611
+
612
+ vectorstore_local = FAISS.from_documents(docs, embeddings)
613
+ vectorstore_local.save_local(cfg.vectorstore_dir)
614
+ logger.info(f"Saved vectorstore with {len(docs)} docs to {cfg.vectorstore_dir}")
615
+
616
+
617
+ def load_vectorstore():
618
+ return FAISS.load_local(
619
+ cfg.vectorstore_dir,
620
+ embeddings,
621
+ allow_dangerous_deserialization=True,
622
+ )
623
+
624
+
625
+ if not os.path.exists(cfg.vectorstore_dir):
626
+ logger.info("Vectorstore not found. Building from CSV...")
627
+ build_vectorstore()
628
+
629
+ vectorstore = load_vectorstore()
630
+ logger.info("Vectorstore ready.")
631
+
632
+
633
+ # -------------------------------
634
+ # LOCAL MODEL + ECG ADAPTER
635
+ # -------------------------------
636
+ logger.info("Loading tokenizer...")
637
+ tokenizer = AutoTokenizer.from_pretrained(
638
+ cfg.base_model_path,
639
+ use_fast=True,
640
+ token=cfg.hf_token if cfg.hf_token else None
641
+ )
642
+
643
+ if tokenizer.pad_token is None:
644
+ tokenizer.pad_token = tokenizer.eos_token
645
+
646
+ logger.info("Loading base model...")
647
+ has_cuda = torch.cuda.is_available()
648
+ base_model = None
649
+
650
+ if cfg.use_4bit and has_cuda:
651
+ try:
652
+ bnb_config = BitsAndBytesConfig(
653
+ load_in_4bit=True,
654
+ bnb_4bit_compute_dtype=torch.float16,
655
+ bnb_4bit_quant_type="nf4",
656
+ bnb_4bit_use_double_quant=True,
657
+ )
658
+ base_model = AutoModelForCausalLM.from_pretrained(
659
+ cfg.base_model_path,
660
+ device_map="auto",
661
+ quantization_config=bnb_config,
662
+ torch_dtype=torch.float16,
663
+ token=cfg.hf_token if cfg.hf_token else None,
664
+ )
665
+ logger.info("Loaded base model in 4-bit mode.")
666
+ except Exception as e:
667
+ logger.warning(f"4-bit load failed: {e}")
668
+
669
+ if base_model is None:
670
+ dtype = torch.float16 if has_cuda else torch.float32
671
+ base_model = AutoModelForCausalLM.from_pretrained(
672
+ cfg.base_model_path,
673
+ device_map="auto" if has_cuda else None,
674
+ torch_dtype=dtype,
675
+ token=cfg.hf_token if cfg.hf_token else None,
676
+ )
677
+ if not has_cuda:
678
+ base_model = base_model.to("cpu")
679
+ logger.info("Loaded base model without 4-bit.")
680
+
681
+ base_model.eval()
682
+
683
+ logger.info("Loading ECG reasoning adapter...")
684
+ reason_model = PeftModel.from_pretrained(base_model, cfg.adapter_dir)
685
+ reason_model.eval()
686
+
687
+
688
+ def get_primary_model_device(model) -> torch.device:
689
+ try:
690
+ return next(model.parameters()).device
691
+ except StopIteration:
692
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
693
+
694
+
695
+ @torch.inference_mode()
696
+ def run_local_reasoner(user_query: str, context: str) -> str:
697
+ try:
698
+ messages = [
699
+ {"role": "system", "content": LOCAL_REASONING_SYSTEM},
700
+ {
701
+ "role": "user",
702
+ "content": f"QUESTION:\n{user_query}\n\nEVIDENCE:\n{context if context.strip() else '[EMPTY]'}"
703
+ },
704
+ ]
705
+
706
+ prompt = tokenizer.apply_chat_template(
707
+ messages,
708
+ tokenize=False,
709
+ add_generation_prompt=True,
710
+ )
711
+
712
+ inputs = tokenizer(
713
+ prompt,
714
+ return_tensors="pt",
715
+ truncation=True,
716
+ max_length=cfg.max_input_len,
717
+ )
718
+
719
+ model_device = get_primary_model_device(reason_model)
720
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
721
+
722
+ out = reason_model.generate(
723
+ **inputs,
724
+ max_new_tokens=cfg.max_new_tokens_local,
725
+ do_sample=False,
726
+ use_cache=True,
727
+ repetition_penalty=1.08,
728
+ no_repeat_ngram_size=3,
729
+ pad_token_id=tokenizer.eos_token_id,
730
+ eos_token_id=tokenizer.eos_token_id,
731
+ )
732
+
733
+ gen_ids = out[0, inputs["input_ids"].shape[1]:]
734
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
735
+ text = strip_bad_sections(text)
736
+
737
+ return text if text else "INSUFFICIENT_EVIDENCE"
738
+
739
+ except Exception as e:
740
+ logger.error(f"Local reasoner error: {e}")
741
+ traceback.print_exc()
742
+ return "INSUFFICIENT_EVIDENCE"
743
+
744
+
745
+ # -------------------------------
746
+ # REMOTE LLM (DEEPSEEK)
747
+ # -------------------------------
748
+ deepseek_llm = ChatOpenAI(
749
+ model=cfg.deepseek_model,
750
+ api_key=cfg.deepseek_api_key,
751
+ base_url=cfg.deepseek_base_url,
752
+ temperature=cfg.deepseek_temperature,
753
+ max_tokens=cfg.deepseek_max_tokens,
754
+ )
755
+
756
+ _query_expansion_cache: Dict[str, str] = {}
757
+
758
+
759
+ def llm_text(system_prompt: str, user_prompt: str, fallback: str = "INSUFFICIENT_EVIDENCE") -> str:
760
+ try:
761
+ resp = deepseek_llm.invoke([
762
+ {"role": "system", "content": system_prompt},
763
+ {"role": "user", "content": user_prompt},
764
+ ])
765
+ text = resp.content if hasattr(resp, "content") else str(resp)
766
+ text = strip_bad_sections(text)
767
+ return text if text.strip() else fallback
768
+ except Exception as e:
769
+ logger.error(f"DeepSeek error: {e}")
770
+ traceback.print_exc()
771
+ return fallback
772
+
773
+
774
+ def run_query_expansion(user_query: str) -> str:
775
+ if not cfg.enable_query_expansion:
776
+ return user_query
777
+
778
+ if cfg.use_query_cache and user_query in _query_expansion_cache:
779
+ logger.info(f"Using cached expansion for: {user_query[:80]}")
780
+ return _query_expansion_cache[user_query]
781
+
782
+ prompt = f"""
783
+ USER_QUERY:
784
+ {user_query}
785
+
786
+ Expand this for retrieval with close medical phrasing, synonyms, and alternate wording.
787
+ Do not answer the question.
788
+ """.strip()
789
+
790
+ expanded = llm_text(QUERY_EXPANSION_SYSTEM, prompt, fallback=user_query)
791
+ expanded = expanded.strip() if expanded else user_query
792
+
793
+ if cfg.use_query_cache:
794
+ _query_expansion_cache[user_query] = expanded
795
+
796
+ return expanded
797
+
798
+
799
+ def run_deepseek_summary(
800
+ user_query: str,
801
+ context: str,
802
+ reasoning_draft: str,
803
+ chat_history: List[Dict[str, str]],
804
+ ) -> str:
805
+ prompt = f"""
806
+ CHAT_HISTORY:
807
+ {history_to_text(chat_history)}
808
+
809
+ USER_QUESTION:
810
+ {user_query}
811
+
812
+ RETRIEVED_EVIDENCE:
813
+ {context if context.strip() else '[EMPTY]'}
814
+
815
+ LOCAL_REASONING_DRAFT:
816
+ {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'}
817
+
818
+ Write a grounded final summary answer using only the evidence and reasoning draft.
819
+ """.strip()
820
+
821
+ return llm_text(
822
+ DEEPSEEK_SUMMARY_SYSTEM,
823
+ prompt,
824
+ fallback="I could not generate a grounded summary from the retrieved evidence."
825
+ )
826
+
827
+
828
+ def run_validator(context: str, answer: str) -> str:
829
+ if not cfg.enable_validator:
830
+ return "SUPPORTED (validator disabled)"
831
+
832
+ prompt = f"""
833
+ EVIDENCE:
834
+ {context if context.strip() else '[EMPTY]'}
835
+
836
+ ANSWER:
837
+ {answer if answer.strip() else '[EMPTY]'}
838
+ """.strip()
839
+
840
+ return llm_text(VALIDATOR_SYSTEM, prompt, fallback="PARTLY_UNSUPPORTED: validator unavailable")
841
+
842
+
843
+ def run_normal_chat(user_query: str, chat_history: List[Dict[str, str]]) -> str:
844
+ prompt = f"""
845
+ CHAT_HISTORY:
846
+ {history_to_text(chat_history)}
847
+
848
+ USER_MESSAGE:
849
+ {user_query}
850
+
851
+ Respond as a normal helpful chatbot.
852
+ """.strip()
853
+
854
+ return llm_text(
855
+ NORMAL_CHAT_SYSTEM,
856
+ prompt,
857
+ fallback="Sorry, I could not generate a response."
858
+ )
859
+
860
+
861
+ # -------------------------------
862
+ # WARMUP
863
+ # -------------------------------
864
+ def warmup_models():
865
+ logger.info("Warming up local reasoner...")
866
+ try:
867
+ _ = run_local_reasoner(
868
+ "What are ECG findings in hyperkalemia?",
869
+ """
870
+ ==============================
871
+ EVIDENCE_ID: 1
872
+ SOURCE_QUESTION: What are ECG findings in hyperkalemia?
873
+ SOURCE_TAGS: ecg
874
+ EVIDENCE_TEXT:
875
+ Hyperkalemia may cause peaked T waves, PR prolongation, QRS widening, and severe conduction abnormalities.
876
+ ==============================
877
+ """.strip(),
878
+ )
879
+ logger.info("Warmup completed.")
880
+ except Exception as e:
881
+ logger.warning(f"Warmup failed: {e}")
882
+
883
+
884
+ warmup_models()
885
+
886
+
887
+ # -------------------------------
888
+ # STATE
889
+ # -------------------------------
890
+ class ChatState(TypedDict, total=False):
891
+ user_query: str
892
+ expanded_query: str
893
+ chat_history: List[Dict[str, str]]
894
+
895
+ retrieved_docs: List[Document]
896
+ best_score: float
897
+ used_context: bool
898
+ context: str
899
+ retrieval_attempts: int
900
+ retrieval_mode: str
901
+
902
+ reasoning_draft: str
903
+ final_answer: str
904
+ validation_status: str
905
+
906
+
907
+ # -------------------------------
908
+ # RETRIEVAL
909
+ # -------------------------------
910
+ def retrieve_docs_once(query_for_search: str, original_query: str):
911
+ try:
912
+ scored = vectorstore.similarity_search_with_score(
913
+ query_for_search,
914
+ k=cfg.similarity_k,
915
+ )
916
+ except Exception as e:
917
+ logger.error(f"Retriever error: {e}")
918
+ traceback.print_exc()
919
+ return [], -1.0
920
+
921
+ if not scored:
922
+ return [], -1.0
923
+
924
+ filtered_docs = []
925
+ best_score = -1.0
926
+
927
+ for doc, raw_score in scored:
928
+ sim = score_to_similarity(raw_score)
929
+ best_score = max(best_score, sim)
930
+
931
+ q = doc.metadata.get("question", "")
932
+ a = doc.metadata.get("answer", "")
933
+ ov = lexical_overlap(original_query, f"{q} {a}")
934
+
935
+ if ov >= cfg.min_lexical_overlap and sim >= cfg.min_faiss_similarity:
936
+ new_doc = Document(page_content=doc.page_content, metadata=dict(doc.metadata))
937
+ new_doc.metadata["sim_score"] = sim
938
+ new_doc.metadata["lexical_overlap"] = ov
939
+ filtered_docs.append(new_doc)
940
+
941
+ reranked = rerank_docs(original_query, filtered_docs, top_n=cfg.top_k_final)
942
+ return reranked, best_score
943
+
944
+
945
+ # -------------------------------
946
+ # LANGGRAPH NODES
947
+ # -------------------------------
948
+ def retrieve_node(state: ChatState) -> ChatState:
949
+ query = state.get("expanded_query") or state["user_query"]
950
+ retrieval_attempts = int(state.get("retrieval_attempts", 0)) + 1
951
+ retrieval_mode = "expanded" if state.get("expanded_query") else "original"
952
+
953
+ docs, best_score = retrieve_docs_once(
954
+ query_for_search=query,
955
+ original_query=state["user_query"],
956
+ )
957
+
958
+ if not docs:
959
+ return {
960
+ "retrieved_docs": [],
961
+ "best_score": best_score,
962
+ "used_context": False,
963
+ "context": "",
964
+ "retrieval_attempts": retrieval_attempts,
965
+ "retrieval_mode": retrieval_mode,
966
+ }
967
+
968
+ return {
969
+ "retrieved_docs": docs,
970
+ "best_score": best_score,
971
+ "used_context": True,
972
+ "context": build_context_string(docs, max_chars=cfg.max_context_chars),
973
+ "retrieval_attempts": retrieval_attempts,
974
+ "retrieval_mode": retrieval_mode,
975
+ }
976
+
977
+
978
+ def should_retry_retrieval(state: ChatState) -> str:
979
+ used_context = state.get("used_context", False)
980
+ best_score = state.get("best_score", -1.0)
981
+ attempts = int(state.get("retrieval_attempts", 0))
982
+
983
+ if used_context and best_score >= cfg.min_faiss_similarity:
984
+ return "local_reasoning"
985
+
986
+ if not cfg.enable_query_expansion:
987
+ return "local_reasoning"
988
+
989
+ if attempts >= 2:
990
+ return "local_reasoning"
991
+
992
+ return "expand_query"
993
+
994
+
995
+ def expand_query_node(state: ChatState) -> ChatState:
996
+ expanded = run_query_expansion(state["user_query"])
997
+ if not expanded.strip():
998
+ expanded = state["user_query"]
999
+ return {"expanded_query": expanded}
1000
+
1001
+
1002
+ def local_reasoning_node(state: ChatState) -> ChatState:
1003
+ context = state.get("context", "").strip()
1004
+ if not context:
1005
+ return {"reasoning_draft": "INSUFFICIENT_EVIDENCE"}
1006
+
1007
+ reasoning = run_local_reasoner(state["user_query"], context)
1008
+ return {"reasoning_draft": reasoning}
1009
+
1010
+
1011
+ def generate_node(state: ChatState) -> ChatState:
1012
+ context = state.get("context", "").strip()
1013
+ reasoning = state.get("reasoning_draft", "INSUFFICIENT_EVIDENCE")
1014
+ history = state.get("chat_history", [])
1015
+
1016
+ if not context:
1017
+ return {"final_answer": "I could not find sufficiently relevant evidence in the RAG database for this question."}
1018
+
1019
+ answer = run_deepseek_summary(
1020
+ user_query=state["user_query"],
1021
+ context=context,
1022
+ reasoning_draft=reasoning,
1023
+ chat_history=history,
1024
+ )
1025
+ return {"final_answer": answer}
1026
+
1027
+
1028
+ def validate_node(state: ChatState) -> ChatState:
1029
+ context = state.get("context", "").strip()
1030
+ answer = state.get("final_answer", "").strip()
1031
+ best_score = state.get("best_score", -1.0)
1032
+ docs = state.get("retrieved_docs", [])
1033
+
1034
+ if not context or not answer:
1035
+ return {"validation_status": "INSUFFICIENT_EVIDENCE: missing context or answer"}
1036
+
1037
+ if strong_retrieval(best_score, docs):
1038
+ return {"validation_status": "SUPPORTED (validator skipped due to strong retrieval)"}
1039
+
1040
+ verdict = run_validator(context, answer)
1041
+
1042
+ if verdict.startswith("SUPPORTED"):
1043
+ return {"validation_status": verdict}
1044
+
1045
+ if verdict.startswith("PARTLY_UNSUPPORTED"):
1046
+ return {
1047
+ "validation_status": verdict,
1048
+ "final_answer": answer + "\n\nEvidence limits: some parts may not be fully supported by the retrieved evidence."
1049
+ }
1050
+
1051
+ if verdict.startswith("INSUFFICIENT_EVIDENCE"):
1052
+ return {
1053
+ "validation_status": verdict,
1054
+ "final_answer": answer + "\n\nEvidence limits: the retrieved evidence was weak or only partially relevant."
1055
+ }
1056
+
1057
+ return {"validation_status": verdict}
1058
+
1059
+
1060
+ def finalize_node(state: ChatState) -> ChatState:
1061
+ answer = strip_bad_sections(state.get("final_answer", ""))
1062
+ if not answer:
1063
+ answer = "I could not generate an answer."
1064
+ return {"final_answer": answer}
1065
+
1066
+
1067
+ # -------------------------------
1068
+ # GRAPH
1069
+ # -------------------------------
1070
+ builder = StateGraph(ChatState)
1071
+ builder.add_node("retrieve", retrieve_node)
1072
+ builder.add_node("expand_query", expand_query_node)
1073
+ builder.add_node("local_reasoning", local_reasoning_node)
1074
+ builder.add_node("generate", generate_node)
1075
+ builder.add_node("validate", validate_node)
1076
+ builder.add_node("finalize", finalize_node)
1077
+
1078
+ builder.add_edge(START, "retrieve")
1079
+ builder.add_conditional_edges(
1080
+ "retrieve",
1081
+ should_retry_retrieval,
1082
+ {
1083
+ "expand_query": "expand_query",
1084
+ "local_reasoning": "local_reasoning",
1085
+ }
1086
  )
1087
+ builder.add_edge("expand_query", "retrieve")
1088
+ builder.add_edge("local_reasoning", "generate")
1089
+ builder.add_edge("generate", "validate")
1090
+ builder.add_edge("validate", "finalize")
1091
+ builder.add_edge("finalize", END)
1092
+
1093
+ graph = builder.compile()
1094
+ logger.info("LangGraph compiled.")
1095
+
1096
+
1097
+ # -------------------------------
1098
+ # FORMATTING HELPERS
1099
+ # -------------------------------
1100
+ def format_sources_minimal(result: Optional[Dict], chat_mode: str = "ecg_rag") -> str:
1101
+ if chat_mode == "normal_chat":
1102
+ return "## Retrieved Sources\n\nNormal chat mode is active. No ECG evidence retrieval used."
1103
+
1104
+ if not result:
1105
+ return "## Retrieved Sources\n\nNo sources yet."
1106
+
1107
+ docs = result.get("retrieved_docs", [])
1108
+ best_score = result.get("best_score", -1.0)
1109
+
1110
+ if not docs:
1111
+ return (
1112
+ "## Retrieved Sources\n\n"
1113
+ "No sufficiently relevant evidence retrieved.\n\n"
1114
+ f"**Best score:** `{best_score:.3f}`"
1115
+ )
1116
+
1117
+ lines = [
1118
+ "## Retrieved Sources",
1119
+ f"**Best score:** `{best_score:.3f}`",
1120
+ "",
1121
+ ]
1122
+
1123
+ for i, d in enumerate(docs, 1):
1124
+ question = d.metadata.get("question", "")
1125
+ answer = d.metadata.get("answer", "")
1126
+ similarity = d.metadata.get("sim_score", "N/A")
1127
+ preview = answer[:210].strip()
1128
+ if len(answer) > 210:
1129
+ preview += "..."
1130
+
1131
+ lines.extend([
1132
+ f"### Evidence {i}",
1133
+ f"- **Question:** {question}",
1134
+ f"- **Similarity:** `{similarity}`",
1135
+ f"- **Preview:** {preview}",
1136
+ "",
1137
+ ])
1138
+
1139
+ return "\n".join(lines)
1140
+
1141
+
1142
+ def format_debug_text(result: Optional[Dict], chat_mode: str = "ecg_rag") -> str:
1143
+ if chat_mode == "normal_chat":
1144
+ return "MODE: normal_chat\nNo retrieval/debug evidence used."
1145
+
1146
+ if not result:
1147
+ return "No debug result yet."
1148
+
1149
+ return f"""
1150
+ BEST SCORE: {result.get('best_score', -1.0)}
1151
+ USED CONTEXT: {result.get('used_context', False)}
1152
+ RETRIEVAL ATTEMPTS: {result.get('retrieval_attempts', 0)}
1153
+ RETRIEVAL MODE: {result.get('retrieval_mode', 'N/A')}
1154
+ VALIDATION STATUS: {result.get('validation_status', 'N/A')}
1155
+
1156
+ ----- CONTEXT -----
1157
+ {result.get('context', '')}
1158
+
1159
+ ----- LOCAL REASONING DRAFT -----
1160
+ {result.get('reasoning_draft', '')}
1161
+ """.strip()
1162
+
1163
+
1164
+ # -------------------------------
1165
+ # UI HELPERS
1166
+ # -------------------------------
1167
+ CUSTOM_CSS = """
1168
+ :root {
1169
+ --bg-main: #07111f;
1170
+ --bg-soft: #0b1728;
1171
+ --card: rgba(10, 19, 35, 0.86);
1172
+ --card-2: rgba(14, 25, 43, 0.94);
1173
+ --border: rgba(148, 163, 184, 0.16);
1174
+ --text: #e5eefb;
1175
+ --muted: #94a3b8;
1176
+ --primary: #7c3aed;
1177
+ --primary-2: #2563eb;
1178
+ --success: #10b981;
1179
+ }
1180
+
1181
+ html, body, .gradio-container {
1182
+ margin: 0 !important;
1183
+ padding: 0 !important;
1184
+ min-height: 100%;
1185
+ background:
1186
+ radial-gradient(circle at top left, rgba(124,58,237,0.22), transparent 28%),
1187
+ radial-gradient(circle at top right, rgba(37,99,235,0.18), transparent 24%),
1188
+ linear-gradient(180deg, #050b16 0%, #091321 100%);
1189
+ color: var(--text);
1190
+ }
1191
+
1192
+ .gradio-container {
1193
+ max-width: 100% !important;
1194
+ padding: 12px !important;
1195
+ }
1196
+
1197
+ footer {
1198
+ visibility: hidden;
1199
+ }
1200
+
1201
+ .top-card {
1202
+ border: 1px solid var(--border);
1203
+ background: linear-gradient(135deg, rgba(11,23,40,0.95), rgba(18,31,56,0.92));
1204
+ border-radius: 22px;
1205
+ padding: 16px;
1206
+ margin-bottom: 12px;
1207
+ box-shadow: 0 14px 40px rgba(0,0,0,0.20);
1208
+ }
1209
+
1210
+ .hero-title {
1211
+ font-size: 1.6rem;
1212
+ font-weight: 800;
1213
+ color: #f8fbff;
1214
+ margin-bottom: 6px;
1215
+ line-height: 1.15;
1216
+ }
1217
+
1218
+ .hero-subtitle {
1219
+ color: #cbd5e1;
1220
+ font-size: 0.95rem;
1221
+ line-height: 1.5;
1222
+ }
1223
+
1224
+ .badges {
1225
+ display: flex;
1226
+ gap: 8px;
1227
+ flex-wrap: wrap;
1228
+ margin-top: 12px;
1229
+ }
1230
+
1231
+ .badge {
1232
+ display: inline-flex;
1233
+ align-items: center;
1234
+ gap: 6px;
1235
+ padding: 6px 10px;
1236
+ border-radius: 999px;
1237
+ font-size: 11px;
1238
+ color: #e6eefc;
1239
+ border: 1px solid rgba(255,255,255,0.12);
1240
+ background: rgba(255,255,255,0.06);
1241
+ }
1242
+
1243
+ .panel-wrap {
1244
+ border: 1px solid var(--border);
1245
+ background: linear-gradient(180deg, rgba(10,19,35,0.96), rgba(7,14,26,0.94));
1246
+ border-radius: 20px;
1247
+ padding: 12px;
1248
+ box-shadow: 0 16px 45px rgba(0,0,0,0.22);
1249
+ }
1250
+
1251
+ #chatbot {
1252
+ height: min(62vh, 640px) !important;
1253
+ min-height: 360px !important;
1254
+ border-radius: 18px !important;
1255
+ border: 1px solid var(--border) !important;
1256
+ overflow: hidden !important;
1257
+ box-shadow: 0 14px 40px rgba(0,0,0,0.26) !important;
1258
+ }
1259
+
1260
+ .status-card {
1261
+ padding: 12px 14px;
1262
+ border-radius: 16px;
1263
+ background: linear-gradient(135deg, #0f172a 0%, #172554 100%);
1264
+ color: #f9fafb;
1265
+ font-size: 14px;
1266
+ border: 1px solid rgba(255,255,255,0.12);
1267
+ box-shadow: 0 10px 30px rgba(0,0,0,0.2);
1268
+ }
1269
+
1270
+ .muted {
1271
+ color: #a5b4fc;
1272
+ font-size: 12px;
1273
+ }
1274
+
1275
+ .blink-dots {
1276
+ font-size: 22px;
1277
+ font-weight: 800;
1278
+ letter-spacing: 4px;
1279
+ animation: blinkDots 1s steps(1, end) infinite;
1280
+ display: inline-block;
1281
+ padding: 2px 0;
1282
+ }
1283
+
1284
+ @keyframes blinkDots {
1285
+ 0% { opacity: 1; }
1286
+ 50% { opacity: 0.15; }
1287
+ 100% { opacity: 1; }
1288
+ }
1289
+
1290
+ textarea, .gr-textbox textarea {
1291
+ border-radius: 16px !important;
1292
+ font-size: 15px !important;
1293
+ }
1294
+
1295
+ .gr-textbox label, .gr-markdown, .gr-button {
1296
+ font-size: 14px !important;
1297
+ }
1298
+
1299
+ button {
1300
+ border-radius: 14px !important;
1301
+ min-height: 44px !important;
1302
+ font-weight: 600 !important;
1303
+ }
1304
+
1305
+ .mobile-stack {
1306
+ display: flex;
1307
+ flex-direction: column;
1308
+ gap: 12px;
1309
+ }
1310
+
1311
+ .mobile-scroll {
1312
+ max-height: 34vh;
1313
+ overflow-y: auto;
1314
+ }
1315
+
1316
+ .command-note {
1317
+ color: #cbd5e1;
1318
+ font-size: 0.88rem;
1319
+ line-height: 1.45;
1320
+ }
1321
+
1322
+ .mode-note {
1323
+ color: #cbd5e1;
1324
+ font-size: 0.88rem;
1325
+ margin-top: 6px;
1326
+ }
1327
+
1328
+ @media (max-width: 1024px) {
1329
+ .gradio-container { padding: 10px !important; }
1330
+ .hero-title { font-size: 1.45rem; }
1331
+ .hero-subtitle { font-size: 0.92rem; }
1332
+ #chatbot { height: 56vh !important; }
1333
+ }
1334
+
1335
+ @media (max-width: 768px) {
1336
+ .gradio-container { padding: 8px !important; }
1337
+ .top-card { padding: 14px; border-radius: 18px; }
1338
+ .hero-title { font-size: 1.28rem; }
1339
+ .hero-subtitle { font-size: 0.88rem; line-height: 1.45; }
1340
+ .badge { font-size: 10px; padding: 5px 8px; }
1341
+ .panel-wrap { padding: 10px; border-radius: 16px; }
1342
+ #chatbot {
1343
+ height: 52vh !important;
1344
+ min-height: 320px !important;
1345
+ border-radius: 16px !important;
1346
+ }
1347
+ button { width: 100% !important; }
1348
+ .mobile-scroll { max-height: 240px; }
1349
+ }
1350
+
1351
+ @media (max-width: 480px) {
1352
+ .hero-title { font-size: 1.15rem; }
1353
+ .hero-subtitle { font-size: 0.83rem; }
1354
+ #chatbot {
1355
+ height: 50vh !important;
1356
+ min-height: 300px !important;
1357
+ }
1358
+ textarea, .gr-textbox textarea { font-size: 14px !important; }
1359
+ }
1360
+ """
1361
+
1362
+
1363
+ def hero_html() -> str:
1364
+ return """
1365
+ <div class="top-card">
1366
+ <div class="hero-title">🫀 Mr Cardio</div>
1367
+ <div class="hero-subtitle">
1368
+ ECG and cardiology specialist chatbot with automatic mode switching,
1369
+ evidence retrieval, local ECG reasoning, grounded summaries, and normal chat mode.
1370
+ </div>
1371
+ <div class="badges">
1372
+ <div class="badge">ECG RAG</div>
1373
+ <div class="badge">Normal Chat</div>
1374
+ <div class="badge">FAISS Retrieval</div>
1375
+ <div class="badge">LoRA Adapter</div>
1376
+ <div class="badge">Validated Output</div>
1377
+ </div>
1378
+ </div>
1379
+ """
1380
+
1381
+
1382
+ def thinking_html(stage: str) -> str:
1383
+ icon = "⏳"
1384
+ subtitle = "Retrieval → reasoning → grounded answer"
1385
+
1386
+ if "switch" in stage.lower() or "activating" in stage.lower() or "updating ui" in stage.lower():
1387
+ icon = "⚡"
1388
+ subtitle = "Updating mode and interface"
1389
+
1390
+ return f"""
1391
+ <div class="status-card">
1392
+ <div style="display:flex;align-items:center;gap:12px;">
1393
+ <div style="font-size:19px;">{icon}</div>
1394
+ <div>
1395
+ <div style="font-weight:700;">{stage}</div>
1396
+ <div class="muted">{subtitle}</div>
1397
+ <div class="blink-dots">...</div>
1398
+ </div>
1399
+ </div>
1400
+ </div>
1401
+ """
1402
+
1403
+
1404
+ def initialize_session():
1405
+ return {
1406
+ "chat_history": [],
1407
+ "last_result": None,
1408
+ "chat_mode": "ecg_rag",
1409
+ }
1410
+
1411
+
1412
+ def add_assistant_placeholder(history, text="..."):
1413
+ history = history or []
1414
+ history.append({
1415
+ "role": "assistant",
1416
+ "content": text,
1417
+ "metadata": {"title": "Thinking"}
1418
+ })
1419
+ return history
1420
+
1421
+
1422
+ def update_last_assistant_message(history, text, title=None):
1423
+ history = history or []
1424
+ if not history or history[-1]["role"] != "assistant":
1425
+ msg = {"role": "assistant", "content": text}
1426
+ if title:
1427
+ msg["metadata"] = {"title": title}
1428
+ history.append(msg)
1429
+ return history
1430
+
1431
+ history[-1] = {"role": "assistant", "content": text}
1432
+ if title:
1433
+ history[-1]["metadata"] = {"title": title}
1434
+ return history
1435
+
1436
+
1437
+ def user_submit(user_message, chat_ui_history):
1438
+ chat_ui_history = chat_ui_history or []
1439
+ user_message = (user_message or "").strip()
1440
+
1441
+ if not user_message:
1442
+ return "", chat_ui_history
1443
+
1444
+ chat_ui_history.append({"role": "user", "content": user_message})
1445
+ return "", chat_ui_history
1446
+
1447
+
1448
+ def set_chat_mode(mode_value: str, session_state: Dict):
1449
+ if session_state is None:
1450
+ session_state = initialize_session()
1451
+ session_state["chat_mode"] = mode_value
1452
+ return session_state
1453
+
1454
+
1455
+ def get_mode_label(session_state: Dict) -> str:
1456
+ mode = (session_state or {}).get("chat_mode", "ecg_rag")
1457
+
1458
+ if mode == "normal_chat":
1459
+ return """
1460
+ <div class="mode-note">
1461
+ <b>Mode:</b> Normal Chat
1462
+ </div>
1463
+ """
1464
+
1465
+ return """
1466
+ <div class="mode-note">
1467
+ <b>Mode:</b> ECG &amp; Cardiology
1468
+ <br>
1469
+ <span style="color:#93c5fd;">Medical / ECG / Cardiology specialist mode active</span>
1470
+ </div>
1471
+ """
1472
+
1473
+
1474
+ # -------------------------------
1475
+ # CORE CHAT
1476
+ # -------------------------------
1477
+ def run_chat_turn(user_message: str, memory_state: Dict) -> Dict:
1478
+ if memory_state is None:
1479
+ memory_state = initialize_session()
1480
+
1481
+ chat_mode = memory_state.get("chat_mode", "ecg_rag")
1482
+
1483
+ if chat_mode == "normal_chat":
1484
+ answer = run_normal_chat(
1485
+ user_query=user_message,
1486
+ chat_history=memory_state["chat_history"]
1487
+ )
1488
+
1489
+ result = {
1490
+ "final_answer": answer,
1491
+ "best_score": -1.0,
1492
+ "used_context": False,
1493
+ "validation_status": "NORMAL_CHAT_MODE",
1494
+ "retrieved_docs": [],
1495
+ "context": "",
1496
+ "reasoning_draft": "",
1497
+ "retrieval_attempts": 0,
1498
+ "retrieval_mode": "none",
1499
+ }
1500
+ else:
1501
+ state_in = {
1502
+ "user_query": user_message,
1503
+ "chat_history": memory_state["chat_history"],
1504
+ "retrieval_attempts": 0,
1505
+ }
1506
+
1507
+ try:
1508
+ result = graph.invoke(state_in)
1509
+ except Exception as e:
1510
+ logger.error(f"Graph invocation error: {e}")
1511
+ traceback.print_exc()
1512
+ result = {
1513
+ "final_answer": f"I hit a runtime error while processing the request: {e}",
1514
+ "best_score": -1.0,
1515
+ "used_context": False,
1516
+ "validation_status": "ERROR",
1517
+ "retrieved_docs": [],
1518
+ "context": "",
1519
+ "reasoning_draft": "",
1520
+ "retrieval_attempts": 0,
1521
+ "retrieval_mode": "error",
1522
+ }
1523
+
1524
+ answer = result.get("final_answer", "").strip() or "I could not generate an answer."
1525
+ best_score = result.get("best_score", -1.0)
1526
+ validation_status = result.get("validation_status", "N/A")
1527
+ confidence = compute_confidence(result) if chat_mode == "ecg_rag" else 1.0
1528
+
1529
+ answer_with_footer = (
1530
+ f"{answer}\n\n---\n"
1531
+ f"📊 mode={chat_mode} | confidence={confidence:.2f} | best_score={best_score:.3f} | validation={validation_status}"
1532
+ )
1533
+
1534
+ memory_state["chat_history"].append({"role": "user", "content": user_message})
1535
+ memory_state["chat_history"].append({"role": "assistant", "content": answer})
1536
+ memory_state["chat_history"] = memory_state["chat_history"][-12:]
1537
+ memory_state["last_result"] = result
1538
+
1539
+ return {
1540
+ "answer": answer_with_footer,
1541
+ "memory_state": memory_state,
1542
+ "sources_markdown": format_sources_minimal(result, chat_mode=chat_mode),
1543
+ "debug_text": format_debug_text(result, chat_mode=chat_mode),
1544
+ }
1545
+
1546
+
1547
+ def bot_respond_stream(chat_ui_history, session_state):
1548
+ global vectorstore
1549
+
1550
+ if session_state is None:
1551
+ session_state = initialize_session()
1552
+
1553
+ if not chat_ui_history:
1554
+ yield (
1555
+ chat_ui_history,
1556
+ session_state,
1557
+ "## Retrieved Sources\n\nNo sources yet.",
1558
+ "No debug result yet.",
1559
+ "",
1560
+ get_mode_label(session_state),
1561
+ session_state.get("chat_mode", "ecg_rag"),
1562
+ )
1563
+ return
1564
+
1565
+ user_message = str(chat_ui_history[-1]["content"]).strip()
1566
+ chat_mode = session_state.get("chat_mode", "ecg_rag")
1567
+
1568
+ # ---------------------------------
1569
+ # AUTO MODE SWITCH
1570
+ # ---------------------------------
1571
+ requested_mode = detect_mode_switch_request(user_message)
1572
+
1573
+ if requested_mode and requested_mode != chat_mode:
1574
+ session_state["chat_mode"] = requested_mode
1575
+
1576
+ chat_ui_history = add_assistant_placeholder(chat_ui_history, text="...")
1577
+ yield (
1578
+ chat_ui_history,
1579
+ session_state,
1580
+ format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1581
+ format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1582
+ thinking_html(
1583
+ f"Switching to {'ECG & Cardiology Mode' if requested_mode == 'ecg_rag' else 'Normal Chat Mode'}"
1584
+ ),
1585
+ get_mode_label(session_state),
1586
+ requested_mode,
1587
+ )
1588
+ time.sleep(cfg.blink_stage_1)
1589
+
1590
+ yield (
1591
+ chat_ui_history,
1592
+ session_state,
1593
+ format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1594
+ format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1595
+ thinking_html("Updating UI"),
1596
+ get_mode_label(session_state),
1597
+ requested_mode,
1598
+ )
1599
+ time.sleep(cfg.blink_stage_2)
1600
+
1601
+ final_switch_text = mode_switch_message(requested_mode)
1602
+
1603
+ if cfg.enable_typewriter_stream:
1604
+ for partial in stream_text(final_switch_text, step=90):
1605
+ chat_ui_history = update_last_assistant_message(
1606
+ chat_ui_history,
1607
+ partial,
1608
+ title="Mode Update"
1609
+ )
1610
+ yield (
1611
+ chat_ui_history,
1612
+ session_state,
1613
+ format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1614
+ format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1615
+ "",
1616
+ get_mode_label(session_state),
1617
+ requested_mode,
1618
+ )
1619
+
1620
+ chat_ui_history = update_last_assistant_message(
1621
+ chat_ui_history,
1622
+ final_switch_text,
1623
+ title="Mode Update"
1624
+ )
1625
+
1626
+ session_state["chat_history"].append({"role": "user", "content": user_message})
1627
+ session_state["chat_history"].append({"role": "assistant", "content": final_switch_text})
1628
+ session_state["chat_history"] = session_state["chat_history"][-12:]
1629
+
1630
+ yield (
1631
+ chat_ui_history,
1632
+ session_state,
1633
+ format_sources_minimal(session_state.get("last_result"), chat_mode=requested_mode),
1634
+ format_debug_text(session_state.get("last_result"), chat_mode=requested_mode),
1635
+ "",
1636
+ get_mode_label(session_state),
1637
+ requested_mode,
1638
+ )
1639
+ return
1640
+
1641
+ if user_message == "/sources":
1642
+ result = session_state.get("last_result")
1643
+ chat_ui_history.append({
1644
+ "role": "assistant",
1645
+ "content": format_sources_minimal(result, chat_mode=chat_mode),
1646
+ "metadata": {"title": "Sources"}
1647
+ })
1648
+ yield (
1649
+ chat_ui_history,
1650
+ session_state,
1651
+ format_sources_minimal(result, chat_mode=chat_mode),
1652
+ format_debug_text(result, chat_mode=chat_mode),
1653
+ "",
1654
+ get_mode_label(session_state),
1655
+ session_state.get("chat_mode", "ecg_rag"),
1656
+ )
1657
+ return
1658
+
1659
+ if user_message == "/debug":
1660
+ result = session_state.get("last_result")
1661
+ chat_ui_history.append({
1662
+ "role": "assistant",
1663
+ "content": format_debug_text(result, chat_mode=chat_mode),
1664
+ "metadata": {"title": "Debug"}
1665
+ })
1666
+ yield (
1667
+ chat_ui_history,
1668
+ session_state,
1669
+ format_sources_minimal(result, chat_mode=chat_mode),
1670
+ format_debug_text(result, chat_mode=chat_mode),
1671
+ "",
1672
+ get_mode_label(session_state),
1673
+ session_state.get("chat_mode", "ecg_rag"),
1674
+ )
1675
+ return
1676
+
1677
+ if user_message == "/rebuild":
1678
+ if not cfg.allow_rebuild_vectorstore:
1679
+ chat_ui_history.append({
1680
+ "role": "assistant",
1681
+ "content": "Vector store rebuild is disabled on this Space.",
1682
+ "metadata": {"title": "Restricted"}
1683
+ })
1684
+ yield (
1685
+ chat_ui_history,
1686
+ session_state,
1687
+ format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1688
+ format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1689
+ "",
1690
+ get_mode_label(session_state),
1691
+ session_state.get("chat_mode", "ecg_rag"),
1692
+ )
1693
+ return
1694
+
1695
+ chat_ui_history = add_assistant_placeholder(chat_ui_history)
1696
+ yield (
1697
+ chat_ui_history,
1698
+ session_state,
1699
+ "",
1700
+ "",
1701
+ thinking_html("Rebuilding vector store"),
1702
+ get_mode_label(session_state),
1703
+ session_state.get("chat_mode", "ecg_rag"),
1704
+ )
1705
+
1706
+ time.sleep(cfg.blink_stage_1)
1707
+
1708
+ chat_ui_history = update_last_assistant_message(
1709
+ chat_ui_history,
1710
+ "Rebuilding vector store and reloading embeddings...",
1711
+ title="Maintenance"
1712
+ )
1713
+ yield (
1714
+ chat_ui_history,
1715
+ session_state,
1716
+ "",
1717
+ "",
1718
+ thinking_html("Rebuilding vector store"),
1719
+ get_mode_label(session_state),
1720
+ session_state.get("chat_mode", "ecg_rag"),
1721
+ )
1722
+
1723
+ build_vectorstore()
1724
+ vectorstore = load_vectorstore()
1725
+
1726
+ chat_ui_history = update_last_assistant_message(
1727
+ chat_ui_history,
1728
+ "✅ Vector store rebuilt and reloaded.",
1729
+ title="Done"
1730
+ )
1731
+ yield (
1732
+ chat_ui_history,
1733
+ session_state,
1734
+ format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1735
+ format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1736
+ "",
1737
+ get_mode_label(session_state),
1738
+ session_state.get("chat_mode", "ecg_rag"),
1739
+ )
1740
+ return
1741
+
1742
+ chat_ui_history = add_assistant_placeholder(chat_ui_history, text="...")
1743
+ yield (
1744
+ chat_ui_history,
1745
+ session_state,
1746
+ "",
1747
+ "",
1748
+ thinking_html("Starting"),
1749
+ get_mode_label(session_state),
1750
+ session_state.get("chat_mode", "ecg_rag"),
1751
+ )
1752
+ time.sleep(cfg.blink_stage_1)
1753
+
1754
+ if chat_mode == "normal_chat":
1755
+ yield (
1756
+ chat_ui_history,
1757
+ session_state,
1758
+ "",
1759
+ "",
1760
+ thinking_html("Generating normal chat reply"),
1761
+ get_mode_label(session_state),
1762
+ session_state.get("chat_mode", "ecg_rag"),
1763
+ )
1764
+ time.sleep(cfg.blink_stage_2)
1765
+ else:
1766
+ yield (
1767
+ chat_ui_history,
1768
+ session_state,
1769
+ "",
1770
+ "",
1771
+ thinking_html("Retrieving evidence"),
1772
+ get_mode_label(session_state),
1773
+ session_state.get("chat_mode", "ecg_rag"),
1774
+ )
1775
+ time.sleep(cfg.blink_stage_2)
1776
+
1777
+ yield (
1778
+ chat_ui_history,
1779
+ session_state,
1780
+ "",
1781
+ "",
1782
+ thinking_html("Running ECG adapter reasoning"),
1783
+ get_mode_label(session_state),
1784
+ session_state.get("chat_mode", "ecg_rag"),
1785
+ )
1786
+ time.sleep(cfg.blink_stage_3)
1787
+
1788
+ out = run_chat_turn(user_message, session_state)
1789
+
1790
+ yield (
1791
+ chat_ui_history,
1792
+ session_state,
1793
+ out["sources_markdown"],
1794
+ out["debug_text"],
1795
+ thinking_html("Generating grounded summary" if chat_mode == "ecg_rag" else "Finishing reply"),
1796
+ get_mode_label(session_state),
1797
+ session_state.get("chat_mode", "ecg_rag"),
1798
+ )
1799
+ time.sleep(cfg.blink_before_answer)
1800
+
1801
+ if cfg.enable_typewriter_stream:
1802
+ for partial in stream_text(out["answer"], step=120):
1803
+ chat_ui_history = update_last_assistant_message(
1804
+ chat_ui_history,
1805
+ partial,
1806
+ title="Answer"
1807
+ )
1808
+ yield (
1809
+ chat_ui_history,
1810
+ session_state,
1811
+ out["sources_markdown"],
1812
+ out["debug_text"],
1813
+ "",
1814
+ get_mode_label(session_state),
1815
+ session_state.get("chat_mode", "ecg_rag"),
1816
+ )
1817
+
1818
+ chat_ui_history = update_last_assistant_message(
1819
+ chat_ui_history,
1820
+ out["answer"],
1821
+ title="Answer"
1822
+ )
1823
+
1824
+ yield (
1825
+ chat_ui_history,
1826
+ out["memory_state"],
1827
+ out["sources_markdown"],
1828
+ out["debug_text"],
1829
+ "",
1830
+ get_mode_label(out["memory_state"]),
1831
+ out["memory_state"].get("chat_mode", "ecg_rag"),
1832
+ )
1833
+
1834
+
1835
+ def clear_chat():
1836
+ st = initialize_session()
1837
+ return (
1838
+ [],
1839
+ st,
1840
+ "## Retrieved Sources\n\nNo sources yet.",
1841
+ "No debug result yet.",
1842
+ "",
1843
+ get_mode_label(st),
1844
+ st.get("chat_mode", "ecg_rag"),
1845
+ )
1846
+
1847
+
1848
+ def rebuild_from_button(session_state, chatbot_history):
1849
+ global vectorstore
1850
+
1851
+ if session_state is None:
1852
+ session_state = initialize_session()
1853
+
1854
+ chat_mode = session_state.get("chat_mode", "ecg_rag")
1855
+
1856
+ if not cfg.allow_rebuild_vectorstore:
1857
+ chatbot_history = chatbot_history or []
1858
+ chatbot_history.append({
1859
+ "role": "assistant",
1860
+ "content": "Vector store rebuild is disabled on this Space.",
1861
+ "metadata": {"title": "Restricted"}
1862
+ })
1863
+ return (
1864
+ chatbot_history,
1865
+ session_state,
1866
+ format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1867
+ format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1868
+ "",
1869
+ get_mode_label(session_state),
1870
+ session_state.get("chat_mode", "ecg_rag"),
1871
+ )
1872
+
1873
+ build_vectorstore()
1874
+ vectorstore = load_vectorstore()
1875
+
1876
+ chatbot_history = chatbot_history or []
1877
+ chatbot_history.append({
1878
+ "role": "assistant",
1879
+ "content": "✅ Vector store rebuilt and reloaded.",
1880
+ "metadata": {"title": "Done"}
1881
+ })
1882
+
1883
+ return (
1884
+ chatbot_history,
1885
+ session_state,
1886
+ format_sources_minimal(session_state.get("last_result"), chat_mode=chat_mode),
1887
+ format_debug_text(session_state.get("last_result"), chat_mode=chat_mode),
1888
+ "",
1889
+ get_mode_label(session_state),
1890
+ session_state.get("chat_mode", "ecg_rag"),
1891
+ )
1892
+
1893
+
1894
+ # -------------------------------
1895
+ # APP
1896
+ # -------------------------------
1897
+ with gr.Blocks(
1898
+ title="Medical CSV RAG Chatbot",
1899
+ css=CUSTOM_CSS,
1900
+ theme=gr.themes.Soft(
1901
+ primary_hue="indigo",
1902
+ secondary_hue="blue",
1903
+ neutral_hue="slate",
1904
+ radius_size="lg",
1905
+ text_size="md",
1906
+ ),
1907
+ ) as demo:
1908
+
1909
+ gr.HTML(hero_html())
1910
+
1911
+ session_state = gr.State(initialize_session())
1912
+
1913
+ with gr.Column(elem_classes=["mobile-stack"]):
1914
+ with gr.Group(elem_classes=["panel-wrap"]):
1915
+ mode_selector = gr.Radio(
1916
+ choices=[
1917
+ ("ECG RAG Mode", "ecg_rag"),
1918
+ ("Normal Chat Mode", "normal_chat"),
1919
+ ],
1920
+ value="ecg_rag",
1921
+ label="Chat Mode",
1922
+ interactive=True,
1923
+ )
1924
+
1925
+ mode_status = gr.HTML(get_mode_label(initialize_session()))
1926
+
1927
+ chatbot = gr.Chatbot(
1928
+ label="Clinical Chat",
1929
+ height=640,
1930
+ elem_id="chatbot",
1931
+ type="messages",
1932
+ show_copy_button=True,
1933
+ bubble_full_width=False,
1934
+ avatar_images=(None, None),
1935
+ )
1936
+
1937
+ user_box = gr.Textbox(
1938
+ label="Ask a question",
1939
+ placeholder="e.g. What are the ECG findings in hyperkalemia? or type 'switch to ECG mode'",
1940
+ lines=2,
1941
+ autofocus=True,
1942
+ )
1943
+
1944
+ status_html = gr.HTML("")
1945
+
1946
+ with gr.Row():
1947
+ send_btn = gr.Button("Send", variant="primary")
1948
+ clear_btn = gr.Button("Clear")
1949
+ rebuild_btn = gr.Button("Rebuild Store")
1950
+
1951
+ gr.HTML(
1952
+ """
1953
+ <div class="command-note">
1954
+ Commands: <code>/sources</code>, <code>/debug</code>, <code>/rebuild</code>
1955
+ </div>
1956
+ """
1957
+ )
1958
+
1959
+ with gr.Accordion("Retrieved Sources", open=False):
1960
+ with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]):
1961
+ sources_panel = gr.Markdown("## Retrieved Sources\n\nNo sources yet.")
1962
+
1963
+ if cfg.show_debug_panel:
1964
+ with gr.Accordion("Debug Panel", open=False):
1965
+ with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]):
1966
+ debug_panel = gr.Textbox(
1967
+ label="Debug",
1968
+ value="No debug result yet.",
1969
+ lines=18,
1970
+ max_lines=28,
1971
+ interactive=False,
1972
+ )
1973
+ else:
1974
+ debug_panel = gr.Textbox(visible=False, value="")
1975
+
1976
+ mode_selector.change(
1977
+ fn=set_chat_mode,
1978
+ inputs=[mode_selector, session_state],
1979
+ outputs=[session_state],
1980
+ queue=False,
1981
+ ).then(
1982
+ fn=get_mode_label,
1983
+ inputs=[session_state],
1984
+ outputs=[mode_status],
1985
+ queue=False,
1986
+ )
1987
+
1988
+ submit_event = user_box.submit(
1989
+ fn=user_submit,
1990
+ inputs=[user_box, chatbot],
1991
+ outputs=[user_box, chatbot],
1992
+ queue=True,
1993
+ )
1994
+
1995
+ submit_event.then(
1996
+ fn=bot_respond_stream,
1997
+ inputs=[chatbot, session_state],
1998
+ outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
1999
+ queue=True,
2000
+ )
2001
+
2002
+ send_click = send_btn.click(
2003
+ fn=user_submit,
2004
+ inputs=[user_box, chatbot],
2005
+ outputs=[user_box, chatbot],
2006
+ queue=True,
2007
+ )
2008
+
2009
+ send_click.then(
2010
+ fn=bot_respond_stream,
2011
+ inputs=[chatbot, session_state],
2012
+ outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2013
+ queue=True,
2014
+ )
2015
+
2016
+ clear_btn.click(
2017
+ fn=clear_chat,
2018
+ inputs=[],
2019
+ outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2020
+ queue=False,
2021
+ )
2022
 
2023
+ rebuild_btn.click(
2024
+ fn=rebuild_from_button,
2025
+ inputs=[session_state, chatbot],
2026
+ outputs=[chatbot, session_state, sources_panel, debug_panel, status_html, mode_status, mode_selector],
2027
+ queue=True,
2028
+ )
2029
 
2030
+ demo.queue(default_concurrency_limit=1)
2031
 
2032
  if __name__ == "__main__":
2033
+ demo.launch(
2034
+ debug=cfg.launch_debug,
2035
+ server_name=cfg.server_name,
2036
+ server_port=cfg.server_port,
2037
+ )