File size: 7,952 Bytes
aeb5696
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os, re, requests, json
from typing import List, Dict, Any, Tuple
from bs4 import BeautifulSoup
import numpy as np
import faiss
import streamlit as st

from sentence_transformers import SentenceTransformer  # Local HF model use

MEDLINE_WSEARCH = "https://wsearch.nlm.nih.gov/ws/query"
DISCLAIMER = ("This assistant provides general health information and is not a substitute for professional medical advice, "
              "diagnosis, or treatment. For personal medical concerns, consult a qualified clinician or seek emergency care for urgent symptoms.")

# --- Red flag patterns for basic triage ---
RED_FLAGS = [
    r"\b(chest pain|pressure in chest)\b",
    r"\b(trouble breathing|shortness of breath|severe breathlessness)\b",
    r"\b(signs of stroke|face droop|arm weakness|speech trouble|sudden confusion)\b",
    r"\b(severe allergic reaction|anaphylaxis|swelling of face|swelling of tongue)\b",
    r"\b(black stools|vomiting blood|severe bleeding)\b",
    r"\b(severe dehydration|no urination|sunken eyes)\b",
    r"\b(high fever|stiff neck|severe headache)\b",
]

def has_red_flags(text: str) -> bool:
    t = text.lower()
    return any(re.search(p, t) for p in RED_FLAGS)

# --- MedlinePlus search and fetch ---
def medline_search(term: str, retmax: int = 5, rettype: str = "brief") -> List[Dict[str, str]]:
    params = {"db": "healthTopics", "term": term, "retmax": str(retmax), "rettype": rettype}
    r = requests.get(MEDLINE_WSEARCH, params=params, timeout=10)
    r.raise_for_status()
    soup = BeautifulSoup(r.text, "xml")
    results = []
    for doc in soup.find_all("document"):
        title = doc.find("content", {"name": "title"})
        url = doc.find("content", {"name": "url"})
        snippet = doc.find("content", {"name": "snippet"}) or doc.find("content", {"name": "full-summary"})
        if title and url:
            results.append({"title": title.text.strip(), "url": url.text.strip(), "snippet": (snippet.text.strip() if snippet else "")})
    return results

def fetch_page_text(url: str, max_chars: int = 12000) -> str:
    r = requests.get(url, timeout=10)
    r.raise_for_status()
    soup = BeautifulSoup(r.text, "html.parser")
    for tag in soup(["script", "style", "nav", "footer", "header", "form", "aside"]):
        tag.decompose()
    text = soup.get_text(separator="\n")
    text = re.sub(r"\n{2,}", "\n", text)
    return text[:max_chars].strip()

def chunk_text(text: str, approx_tokens: int = 220) -> List[str]:
    words = text.split()
    chunks = []
    for i in range(0, len(words), approx_tokens):
        chunk = " ".join(words[i:i+approx_tokens])
        if len(chunk) > 40:
            chunks.append(chunk)
    return chunks

# --- Embeddings via Hugging Face ---
@st.cache_resource
def load_local_embedder():
    # Uses Hugging Face model from the Hub locally
    return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def hf_inference_embed(texts: List[str], hf_token: str) -> np.ndarray:
    # Uses Hugging Face Inference API directly to get embeddings from the model repo
    # Some providers return lists of vectors; normalize after
    api_url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
    headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
    # Batch once for simplicity; for large corpora, split into smaller requests
    resp = requests.post(api_url, headers=headers, json={"inputs": texts}, timeout=30)
    resp.raise_for_status()
    data = resp.json()
    # Handle potential {'error': ...} or streaming-like responses
    if isinstance(data, dict) and "error" in data:
        raise RuntimeError(data["error"])
    # Expect a list of vectors
    arr = np.array(data, dtype=np.float32)
    # L2 normalize for cosine similarity
    norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
    return arr / norms

def build_faiss(embeddings: np.ndarray) -> faiss.IndexFlatIP:
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings.astype(np.float32))
    return index

