| import os |
| import re |
| import time |
| import traceback |
| import logging |
| from typing import List, Dict, TypedDict, Optional |
| from dataclasses import dataclass, field |
|
|
| import torch |
| import pandas as pd |
| import gradio as gr |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
|
|
| from langchain_core.documents import Document |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain_community.vectorstores import FAISS |
| from langchain_openai import ChatOpenAI |
| from langgraph.graph import StateGraph, START, END |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
| @dataclass |
| class Config: |
| |
| base_model_path: str = os.getenv( |
| "BASE_MODEL_PATH", |
| "meta-llama/Llama-3.1-8B-Instruct" |
| ) |
| adapter_dir: str = os.getenv( |
| "ADAPTER_DIR", |
| "adapter_refined_v10" |
| ) |
| data_csv: str = os.getenv( |
| "DATA_CSV", |
| "RAGmaterials/ECG_RAG_only_clean.csv" |
| ) |
| rag_dir: str = os.getenv( |
| "RAG_DIR", |
| "RAGmaterials" |
| ) |
| vectorstore_dir: str = field(init=False) |
|
|
| |
| hf_token: str = os.getenv("HF_TOKEN", "") |
| deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "") |
| deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com") |
| deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat") |
|
|
| |
| deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1")) |
| deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "550")) |
|
|
| |
| embed_model_name: str = os.getenv( |
| "EMBED_MODEL_NAME", |
| "sentence-transformers/all-MiniLM-L6-v2" |
| ) |
|
|
| |
| similarity_k: int = int(os.getenv("SIMILARITY_K", "12")) |
| top_k_final: int = int(os.getenv("TOP_K_FINAL", "4")) |
| max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200")) |
|
|
| |
| max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096")) |
| max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180")) |
| max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6")) |
|
|
| |
| min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08")) |
| min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20")) |
| strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30")) |
| strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3")) |
|
|
| |
| use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true" |
| enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true" |
| enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true" |
| enable_typewriter_stream: bool = os.getenv("ENABLE_TYPEWRITER_STREAM", "true").lower() == "true" |
| show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true" |
| allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true" |
|
|
| |
| use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true" |
|
|
| |
| launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true" |
| server_name: str = os.getenv("SERVER_NAME", "0.0.0.0") |
| server_port: int = int(os.getenv("SERVER_PORT", "7860")) |
|
|
| |
| blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40")) |
| blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55")) |
| blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50")) |
| blink_before_answer: float = float(os.getenv("BLINK_BEFORE_ANSWER", "0.25")) |
|
|
| def __post_init__(self): |
| self.vectorstore_dir = os.path.join(self.rag_dir, "faiss_store") |
| os.makedirs(self.rag_dir, exist_ok=True) |
|
|
| if not self.deepseek_api_key: |
| raise ValueError("Missing DEEPSEEK_API_KEY. Add it in Hugging Face Space Secrets.") |
|
|
| for path, name in [ |
| (self.adapter_dir, "Adapter directory"), |
| (self.data_csv, "CSV data"), |
| ]: |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"{name} not found at: {path}") |
|
|
|
|
| cfg = Config() |
| logger.info("Configuration loaded.") |
|
|
|
|
| |
| |
| |
| LOCAL_REASONING_SYSTEM = """ |
| You are a strict medical reasoning assistant specialized for ECG and cardiology reasoning. |
| |
| You are NOT the final answer generator. |
| You must analyze ONLY the supplied evidence and produce a short structured reasoning draft. |
| |
| Rules: |
| 1) Use only the provided evidence. |
| 2) Do not invent facts. |
| 3) Focus only on the user's exact question. |
| 4) Output exactly in this structure: |
| |
| KEY_FINDINGS: |
| - ... |
| - ... |
| |
| INTERPRETATION: |
| - ... |
| - ... |
| |
| SUPPORTED_POINTS: |
| - [EVIDENCE_ID: X] ... |
| - [EVIDENCE_ID: Y] ... |
| |
| LIMITS: |
| - ... |
| |
| 5) If evidence is insufficient, output exactly: |
| INSUFFICIENT_EVIDENCE |
| """.strip() |
|
|
| QUERY_EXPANSION_SYSTEM = """ |
| You expand medical queries for retrieval. |
| |
| Rules: |
| 1) Preserve the user's intent. |
| 2) Add close medical paraphrases and alternate wording. |
| 3) Add likely medical synonyms, abbreviations, and alternate phrasing. |
| 4) Do not answer the question. |
| 5) Output only the expanded retrieval query. |
| """.strip() |
|
|
| DEEPSEEK_SUMMARY_SYSTEM = """ |
| You are an expert medical evidence summarizer. |
| |
| Your job is to produce a clinically precise, well-structured answer grounded ONLY in: |
| 1. the retrieved evidence |
| 2. the local reasoning draft |
| |
| You must be faithful to the provided material and answer the user's question directly, clearly, and conservatively. |
| |
| PRIMARY OBJECTIVE |
| - Identify the user's main intent before writing: |
| definition, cause, symptoms, diagnosis, investigation, treatment, prognosis, or genetics. |
| - Prioritize that intent throughout the response. |
| - The first sentence of the Summary must directly answer the user's question in the most clinically relevant way. |
| |
| GROUNDING RULES |
| - Use only information supported by the retrieved evidence and local reasoning draft. |
| - Do not add outside medical knowledge. |
| - Do not infer specific facts unless they are clearly supported. |
| - Do not invent treatments, diagnoses, risks, mechanisms, thresholds, statistics, timelines, monitoring plans, or prognosis details. |
| - If the evidence is incomplete, be explicit about what is missing. |
| - If the evidence is too weak to answer the question reliably, output exactly: |
| INSUFFICIENT_EVIDENCE |
| |
| STYLE RULES |
| - Write in precise, professional clinical language. |
| - Be specific, not vague. |
| - Be concise, but fully informative. |
| - Avoid repetition, generic filler, and empty statements. |
| - Do not mention retrieval, prompts, system instructions, reasoning drafts, tools, pipelines, or internal processes. |
| - Do not include URLs or citations unless explicitly requested elsewhere. |
| - Do not overstate certainty. |
| - When appropriate, distinguish clearly between what is established, what is suggested, and what is not addressed by the evidence. |
| |
| OUTPUT FORMAT |
| |
| ### Summary |
| - Write 4 to 7 full sentences. |
| - This is the most important section. |
| - The first sentence must directly answer the user's question. |
| - Focus primarily on the user's main intent. |
| - Include only background information that improves understanding of the requested topic. |
| - Make the summary clinically useful, specific, and evidence-faithful. |
| |
| ### Key Evidence Points |
| - Include 4 to 6 bullet points. |
| - Each bullet must state a concrete fact supported by the evidence. |
| - Prioritize clinically important facts over background detail. |
| - Avoid repeating the same idea in different words. |
| |
| ### Clinical Implications / Recommendations |
| - Include 2 to 4 bullet points only if supported by the evidence. |
| - Focus on practical interpretation, management implications, follow-up considerations, or next steps. |
| - If the evidence supports recognition or framing rather than action, say that clearly. |
| - Do not recommend interventions not supported by the evidence. |
| |
| ### Limitations of the Evidence |
| - State clearly what the evidence does not establish, does not cover, or leaves uncertain. |
| - Explicitly note when details are lacking on: |
| treatment, diagnosis, prognosis, genetics, monitoring, recurrence prevention, comparative effectiveness, or long-term outcomes. |
| - If the evidence is narrow, low-detail, or only partially aligned with the question, say so plainly. |
| |
| SPECIAL INSTRUCTIONS BY QUESTION TYPE |
| |
| For treatment questions: |
| - Focus primarily on treatment and management, not disease definition. |
| - Organize treatment information in this order whenever supported by the evidence: |
| 1. supportive or conservative care |
| 2. symptomatic drug therapy or procedural treatment |
| 3. long-term prevention, follow-up, or recurrence prevention |
| - Distinguish treatment of active symptoms from prevention of recurrence or complications. |
| - If the condition is benign, self-limited, or often does not require treatment, state that clearly in the first sentence. |
| |
| For diagnosis or investigation questions: |
| - Focus on how the condition is identified, evaluated, or differentiated. |
| - Prioritize diagnostic features, testing approach, and clinically useful distinctions. |
| - Do not drift into treatment unless the evidence clearly supports it and it helps answer the question. |
| |
| For cause or risk questions: |
| - Focus on etiologies, risk factors, mechanisms, or associations supported by the evidence. |
| - Distinguish established causes from possible contributors if the evidence is less certain. |
| |
| For prognosis questions: |
| - Focus on expected course, complications, recurrence, or outcome-related information supported by the evidence. |
| - Do not add prognostic claims not explicitly supported. |
| |
| QUALITY CHECK BEFORE OUTPUT |
| Before finalizing, ensure that: |
| - the first sentence directly answers the question |
| - the response matches the user's primary intent |
| - every important claim is grounded in the provided material |
| - no unsupported medical detail has been added |
| - the Limitations section honestly reflects evidence gaps |
| |
| If these conditions cannot be met, output exactly: |
| INSUFFICIENT_EVIDENCE |
| """.strip() |
|
|
| VALIDATOR_SYSTEM = """ |
| You are a strict medical evidence validator. |
| |
| Your job is to compare the ANSWER against the EVIDENCE. |
| |
| Rules: |
| 1) Mark SUPPORTED if the answer is well grounded in the evidence. |
| 2) Mark PARTLY_UNSUPPORTED if some claims are supported but others go beyond the evidence. |
| 3) Mark INSUFFICIENT_EVIDENCE if the answer is mostly unsupported or the evidence is too weak. |
| 4) Output only one short verdict line beginning with exactly one of: |
| SUPPORTED: |
| PARTLY_UNSUPPORTED: |
| INSUFFICIENT_EVIDENCE: |
| """.strip() |
|
|
|
|
| |
| |
| |
| def clean_text(x: str) -> str: |
| x = str(x).replace("\x00", " ").strip() |
| x = re.sub(r"\s+", " ", x) |
| return x |
|
|
|
|
| def strip_bad_sections(txt: str) -> str: |
| t = str(txt).strip() |
| cut_markers = [ |
| "References:", |
| "Sources:", |
| "Source:", |
| "URLs:", |
| "This response is based", |
| "Please let me know", |
| "Is there anything else", |
| ] |
| for marker in cut_markers: |
| pos = t.lower().find(marker.lower()) |
| if pos != -1: |
| t = t[:pos].strip() |
|
|
| t = re.sub(r"https?://\S+|www\.\S+", "", t).strip() |
| return t |
|
|
|
|
| def infer_tags(question: str, answer: str) -> List[str]: |
| text = f"{question} {answer}".lower() |
| tags: List[str] = [] |
|
|
| keyword_map = { |
| "treatment": ["treat", "therapy", "management", "drug", "surgery"], |
| "diagnosis": ["diagnosis", "diagnose", "criteria"], |
| "symptoms": ["symptom", "presentation", "sign", "feature"], |
| "ecg": ["ecg", "ekg", "st elevation", "qrs", "p wave", "arrhythmia", "tachycardia", "bradycardia"], |
| "investigation": ["test", "investigation", "mri", "ct", "lab", "imaging"], |
| "prognosis": ["prognosis", "outcome", "survival", "risk"], |
| "genetics": ["gene", "genetic", "mutation", "variant", "chromosome", "inherited", "inheritance"], |
| "etiology": ["cause", "causes", "caused by", "associated with", "risk factor"], |
| } |
|
|
| for tag, words in keyword_map.items(): |
| if any(w in text for w in words): |
| tags.append(tag) |
|
|
| return tags |
|
|
|
|
| def make_row_text(q: str, a: str) -> str: |
| return f"QUESTION:\n{q}\n\nANSWER:\n{a}".strip() |
|
|
|
|
| def score_to_similarity(raw_score: float) -> float: |
| try: |
| raw_score = float(raw_score) |
| except Exception: |
| return -1.0 |
| return 1.0 / (1.0 + max(raw_score, 0.0)) |
|
|
|
|
| def lexical_overlap(query: str, text: str) -> float: |
| q_words = set(re.findall(r"\w+", query.lower())) |
| t_words = set(re.findall(r"\w+", text.lower())) |
| if not q_words: |
| return 0.0 |
| return len(q_words & t_words) / max(1, len(q_words)) |
|
|
|
|
| def rerank_docs(query: str, docs: List[Document], top_n: Optional[int] = None) -> List[Document]: |
| if top_n is None: |
| top_n = cfg.top_k_final |
|
|
| q_words = set(re.findall(r"\w+", query.lower())) |
| scored = [] |
|
|
| for d in docs: |
| question = d.metadata.get("question", "") |
| answer = d.metadata.get("answer", "") |
| tags = " ".join(d.metadata.get("tags", [])) |
| text = f"{question} {answer} {tags}".lower() |
|
|
| t_words = set(re.findall(r"\w+", text)) |
| overlap = len(q_words & t_words) / max(1, len(q_words)) |
| question_boost = 0.20 if any(w in question.lower() for w in q_words) else 0.0 |
| tag_boost = 0.10 if any(w in tags.lower() for w in q_words) else 0.0 |
| sim_score = float(d.metadata.get("sim_score", 0.0)) |
|
|
| final_score = overlap + question_boost + tag_boost + (0.35 * sim_score) |
| scored.append((d, final_score)) |
|
|
| scored.sort(key=lambda x: x[1], reverse=True) |
| return [d for d, _ in scored[:top_n]] |
|
|
|
|
| def history_to_text(chat_history: List[Dict[str, str]], max_turns: Optional[int] = None) -> str: |
| if max_turns is None: |
| max_turns = cfg.max_chat_history_turns |
|
|
| items = chat_history[-max_turns:] |
| if not items: |
| return "[EMPTY]" |
|
|
| return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in items]).strip() |
|
|
|
|
| def build_context_string(docs: List[Document], max_chars: Optional[int] = None) -> str: |
| if max_chars is None: |
| max_chars = cfg.max_context_chars |
|
|
| blocks = [] |
| total = 0 |
|
|
| for i, d in enumerate(docs, 1): |
| q = d.metadata.get("question", "") |
| a = d.metadata.get("answer", "") |
| tags = ", ".join(d.metadata.get("tags", [])) or "N/A" |
| sim = d.metadata.get("sim_score", None) |
|
|
| block = f""" |
| ============================== |
| EVIDENCE_ID: {i} |
| SOURCE_ID: {d.metadata.get('id')} |
| SOURCE_QUESTION: {q} |
| SOURCE_TAGS: {tags} |
| SIMILARITY: {sim if sim is not None else 'N/A'} |
| EVIDENCE_TEXT: |
| {a} |
| ============================== |
| """.strip() |
|
|
| if total + len(block) > max_chars: |
| break |
|
|
| blocks.append(block) |
| total += len(block) + 2 |
|
|
| return "\n\n".join(blocks).strip() |
|
|
|
|
| def compute_confidence(result: Dict) -> float: |
| best_score = result.get("best_score", -1.0) |
| validation = result.get("validation_status", "") |
|
|
| if validation.startswith("SUPPORTED"): |
| conf = best_score |
| elif validation.startswith("PARTLY_UNSUPPORTED"): |
| conf = best_score * 0.70 |
| else: |
| conf = best_score * 0.40 |
|
|
| return max(0.0, min(1.0, conf)) |
|
|
|
|
| def strong_retrieval(best_score: float, docs: List[Document]) -> bool: |
| return ( |
| best_score >= cfg.strong_retrieval_threshold |
| and len(docs) >= cfg.strong_retrieval_min_docs |
| ) |
|
|
|
|
| def stream_text(text: str, step: int = 110): |
| acc = "" |
| for i in range(0, len(text), step): |
| acc += text[i:i + step] |
| yield acc |
|
|
|
|
| |
| |
| |
| logger.info("Loading embeddings...") |
| embeddings = HuggingFaceEmbeddings(model_name=cfg.embed_model_name) |
|
|
|
|
| def build_vectorstore(): |
| logger.info(f"Reading CSV: {cfg.data_csv}") |
| df = pd.read_csv(cfg.data_csv) |
| df.columns = [c.strip().lower() for c in df.columns] |
|
|
| required = {"instruction", "response"} |
| if not required.issubset(df.columns): |
| raise ValueError(f"CSV must contain columns {required}. Found: {df.columns.tolist()}") |
|
|
| df = df[["instruction", "response"]].dropna().reset_index(drop=True) |
| df["instruction"] = df["instruction"].map(clean_text) |
| df["response"] = df["response"].map(clean_text) |
|
|
| docs = [] |
| for i, row in df.iterrows(): |
| q = row["instruction"] |
| a = row["response"] |
| docs.append( |
| Document( |
| page_content=make_row_text(q, a), |
| metadata={ |
| "id": int(i), |
| "question": q, |
| "answer": a, |
| "tags": infer_tags(q, a), |
| } |
| ) |
| ) |
|
|
| vectorstore_local = FAISS.from_documents(docs, embeddings) |
| vectorstore_local.save_local(cfg.vectorstore_dir) |
| logger.info(f"Saved vectorstore with {len(docs)} docs to {cfg.vectorstore_dir}") |
|
|
|
|
| def load_vectorstore(): |
| return FAISS.load_local( |
| cfg.vectorstore_dir, |
| embeddings, |
| allow_dangerous_deserialization=True, |
| ) |
|
|
|
|
| if not os.path.exists(cfg.vectorstore_dir): |
| logger.info("Vectorstore not found. Building from CSV...") |
| build_vectorstore() |
|
|
| vectorstore = load_vectorstore() |
| logger.info("Vectorstore ready.") |
|
|
|
|
| |
| |
| |
| logger.info("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| cfg.base_model_path, |
| use_fast=True, |
| token=cfg.hf_token if cfg.hf_token else None |
| ) |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| logger.info("Loading base model...") |
| has_cuda = torch.cuda.is_available() |
| base_model = None |
|
|
| if cfg.use_4bit and has_cuda: |
| try: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| ) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| cfg.base_model_path, |
| device_map="auto", |
| quantization_config=bnb_config, |
| torch_dtype=torch.float16, |
| token=cfg.hf_token if cfg.hf_token else None, |
| ) |
| logger.info("Loaded base model in 4-bit mode.") |
| except Exception as e: |
| logger.warning(f"4-bit load failed: {e}") |
|
|
| if base_model is None: |
| dtype = torch.float16 if has_cuda else torch.float32 |
| base_model = AutoModelForCausalLM.from_pretrained( |
| cfg.base_model_path, |
| device_map="auto" if has_cuda else None, |
| torch_dtype=dtype, |
| token=cfg.hf_token if cfg.hf_token else None, |
| ) |
| if not has_cuda: |
| base_model = base_model.to("cpu") |
| logger.info("Loaded base model without 4-bit.") |
|
|
| base_model.eval() |
|
|
| logger.info("Loading ECG reasoning adapter...") |
| reason_model = PeftModel.from_pretrained(base_model, cfg.adapter_dir) |
| reason_model.eval() |
|
|
|
|
| def get_primary_model_device(model) -> torch.device: |
| try: |
| return next(model.parameters()).device |
| except StopIteration: |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @torch.inference_mode() |
| def run_local_reasoner(user_query: str, context: str) -> str: |
| try: |
| messages = [ |
| {"role": "system", "content": LOCAL_REASONING_SYSTEM}, |
| { |
| "role": "user", |
| "content": f"QUESTION:\n{user_query}\n\nEVIDENCE:\n{context if context.strip() else '[EMPTY]'}" |
| }, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=cfg.max_input_len, |
| ) |
|
|
| model_device = get_primary_model_device(reason_model) |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
| out = reason_model.generate( |
| **inputs, |
| max_new_tokens=cfg.max_new_tokens_local, |
| do_sample=False, |
| use_cache=True, |
| repetition_penalty=1.08, |
| no_repeat_ngram_size=3, |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| gen_ids = out[0, inputs["input_ids"].shape[1]:] |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() |
| text = strip_bad_sections(text) |
|
|
| return text if text else "INSUFFICIENT_EVIDENCE" |
|
|
| except Exception as e: |
| logger.error(f"Local reasoner error: {e}") |
| traceback.print_exc() |
| return "INSUFFICIENT_EVIDENCE" |
|
|
|
|
| |
| |
| |
| deepseek_llm = ChatOpenAI( |
| model=cfg.deepseek_model, |
| api_key=cfg.deepseek_api_key, |
| base_url=cfg.deepseek_base_url, |
| temperature=cfg.deepseek_temperature, |
| max_tokens=cfg.deepseek_max_tokens, |
| ) |
|
|
| _query_expansion_cache: Dict[str, str] = {} |
|
|
|
|
| def llm_text(system_prompt: str, user_prompt: str, fallback: str = "INSUFFICIENT_EVIDENCE") -> str: |
| try: |
| resp = deepseek_llm.invoke([ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ]) |
| text = resp.content if hasattr(resp, "content") else str(resp) |
| text = strip_bad_sections(text) |
| return text if text.strip() else fallback |
| except Exception as e: |
| logger.error(f"DeepSeek error: {e}") |
| traceback.print_exc() |
| return fallback |
|
|
|
|
| def run_query_expansion(user_query: str) -> str: |
| if not cfg.enable_query_expansion: |
| return user_query |
|
|
| if cfg.use_query_cache and user_query in _query_expansion_cache: |
| logger.info(f"Using cached expansion for: {user_query[:80]}") |
| return _query_expansion_cache[user_query] |
|
|
| prompt = f""" |
| USER_QUERY: |
| {user_query} |
| |
| Expand this for retrieval with close medical phrasing, synonyms, and alternate wording. |
| Do not answer the question. |
| """.strip() |
|
|
| expanded = llm_text(QUERY_EXPANSION_SYSTEM, prompt, fallback=user_query) |
| expanded = expanded.strip() if expanded else user_query |
|
|
| if cfg.use_query_cache: |
| _query_expansion_cache[user_query] = expanded |
|
|
| return expanded |
|
|
|
|
| def run_deepseek_summary( |
| user_query: str, |
| context: str, |
| reasoning_draft: str, |
| chat_history: List[Dict[str, str]], |
| ) -> str: |
| prompt = f""" |
| CHAT_HISTORY: |
| {history_to_text(chat_history)} |
| |
| USER_QUESTION: |
| {user_query} |
| |
| RETRIEVED_EVIDENCE: |
| {context if context.strip() else '[EMPTY]'} |
| |
| LOCAL_REASONING_DRAFT: |
| {reasoning_draft if reasoning_draft.strip() else '[EMPTY]'} |
| |
| Write a grounded final summary answer using only the evidence and reasoning draft. |
| """.strip() |
|
|
| return llm_text( |
| DEEPSEEK_SUMMARY_SYSTEM, |
| prompt, |
| fallback="I could not generate a grounded summary from the retrieved evidence." |
| ) |
|
|
|
|
| def run_validator(context: str, answer: str) -> str: |
| if not cfg.enable_validator: |
| return "SUPPORTED (validator disabled)" |
|
|
| prompt = f""" |
| EVIDENCE: |
| {context if context.strip() else '[EMPTY]'} |
| |
| ANSWER: |
| {answer if answer.strip() else '[EMPTY]'} |
| """.strip() |
|
|
| return llm_text(VALIDATOR_SYSTEM, prompt, fallback="PARTLY_UNSUPPORTED: validator unavailable") |
|
|
|
|
| |
| |
| |
| def warmup_models(): |
| logger.info("Warming up local reasoner...") |
| try: |
| _ = run_local_reasoner( |
| "What are ECG findings in hyperkalemia?", |
| """ |
| ============================== |
| EVIDENCE_ID: 1 |
| SOURCE_QUESTION: What are ECG findings in hyperkalemia? |
| SOURCE_TAGS: ecg |
| EVIDENCE_TEXT: |
| Hyperkalemia may cause peaked T waves, PR prolongation, QRS widening, and severe conduction abnormalities. |
| ============================== |
| """.strip(), |
| ) |
| logger.info("Warmup completed.") |
| except Exception as e: |
| logger.warning(f"Warmup failed: {e}") |
|
|
|
|
| warmup_models() |
|
|
|
|
| |
| |
| |
| class ChatState(TypedDict, total=False): |
| user_query: str |
| expanded_query: str |
| chat_history: List[Dict[str, str]] |
|
|
| retrieved_docs: List[Document] |
| best_score: float |
| used_context: bool |
| context: str |
| retrieval_attempts: int |
| retrieval_mode: str |
|
|
| reasoning_draft: str |
| final_answer: str |
| validation_status: str |
|
|
|
|
| |
| |
| |
| def retrieve_docs_once(query_for_search: str, original_query: str): |
| try: |
| scored = vectorstore.similarity_search_with_score( |
| query_for_search, |
| k=cfg.similarity_k, |
| ) |
| except Exception as e: |
| logger.error(f"Retriever error: {e}") |
| traceback.print_exc() |
| return [], -1.0 |
|
|
| if not scored: |
| return [], -1.0 |
|
|
| filtered_docs = [] |
| best_score = -1.0 |
|
|
| for doc, raw_score in scored: |
| sim = score_to_similarity(raw_score) |
| best_score = max(best_score, sim) |
|
|
| q = doc.metadata.get("question", "") |
| a = doc.metadata.get("answer", "") |
| ov = lexical_overlap(original_query, f"{q} {a}") |
|
|
| if ov >= cfg.min_lexical_overlap and sim >= cfg.min_faiss_similarity: |
| new_doc = Document(page_content=doc.page_content, metadata=dict(doc.metadata)) |
| new_doc.metadata["sim_score"] = sim |
| new_doc.metadata["lexical_overlap"] = ov |
| filtered_docs.append(new_doc) |
|
|
| reranked = rerank_docs(original_query, filtered_docs, top_n=cfg.top_k_final) |
| return reranked, best_score |
|
|
|
|
| |
| |
| |
| def retrieve_node(state: ChatState) -> ChatState: |
| query = state.get("expanded_query") or state["user_query"] |
| retrieval_attempts = int(state.get("retrieval_attempts", 0)) + 1 |
| retrieval_mode = "expanded" if state.get("expanded_query") else "original" |
|
|
| docs, best_score = retrieve_docs_once( |
| query_for_search=query, |
| original_query=state["user_query"], |
| ) |
|
|
| if not docs: |
| return { |
| "retrieved_docs": [], |
| "best_score": best_score, |
| "used_context": False, |
| "context": "", |
| "retrieval_attempts": retrieval_attempts, |
| "retrieval_mode": retrieval_mode, |
| } |
|
|
| return { |
| "retrieved_docs": docs, |
| "best_score": best_score, |
| "used_context": True, |
| "context": build_context_string(docs, max_chars=cfg.max_context_chars), |
| "retrieval_attempts": retrieval_attempts, |
| "retrieval_mode": retrieval_mode, |
| } |
|
|
|
|
| def should_retry_retrieval(state: ChatState) -> str: |
| used_context = state.get("used_context", False) |
| best_score = state.get("best_score", -1.0) |
| attempts = int(state.get("retrieval_attempts", 0)) |
|
|
| if used_context and best_score >= cfg.min_faiss_similarity: |
| return "local_reasoning" |
|
|
| if not cfg.enable_query_expansion: |
| return "local_reasoning" |
|
|
| if attempts >= 2: |
| return "local_reasoning" |
|
|
| return "expand_query" |
|
|
|
|
| def expand_query_node(state: ChatState) -> ChatState: |
| expanded = run_query_expansion(state["user_query"]) |
| if not expanded.strip(): |
| expanded = state["user_query"] |
| return {"expanded_query": expanded} |
|
|
|
|
| def local_reasoning_node(state: ChatState) -> ChatState: |
| context = state.get("context", "").strip() |
| if not context: |
| return {"reasoning_draft": "INSUFFICIENT_EVIDENCE"} |
|
|
| reasoning = run_local_reasoner(state["user_query"], context) |
| return {"reasoning_draft": reasoning} |
|
|
|
|
| def generate_node(state: ChatState) -> ChatState: |
| context = state.get("context", "").strip() |
| reasoning = state.get("reasoning_draft", "INSUFFICIENT_EVIDENCE") |
| history = state.get("chat_history", []) |
|
|
| if not context: |
| return {"final_answer": "I could not find sufficiently relevant evidence in the RAG database for this question."} |
|
|
| answer = run_deepseek_summary( |
| user_query=state["user_query"], |
| context=context, |
| reasoning_draft=reasoning, |
| chat_history=history, |
| ) |
| return {"final_answer": answer} |
|
|
|
|
| def validate_node(state: ChatState) -> ChatState: |
| context = state.get("context", "").strip() |
| answer = state.get("final_answer", "").strip() |
| best_score = state.get("best_score", -1.0) |
| docs = state.get("retrieved_docs", []) |
|
|
| if not context or not answer: |
| return {"validation_status": "INSUFFICIENT_EVIDENCE: missing context or answer"} |
|
|
| if strong_retrieval(best_score, docs): |
| return {"validation_status": "SUPPORTED (validator skipped due to strong retrieval)"} |
|
|
| verdict = run_validator(context, answer) |
|
|
| if verdict.startswith("SUPPORTED"): |
| return {"validation_status": verdict} |
|
|
| if verdict.startswith("PARTLY_UNSUPPORTED"): |
| return { |
| "validation_status": verdict, |
| "final_answer": answer + "\n\nEvidence limits: some parts may not be fully supported by the retrieved evidence." |
| } |
|
|
| if verdict.startswith("INSUFFICIENT_EVIDENCE"): |
| return { |
| "validation_status": verdict, |
| "final_answer": answer + "\n\nEvidence limits: the retrieved evidence was weak or only partially relevant." |
| } |
|
|
| return {"validation_status": verdict} |
|
|
|
|
| def finalize_node(state: ChatState) -> ChatState: |
| answer = strip_bad_sections(state.get("final_answer", "")) |
| if not answer: |
| answer = "I could not generate an answer." |
| return {"final_answer": answer} |
|
|
|
|
| |
| |
| |
| builder = StateGraph(ChatState) |
| builder.add_node("retrieve", retrieve_node) |
| builder.add_node("expand_query", expand_query_node) |
| builder.add_node("local_reasoning", local_reasoning_node) |
| builder.add_node("generate", generate_node) |
| builder.add_node("validate", validate_node) |
| builder.add_node("finalize", finalize_node) |
|
|
| builder.add_edge(START, "retrieve") |
| builder.add_conditional_edges( |
| "retrieve", |
| should_retry_retrieval, |
| { |
| "expand_query": "expand_query", |
| "local_reasoning": "local_reasoning", |
| } |
| ) |
| builder.add_edge("expand_query", "retrieve") |
| builder.add_edge("local_reasoning", "generate") |
| builder.add_edge("generate", "validate") |
| builder.add_edge("validate", "finalize") |
| builder.add_edge("finalize", END) |
|
|
| graph = builder.compile() |
| logger.info("LangGraph compiled.") |
|
|
|
|
| |
| |
| |
| def format_sources_minimal(result: Optional[Dict]) -> str: |
| if not result: |
| return "## Retrieved Sources\n\nNo sources yet." |
|
|
| docs = result.get("retrieved_docs", []) |
| best_score = result.get("best_score", -1.0) |
|
|
| if not docs: |
| return ( |
| "## Retrieved Sources\n\n" |
| "No sufficiently relevant evidence retrieved.\n\n" |
| f"**Best score:** `{best_score:.3f}`" |
| ) |
|
|
| lines = [ |
| "## Retrieved Sources", |
| f"**Best score:** `{best_score:.3f}`", |
| "", |
| ] |
|
|
| for i, d in enumerate(docs, 1): |
| question = d.metadata.get("question", "") |
| answer = d.metadata.get("answer", "") |
| similarity = d.metadata.get("sim_score", "N/A") |
| preview = answer[:210].strip() |
| if len(answer) > 210: |
| preview += "..." |
|
|
| lines.extend([ |
| f"### Evidence {i}", |
| f"- **Question:** {question}", |
| f"- **Similarity:** `{similarity}`", |
| f"- **Preview:** {preview}", |
| "", |
| ]) |
|
|
| return "\n".join(lines) |
|
|
|
|
| def format_debug_text(result: Optional[Dict]) -> str: |
| if not result: |
| return "No debug result yet." |
|
|
| return f""" |
| BEST SCORE: {result.get('best_score', -1.0)} |
| USED CONTEXT: {result.get('used_context', False)} |
| RETRIEVAL ATTEMPTS: {result.get('retrieval_attempts', 0)} |
| RETRIEVAL MODE: {result.get('retrieval_mode', 'N/A')} |
| VALIDATION STATUS: {result.get('validation_status', 'N/A')} |
| |
| ----- CONTEXT ----- |
| {result.get('context', '')} |
| |
| ----- LOCAL REASONING DRAFT ----- |
| {result.get('reasoning_draft', '')} |
| """.strip() |
|
|
|
|
| |
| |
| |
| CUSTOM_CSS = """ |
| :root { |
| --bg-main: #07111f; |
| --bg-soft: #0b1728; |
| --card: rgba(10, 19, 35, 0.86); |
| --card-2: rgba(14, 25, 43, 0.94); |
| --border: rgba(148, 163, 184, 0.16); |
| --text: #e5eefb; |
| --muted: #94a3b8; |
| --primary: #7c3aed; |
| --primary-2: #2563eb; |
| --success: #10b981; |
| } |
| |
| html, body, .gradio-container { |
| margin: 0 !important; |
| padding: 0 !important; |
| min-height: 100%; |
| background: |
| radial-gradient(circle at top left, rgba(124,58,237,0.22), transparent 28%), |
| radial-gradient(circle at top right, rgba(37,99,235,0.18), transparent 24%), |
| linear-gradient(180deg, #050b16 0%, #091321 100%); |
| color: var(--text); |
| } |
| |
| .gradio-container { |
| max-width: 100% !important; |
| padding: 12px !important; |
| } |
| |
| footer { |
| visibility: hidden; |
| } |
| |
| .top-card { |
| border: 1px solid var(--border); |
| background: linear-gradient(135deg, rgba(11,23,40,0.95), rgba(18,31,56,0.92)); |
| border-radius: 22px; |
| padding: 16px; |
| margin-bottom: 12px; |
| box-shadow: 0 14px 40px rgba(0,0,0,0.20); |
| } |
| |
| .hero-title { |
| font-size: 1.6rem; |
| font-weight: 800; |
| color: #f8fbff; |
| margin-bottom: 6px; |
| line-height: 1.15; |
| } |
| |
| .hero-subtitle { |
| color: #cbd5e1; |
| font-size: 0.95rem; |
| line-height: 1.5; |
| } |
| |
| .badges { |
| display: flex; |
| gap: 8px; |
| flex-wrap: wrap; |
| margin-top: 12px; |
| } |
| |
| .badge { |
| display: inline-flex; |
| align-items: center; |
| gap: 6px; |
| padding: 6px 10px; |
| border-radius: 999px; |
| font-size: 11px; |
| color: #e6eefc; |
| border: 1px solid rgba(255,255,255,0.12); |
| background: rgba(255,255,255,0.06); |
| } |
| |
| .panel-wrap { |
| border: 1px solid var(--border); |
| background: linear-gradient(180deg, rgba(10,19,35,0.96), rgba(7,14,26,0.94)); |
| border-radius: 20px; |
| padding: 12px; |
| box-shadow: 0 16px 45px rgba(0,0,0,0.22); |
| } |
| |
| #chatbot { |
| height: min(62vh, 640px) !important; |
| min-height: 360px !important; |
| border-radius: 18px !important; |
| border: 1px solid var(--border) !important; |
| overflow: hidden !important; |
| box-shadow: 0 14px 40px rgba(0,0,0,0.26) !important; |
| } |
| |
| .status-card { |
| padding: 12px 14px; |
| border-radius: 16px; |
| background: linear-gradient(135deg, #0f172a 0%, #172554 100%); |
| color: #f9fafb; |
| font-size: 14px; |
| border: 1px solid rgba(255,255,255,0.12); |
| box-shadow: 0 10px 30px rgba(0,0,0,0.2); |
| } |
| |
| .muted { |
| color: #a5b4fc; |
| font-size: 12px; |
| } |
| |
| .blink-dots { |
| font-size: 22px; |
| font-weight: 800; |
| letter-spacing: 4px; |
| animation: blinkDots 1s steps(1, end) infinite; |
| display: inline-block; |
| padding: 2px 0; |
| } |
| |
| @keyframes blinkDots { |
| 0% { opacity: 1; } |
| 50% { opacity: 0.15; } |
| 100% { opacity: 1; } |
| } |
| |
| textarea, .gr-textbox textarea { |
| border-radius: 16px !important; |
| font-size: 15px !important; |
| } |
| |
| .gr-textbox label, .gr-markdown, .gr-button { |
| font-size: 14px !important; |
| } |
| |
| button { |
| border-radius: 14px !important; |
| min-height: 44px !important; |
| font-weight: 600 !important; |
| } |
| |
| .mobile-stack { |
| display: flex; |
| flex-direction: column; |
| gap: 12px; |
| } |
| |
| .mobile-scroll { |
| max-height: 34vh; |
| overflow-y: auto; |
| } |
| |
| .command-note { |
| color: #cbd5e1; |
| font-size: 0.88rem; |
| line-height: 1.45; |
| } |
| |
| @media (max-width: 1024px) { |
| .gradio-container { |
| padding: 10px !important; |
| } |
| |
| .hero-title { |
| font-size: 1.45rem; |
| } |
| |
| .hero-subtitle { |
| font-size: 0.92rem; |
| } |
| |
| #chatbot { |
| height: 56vh !important; |
| } |
| } |
| |
| @media (max-width: 768px) { |
| .gradio-container { |
| padding: 8px !important; |
| } |
| |
| .top-card { |
| padding: 14px; |
| border-radius: 18px; |
| } |
| |
| .hero-title { |
| font-size: 1.28rem; |
| } |
| |
| .hero-subtitle { |
| font-size: 0.88rem; |
| line-height: 1.45; |
| } |
| |
| .badge { |
| font-size: 10px; |
| padding: 5px 8px; |
| } |
| |
| .panel-wrap { |
| padding: 10px; |
| border-radius: 16px; |
| } |
| |
| #chatbot { |
| height: 52vh !important; |
| min-height: 320px !important; |
| border-radius: 16px !important; |
| } |
| |
| button { |
| width: 100% !important; |
| } |
| |
| .mobile-scroll { |
| max-height: 240px; |
| } |
| } |
| |
| @media (max-width: 480px) { |
| .hero-title { |
| font-size: 1.15rem; |
| } |
| |
| .hero-subtitle { |
| font-size: 0.83rem; |
| } |
| |
| #chatbot { |
| height: 50vh !important; |
| min-height: 300px !important; |
| } |
| |
| textarea, .gr-textbox textarea { |
| font-size: 14px !important; |
| } |
| } |
| """ |
|
|
|
|
| def hero_html() -> str: |
| return """ |
| <div class="top-card"> |
| <div class="hero-title">🫀 Mr Cardio</div> |
| <div class="hero-subtitle"> |
| ECG-focused clinical chatbot with RAG retrieval, local ECG reasoning, |
| and grounded evidence summaries. Mobile-friendly layout included. |
| </div> |
| <div class="badges"> |
| <div class="badge">ECG Reasoning</div> |
| <div class="badge">FAISS Retrieval</div> |
| <div class="badge">LoRA Adapter</div> |
| <div class="badge">Validated Output</div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def thinking_html(stage: str) -> str: |
| return f""" |
| <div class="status-card"> |
| <div style="display:flex;align-items:center;gap:12px;"> |
| <div style="font-size:19px;">⏳</div> |
| <div> |
| <div style="font-weight:700;">{stage}</div> |
| <div class="muted">Retrieval → reasoning → grounded answer</div> |
| <div class="blink-dots">...</div> |
| </div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def initialize_session(): |
| return {"chat_history": [], "last_result": None} |
|
|
|
|
| def add_assistant_placeholder(history, text="..."): |
| history = history or [] |
| history.append({ |
| "role": "assistant", |
| "content": text, |
| "metadata": {"title": "Thinking"} |
| }) |
| return history |
|
|
|
|
| def update_last_assistant_message(history, text, title=None): |
| history = history or [] |
| if not history or history[-1]["role"] != "assistant": |
| msg = {"role": "assistant", "content": text} |
| if title: |
| msg["metadata"] = {"title": title} |
| history.append(msg) |
| return history |
|
|
| history[-1] = {"role": "assistant", "content": text} |
| if title: |
| history[-1]["metadata"] = {"title": title} |
| return history |
|
|
|
|
| def user_submit(user_message, chat_ui_history): |
| chat_ui_history = chat_ui_history or [] |
| user_message = (user_message or "").strip() |
|
|
| if not user_message: |
| return "", chat_ui_history |
|
|
| chat_ui_history.append({"role": "user", "content": user_message}) |
| return "", chat_ui_history |
|
|
|
|
| |
| |
| |
| def run_chat_turn(user_message: str, memory_state: Dict) -> Dict: |
| if memory_state is None: |
| memory_state = {"chat_history": [], "last_result": None} |
|
|
| state_in = { |
| "user_query": user_message, |
| "chat_history": memory_state["chat_history"], |
| "retrieval_attempts": 0, |
| } |
|
|
| try: |
| result = graph.invoke(state_in) |
| except Exception as e: |
| logger.error(f"Graph invocation error: {e}") |
| traceback.print_exc() |
| result = { |
| "final_answer": f"I hit a runtime error while processing the request: {e}", |
| "best_score": -1.0, |
| "used_context": False, |
| "validation_status": "ERROR", |
| "retrieved_docs": [], |
| "context": "", |
| "reasoning_draft": "", |
| "retrieval_attempts": 0, |
| "retrieval_mode": "error", |
| } |
|
|
| answer = result.get("final_answer", "").strip() or "I could not generate an answer." |
| best_score = result.get("best_score", -1.0) |
| validation_status = result.get("validation_status", "N/A") |
| confidence = compute_confidence(result) |
|
|
| answer_with_footer = ( |
| f"{answer}\n\n---\n" |
| f"📊 confidence={confidence:.2f} | best_score={best_score:.3f} | validation={validation_status}" |
| ) |
|
|
| memory_state["chat_history"].append({"role": "user", "content": user_message}) |
| memory_state["chat_history"].append({"role": "assistant", "content": answer}) |
| memory_state["chat_history"] = memory_state["chat_history"][-12:] |
| memory_state["last_result"] = result |
|
|
| return { |
| "answer": answer_with_footer, |
| "memory_state": memory_state, |
| "sources_markdown": format_sources_minimal(result), |
| "debug_text": format_debug_text(result), |
| } |
|
|
|
|
| def bot_respond_stream(chat_ui_history, session_state): |
| global vectorstore |
|
|
| if session_state is None: |
| session_state = initialize_session() |
|
|
| if not chat_ui_history: |
| yield ( |
| chat_ui_history, |
| session_state, |
| "## Retrieved Sources\n\nNo sources yet.", |
| "No debug result yet.", |
| "" |
| ) |
| return |
|
|
| user_message = str(chat_ui_history[-1]["content"]).strip() |
|
|
| if user_message == "/sources": |
| result = session_state.get("last_result") |
| chat_ui_history.append({ |
| "role": "assistant", |
| "content": format_sources_minimal(result), |
| "metadata": {"title": "Sources"} |
| }) |
| yield ( |
| chat_ui_history, |
| session_state, |
| format_sources_minimal(result), |
| format_debug_text(result), |
| "" |
| ) |
| return |
|
|
| if user_message == "/debug": |
| result = session_state.get("last_result") |
| chat_ui_history.append({ |
| "role": "assistant", |
| "content": format_debug_text(result), |
| "metadata": {"title": "Debug"} |
| }) |
| yield ( |
| chat_ui_history, |
| session_state, |
| format_sources_minimal(result), |
| format_debug_text(result), |
| "" |
| ) |
| return |
|
|
| if user_message == "/rebuild": |
| if not cfg.allow_rebuild_vectorstore: |
| chat_ui_history.append({ |
| "role": "assistant", |
| "content": "Vector store rebuild is disabled on this Space.", |
| "metadata": {"title": "Restricted"} |
| }) |
| yield ( |
| chat_ui_history, |
| session_state, |
| format_sources_minimal(session_state.get("last_result")), |
| format_debug_text(session_state.get("last_result")), |
| "" |
| ) |
| return |
|
|
| chat_ui_history = add_assistant_placeholder(chat_ui_history) |
| yield ( |
| chat_ui_history, |
| session_state, |
| "", |
| "", |
| thinking_html("Rebuilding vector store") |
| ) |
|
|
| time.sleep(cfg.blink_stage_1) |
|
|
| chat_ui_history = update_last_assistant_message( |
| chat_ui_history, |
| "Rebuilding vector store and reloading embeddings...", |
| title="Maintenance" |
| ) |
| yield ( |
| chat_ui_history, |
| session_state, |
| "", |
| "", |
| thinking_html("Rebuilding vector store") |
| ) |
|
|
| build_vectorstore() |
| vectorstore = load_vectorstore() |
|
|
| chat_ui_history = update_last_assistant_message( |
| chat_ui_history, |
| "✅ Vector store rebuilt and reloaded.", |
| title="Done" |
| ) |
| yield ( |
| chat_ui_history, |
| session_state, |
| format_sources_minimal(session_state.get("last_result")), |
| format_debug_text(session_state.get("last_result")), |
| "" |
| ) |
| return |
|
|
| chat_ui_history = add_assistant_placeholder(chat_ui_history, text="...") |
| yield ( |
| chat_ui_history, |
| session_state, |
| "", |
| "", |
| thinking_html("Starting") |
| ) |
| time.sleep(cfg.blink_stage_1) |
|
|
| yield ( |
| chat_ui_history, |
| session_state, |
| "", |
| "", |
| thinking_html("Retrieving evidence") |
| ) |
| time.sleep(cfg.blink_stage_2) |
|
|
| yield ( |
| chat_ui_history, |
| session_state, |
| "", |
| "", |
| thinking_html("Running ECG adapter reasoning") |
| ) |
| time.sleep(cfg.blink_stage_3) |
|
|
| out = run_chat_turn(user_message, session_state) |
|
|
| yield ( |
| chat_ui_history, |
| session_state, |
| out["sources_markdown"], |
| out["debug_text"], |
| thinking_html("Generating grounded summary") |
| ) |
| time.sleep(cfg.blink_before_answer) |
|
|
| if cfg.enable_typewriter_stream: |
| for partial in stream_text(out["answer"], step=120): |
| chat_ui_history = update_last_assistant_message( |
| chat_ui_history, |
| partial, |
| title="Answer" |
| ) |
| yield ( |
| chat_ui_history, |
| session_state, |
| out["sources_markdown"], |
| out["debug_text"], |
| "" |
| ) |
|
|
| chat_ui_history = update_last_assistant_message( |
| chat_ui_history, |
| out["answer"], |
| title="Answer" |
| ) |
|
|
| yield ( |
| chat_ui_history, |
| out["memory_state"], |
| out["sources_markdown"], |
| out["debug_text"], |
| "" |
| ) |
|
|
|
|
| def clear_chat(): |
| return [], initialize_session(), "## Retrieved Sources\n\nNo sources yet.", "No debug result yet.", "" |
|
|
|
|
| def rebuild_from_button(session_state, chatbot_history): |
| global vectorstore |
|
|
| if not cfg.allow_rebuild_vectorstore: |
| chatbot_history = chatbot_history or [] |
| chatbot_history.append({ |
| "role": "assistant", |
| "content": "Vector store rebuild is disabled on this Space.", |
| "metadata": {"title": "Restricted"} |
| }) |
| return ( |
| chatbot_history, |
| session_state, |
| format_sources_minimal(session_state.get("last_result")), |
| format_debug_text(session_state.get("last_result")), |
| "" |
| ) |
|
|
| build_vectorstore() |
| vectorstore = load_vectorstore() |
|
|
| chatbot_history = chatbot_history or [] |
| chatbot_history.append({ |
| "role": "assistant", |
| "content": "✅ Vector store rebuilt and reloaded.", |
| "metadata": {"title": "Done"} |
| }) |
|
|
| return ( |
| chatbot_history, |
| session_state, |
| format_sources_minimal(session_state.get("last_result")), |
| format_debug_text(session_state.get("last_result")), |
| "" |
| ) |
|
|
|
|
| |
| |
| |
| with gr.Blocks( |
| title="Medical CSV RAG Chatbot", |
| css=CUSTOM_CSS, |
| theme=gr.themes.Soft( |
| primary_hue="indigo", |
| secondary_hue="blue", |
| neutral_hue="slate", |
| radius_size="lg", |
| text_size="md", |
| ), |
| ) as demo: |
|
|
| gr.HTML(hero_html()) |
|
|
| session_state = gr.State(initialize_session()) |
|
|
| with gr.Column(elem_classes=["mobile-stack"]): |
| with gr.Group(elem_classes=["panel-wrap"]): |
| chatbot = gr.Chatbot( |
| label="Clinical Chat", |
| height=640, |
| elem_id="chatbot", |
| type="messages", |
| show_copy_button=True, |
| bubble_full_width=False, |
| avatar_images=(None, None), |
| ) |
|
|
| user_box = gr.Textbox( |
| label="Ask a medical question", |
| placeholder="e.g. What are the ECG findings in hyperkalemia?", |
| lines=2, |
| autofocus=True, |
| ) |
|
|
| status_html = gr.HTML("") |
|
|
| with gr.Row(): |
| send_btn = gr.Button("Send", variant="primary") |
| clear_btn = gr.Button("Clear") |
| rebuild_btn = gr.Button("Rebuild Store") |
|
|
| gr.HTML( |
| """ |
| <div class="command-note"> |
| Commands: <code>/sources</code>, <code>/debug</code>, <code>/rebuild</code> |
| </div> |
| """ |
| ) |
|
|
| with gr.Accordion("Retrieved Sources", open=False): |
| with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]): |
| sources_panel = gr.Markdown("## Retrieved Sources\n\nNo sources yet.") |
|
|
| if cfg.show_debug_panel: |
| with gr.Accordion("Debug Panel", open=False): |
| with gr.Group(elem_classes=["panel-wrap", "mobile-scroll"]): |
| debug_panel = gr.Textbox( |
| label="Debug", |
| value="No debug result yet.", |
| lines=18, |
| max_lines=28, |
| interactive=False, |
| ) |
| else: |
| debug_panel = gr.Textbox(visible=False, value="") |
|
|
| submit_event = user_box.submit( |
| fn=user_submit, |
| inputs=[user_box, chatbot], |
| outputs=[user_box, chatbot], |
| queue=True, |
| ) |
|
|
| submit_event.then( |
| fn=bot_respond_stream, |
| inputs=[chatbot, session_state], |
| outputs=[chatbot, session_state, sources_panel, debug_panel, status_html], |
| queue=True, |
| ) |
|
|
| send_click = send_btn.click( |
| fn=user_submit, |
| inputs=[user_box, chatbot], |
| outputs=[user_box, chatbot], |
| queue=True, |
| ) |
|
|
| send_click.then( |
| fn=bot_respond_stream, |
| inputs=[chatbot, session_state], |
| outputs=[chatbot, session_state, sources_panel, debug_panel, status_html], |
| queue=True, |
| ) |
|
|
| clear_btn.click( |
| fn=clear_chat, |
| inputs=[], |
| outputs=[chatbot, session_state, sources_panel, debug_panel, status_html], |
| queue=False, |
| ) |
|
|
| rebuild_btn.click( |
| fn=rebuild_from_button, |
| inputs=[session_state, chatbot], |
| outputs=[chatbot, session_state, sources_panel, debug_panel, status_html], |
| queue=True, |
| ) |
|
|
| demo.queue(default_concurrency_limit=1) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| debug=cfg.launch_debug, |
| server_name=cfg.server_name, |
| server_port=cfg.server_port, |
| ) |