Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| from pathlib import Path | |
| # Embeddings and generation | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # Optional FAISS for scalable retrieval | |
| try: | |
| import faiss | |
| FAISS_OK = True | |
| except Exception: | |
| FAISS_OK = False | |
| APP_DIR = Path(__file__).parent | |
| DATA_PATH = APP_DIR / "data" / "med_corpus.json" | |
| DISCLAIMER = ( | |
| "This tool does not provide medical advice. It is for information only and " | |
| "is not a substitute for professional medical advice, diagnosis, or treatment. " | |
| "Call your local emergency number for urgent symptoms." | |
| ) | |
| EMERGENCY_KEYWORDS = [ | |
| "chest pain", "severe chest pain", "shortness of breath", "trouble breathing", | |
| "stroke", "face drooping", "slurred speech", "arm weakness", "uncontrolled bleeding", | |
| "fainting", "passed out", "loss of consciousness", "seizure", "head injury", | |
| "allergic reaction", "anaphylaxis", "suicidal", "homicidal", "poison", "overdose" | |
| ] | |
| def load_corpus() -> List[Dict]: | |
| with open(DATA_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def build_documents(corpus: List[Dict]) -> Tuple[List[str], List[str]]: | |
| texts = [] | |
| meta = [] | |
| for item in corpus: | |
| txt = f"TITLE: {item['title']}\nKEY_TERMS: {', '.join(item.get('key_terms', []))}\nCONTENT: {item['summary']}" | |
| texts.append(txt) | |
| meta.append(json.dumps({"title": item["title"], "url": item["url"]})) | |
| return texts, meta | |
| # Initialize models lazily to speed cold start | |
| _EMBED = None | |
| _GEN_PIPE = None | |
| _FAISS_INDEX = None | |
| _DOC_TEXTS = None | |
| _DOC_META = None | |
| def get_embedder(): | |
| global _EMBED | |
| if _EMBED is None: | |
| # small, CPU friendly | |
| _EMBED = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| return _EMBED | |
| def get_generator(): | |
| global _GEN_PIPE | |
| if _GEN_PIPE is None: | |
| model_name = "google/flan-t5-base" # CPU-friendly | |
| _GEN_PIPE = pipeline("text2text-generation", model=model_name, tokenizer=model_name) | |
| return _GEN_PIPE | |
| def prepare_index(): | |
| global _FAISS_INDEX, _DOC_TEXTS, _DOC_META | |
| if _DOC_TEXTS is None: | |
| corpus = load_corpus() | |
| _DOC_TEXTS, _DOC_META = build_documents(corpus) | |
| if FAISS_OK and _FAISS_INDEX is None: | |
| emb = get_embedder() | |
| vecs = emb.encode(_DOC_TEXTS, convert_to_numpy=True, show_progress_bar=False, normalize_embeddings=True) | |
| d = vecs.shape[1] | |
| index = faiss.IndexFlatIP(d) | |
| index.add(vecs) | |
| _FAISS_INDEX = index | |
| return _FAISS_INDEX is not None | |
| def retrieve(query: str, k: int = 4) -> List[Dict]: | |
| prepare_index() | |
| emb = get_embedder() | |
| q = emb.encode([query], convert_to_numpy=True, normalize_embeddings=True) | |
| if FAISS_OK and _FAISS_INDEX is not None: | |
| D, I = _FAISS_INDEX.search(q, min(k, len(_DOC_TEXTS))) | |
| idxs = I[0].tolist() | |
| else: | |
| # fallback cosine similarity without faiss | |
| doc_vecs = emb.encode(_DOC_TEXTS, convert_to_numpy=True, normalize_embeddings=True) | |
| sims = (doc_vecs @ q[0]) | |
| idxs = np.argsort(-sims)[:k].tolist() | |
| results = [] | |
| for i in idxs: | |
| meta = json.loads(_DOC_META[i]) | |
| results.append({ | |
| "text": _DOC_TEXTS[i], | |
| "title": meta["title"], | |
| "url": meta["url"] | |
| }) | |
| return results | |
| def looks_emergent(text: str) -> bool: | |
| t = text.lower() | |
| return any(kw in t for kw in EMERGENCY_KEYWORDS) | |
| def make_prompt(question: str, contexts: List[Dict], reading_level: str = "plain") -> str: | |
| level_instruction = { | |
| "plain": "Explain like a nurse to a patient in clear, simple language.", | |
| "detailed": "Explain in clear lay language with a bit more detail and definitions.", | |
| }[reading_level] | |
| context_strs = [] | |
| for i, c in enumerate(contexts, 1): | |
| context_strs.append(f"[Source {i}] {c['text']} (URL: {c['url']})") | |
| context_block = "\n\n".join(context_strs) | |
| prompt = f""" | |
| You are MedAssist, a cautious health explainer. Use only the information in the sources below. | |
| Write a concise answer in 120-220 words. Use bullet points where helpful. | |
| Cite sources as [1], [2], etc. Then append the URLs at the end as "Sources: ...". | |
| Include a "Next steps" mini-checklist. Always end with this disclaimer verbatim: | |
| "{DISCLAIMER}" | |
| User question: {question} | |
| Reading level: {level_instruction} | |
| Sources: | |
| {context_block} | |
| """ | |
| return prompt.strip() | |
| def answer_question(message: str, reading_level: str): | |
| # Hard stop for emergencies | |
| if looks_emergent(message): | |
| emergency_msg = ( | |
| "Possible emergency detected based on your description. Call your local emergency number now. " | |
| "If available, seek the nearest emergency department.\n\n" | |
| + DISCLAIMER | |
| ) | |
| return emergency_msg, None | |
| # Retrieve | |
| docs = retrieve(message, k=4) | |
| # Generate | |
| gen = get_generator() | |
| prompt = make_prompt(message, docs, reading_level) | |
| out = gen(prompt, max_new_tokens=320, do_sample=False)[0]["generated_text"] | |
| # Collect source list | |
| srcs = [f"[{i+1}] {d['title']} — {d['url']}" for i, d in enumerate(docs)] | |
| return out, "\n".join(srcs) | |
| CSS = """ | |
| #disclaimer {font-size: 12px; color: #444;} | |
| #title {font-weight: 700; font-size: 22px;} | |
| """ | |
| def ui(): | |
| with gr.Blocks(css=CSS, title="MedAssist: AI Healthcare Q&A") as demo: | |
| gr.Markdown("# MedAssist: AI Healthcare Q&A") | |
| gr.Markdown( | |
| "Ask about conditions, symptoms, or treatments. The system retrieves from a small curated corpus derived from reputable consumer-health sources." | |
| ) | |
| with gr.Row(): | |
| reading = gr.Radio(choices=["plain", "detailed"], value="plain", label="Reading level") | |
| msg = gr.Textbox(label="Your question", placeholder="e.g., What are the common symptoms of a urinary tract infection?") | |
| btn = gr.Button("Get answer") | |
| answer = gr.Markdown(label="Answer") | |
| sources = gr.Markdown(label="Retrieved sources") | |
| with gr.Accordion("Disclaimer", open=False): | |
| gr.Markdown(DISCLAIMER, elem_id="disclaimer") | |
| # Demo examples | |
| gr.Examples( | |
| examples=[ | |
| "I have burning when I pee and need to go often. Could it be a UTI?", | |
| "What are the warning signs of a stroke?", | |
| "How is type 2 diabetes managed?", | |
| ], | |
| inputs=msg | |
| ) | |
| def _on_click(q, level): | |
| text, srcs = answer_question(q, level) | |
| return text, (srcs or "") | |
| btn.click(_on_click, inputs=[msg, reading], outputs=[answer, sources]) | |
| return demo | |
| if __name__ == "__main__": | |
| ui().launch() | |