def search_index(index: faiss.IndexFlatIP, query_emb: np.ndarray, k: int = 6) -> Tuple[np.ndarray, np.ndarray]:
    D, I = index.search(query_emb.astype(np.float32), k)
    return D, I

def format_answer(query: str, hits: List[int], docs: List[Dict[str, str]], urgent: bool) -> str:
    grouped = {}
    for idx in hits:
        d = docs[idx]
        key = (d["source_title"], d["source_url"])
        grouped.setdefault(key, []).append(d["content"])
    lines = []
    if urgent:
        lines.append("Potential urgent symptoms detected. Consider seeking immediate care before self-care steps.")
    lines.append("What it is:\n- Below are excerpts from MedlinePlus topics related to the question.")
    lines.append("Common symptoms:\n- See excerpts; symptom overlap is common, confirm with a clinician.")
    lines.append("Self-care steps:\n- Follow patient-friendly guidance in the excerpts when appropriate.")
    lines.append("When to seek care:\n- New, severe, or worsening symptoms, or red flags such as chest pain, trouble breathing, stroke signs, or severe allergic reaction.")
    lines.append("Sources:")
    for (title, url), chunks in grouped.items():
        lines.append(f"- {title}{url}")
        for c in chunks[:2]:
            snippet = (c[:360] + "…") if len(c) > 360 else c
            lines.append(f"  • {snippet}")
    lines.append(DISCLAIMER)
    return "\n\n".join(lines)

st.set_page_config(page_title="MedAssist (HF MiniLM + MedlinePlus)", page_icon="🩺")
st.title("MedAssist: Hugging Face MiniLM + MedlinePlus")
st.info(DISCLAIMER)

with st.sidebar:
    st.header("Retriever settings")
    use_hf_api = st.checkbox("Use Hugging Face Inference API (else local)", value=False)
    hf_token = st.text_input("HF API Token (if API mode)", type="password")
    topk_urls = st.slider("MedlinePlus URLs to fetch", 1, 8, 4)
    chunks_per_url = st.slider("Chunks per URL", 2, 12, 6)
    topk = st.slider("Top chunks to return", 2, 12, 6)
    st.caption("MedlinePlus wsearch → fetch pages → MiniLM embeddings → FAISS semantic search")

query = st.text_input("Describe symptoms or enter a medical term")
if st.button("Search"):
    urgent = has_red_flags(query)
    try:
        topics = medline_search(query, retmax=topk_urls, rettype="brief")
    except Exception as e:
        st.error(f"MedlinePlus search failed: {e}")
        topics = []

    docs = []
    for t in topics:
        try:
            text = fetch_page_text(t["url"])
            chunks = chunk_text(text)[:chunks_per_url]
            for ch in chunks:
                docs.append({"source_title": t["title"], "source_url": t["url"], "content": ch})
        except Exception:
            continue

    if not docs:
        st.warning("No relevant MedlinePlus content found. Try a different term or consult a clinician.")
    else:
        texts = [d["content"] for d in docs]
        try:
            if use_hf_api:
                if not hf_token:
                    st.error("Provide a Hugging Face API token to use the Inference API.")
                    st.stop()
                doc_emb = hf_inference_embed(texts, hf_token)
                q_emb = hf_inference_embed([query], hf_token)
            else:
                model = load_local_embedder()  # Downloads from Hugging Face Hub
                doc_emb = model.encode(texts, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
                q_emb = model.encode([query], normalize_embeddings=True)
        except Exception as e:
            st.error(f"Embedding failed: {e}")
            st.stop()

        index = build_faiss(np.array(doc_emb, dtype=np.float32))
        D, I = search_index(index, np.array(q_emb, dtype=np.float32), k=topk)
        hit_ids = [int(i) for i in I[0] if i >= 0]
        answer = format_answer(query, hit_ids, docs, urgent)
        st.markdown(answer)