File size: 7,952 Bytes
aeb5696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os, re, requests, json
from typing import List, Dict, Any, Tuple
from bs4 import BeautifulSoup
import numpy as np
import faiss
import streamlit as st
from sentence_transformers import SentenceTransformer # Local HF model use
MEDLINE_WSEARCH = "https://wsearch.nlm.nih.gov/ws/query"
DISCLAIMER = ("This assistant provides general health information and is not a substitute for professional medical advice, "
"diagnosis, or treatment. For personal medical concerns, consult a qualified clinician or seek emergency care for urgent symptoms.")
# --- Red flag patterns for basic triage ---
RED_FLAGS = [
r"\b(chest pain|pressure in chest)\b",
r"\b(trouble breathing|shortness of breath|severe breathlessness)\b",
r"\b(signs of stroke|face droop|arm weakness|speech trouble|sudden confusion)\b",
r"\b(severe allergic reaction|anaphylaxis|swelling of face|swelling of tongue)\b",
r"\b(black stools|vomiting blood|severe bleeding)\b",
r"\b(severe dehydration|no urination|sunken eyes)\b",
r"\b(high fever|stiff neck|severe headache)\b",
]
def has_red_flags(text: str) -> bool:
t = text.lower()
return any(re.search(p, t) for p in RED_FLAGS)
# --- MedlinePlus search and fetch ---
def medline_search(term: str, retmax: int = 5, rettype: str = "brief") -> List[Dict[str, str]]:
params = {"db": "healthTopics", "term": term, "retmax": str(retmax), "rettype": rettype}
r = requests.get(MEDLINE_WSEARCH, params=params, timeout=10)
r.raise_for_status()
soup = BeautifulSoup(r.text, "xml")
results = []
for doc in soup.find_all("document"):
title = doc.find("content", {"name": "title"})
url = doc.find("content", {"name": "url"})
snippet = doc.find("content", {"name": "snippet"}) or doc.find("content", {"name": "full-summary"})
if title and url:
results.append({"title": title.text.strip(), "url": url.text.strip(), "snippet": (snippet.text.strip() if snippet else "")})
return results
def fetch_page_text(url: str, max_chars: int = 12000) -> str:
r = requests.get(url, timeout=10)
r.raise_for_status()
soup = BeautifulSoup(r.text, "html.parser")
for tag in soup(["script", "style", "nav", "footer", "header", "form", "aside"]):
tag.decompose()
text = soup.get_text(separator="\n")
text = re.sub(r"\n{2,}", "\n", text)
return text[:max_chars].strip()
def chunk_text(text: str, approx_tokens: int = 220) -> List[str]:
words = text.split()
chunks = []
for i in range(0, len(words), approx_tokens):
chunk = " ".join(words[i:i+approx_tokens])
if len(chunk) > 40:
chunks.append(chunk)
return chunks
# --- Embeddings via Hugging Face ---
@st.cache_resource
def load_local_embedder():
# Uses Hugging Face model from the Hub locally
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def hf_inference_embed(texts: List[str], hf_token: str) -> np.ndarray:
# Uses Hugging Face Inference API directly to get embeddings from the model repo
# Some providers return lists of vectors; normalize after
api_url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
# Batch once for simplicity; for large corpora, split into smaller requests
resp = requests.post(api_url, headers=headers, json={"inputs": texts}, timeout=30)
resp.raise_for_status()
data = resp.json()
# Handle potential {'error': ...} or streaming-like responses
if isinstance(data, dict) and "error" in data:
raise RuntimeError(data["error"])
# Expect a list of vectors
arr = np.array(data, dtype=np.float32)
# L2 normalize for cosine similarity
norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
return arr / norms
def build_faiss(embeddings: np.ndarray) -> faiss.IndexFlatIP:
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings.astype(np.float32))
return index
def search_index(index: faiss.IndexFlatIP, query_emb: np.ndarray, k: int = 6) -> Tuple[np.ndarray, np.ndarray]:
D, I = index.search(query_emb.astype(np.float32), k)
return D, I
def format_answer(query: str, hits: List[int], docs: List[Dict[str, str]], urgent: bool) -> str:
grouped = {}
for idx in hits:
d = docs[idx]
key = (d["source_title"], d["source_url"])
grouped.setdefault(key, []).append(d["content"])
lines = []
if urgent:
lines.append("Potential urgent symptoms detected. Consider seeking immediate care before self-care steps.")
lines.append("What it is:\n- Below are excerpts from MedlinePlus topics related to the question.")
lines.append("Common symptoms:\n- See excerpts; symptom overlap is common, confirm with a clinician.")
lines.append("Self-care steps:\n- Follow patient-friendly guidance in the excerpts when appropriate.")
lines.append("When to seek care:\n- New, severe, or worsening symptoms, or red flags such as chest pain, trouble breathing, stroke signs, or severe allergic reaction.")
lines.append("Sources:")
for (title, url), chunks in grouped.items():
lines.append(f"- {title} — {url}")
for c in chunks[:2]:
snippet = (c[:360] + "…") if len(c) > 360 else c
lines.append(f" • {snippet}")
lines.append(DISCLAIMER)
return "\n\n".join(lines)
st.set_page_config(page_title="MedAssist (HF MiniLM + MedlinePlus)", page_icon="🩺")
st.title("MedAssist: Hugging Face MiniLM + MedlinePlus")
st.info(DISCLAIMER)
with st.sidebar:
st.header("Retriever settings")
use_hf_api = st.checkbox("Use Hugging Face Inference API (else local)", value=False)
hf_token = st.text_input("HF API Token (if API mode)", type="password")
topk_urls = st.slider("MedlinePlus URLs to fetch", 1, 8, 4)
chunks_per_url = st.slider("Chunks per URL", 2, 12, 6)
topk = st.slider("Top chunks to return", 2, 12, 6)
st.caption("MedlinePlus wsearch → fetch pages → MiniLM embeddings → FAISS semantic search")
query = st.text_input("Describe symptoms or enter a medical term")
if st.button("Search"):
urgent = has_red_flags(query)
try:
topics = medline_search(query, retmax=topk_urls, rettype="brief")
except Exception as e:
st.error(f"MedlinePlus search failed: {e}")
topics = []
docs = []
for t in topics:
try:
text = fetch_page_text(t["url"])
chunks = chunk_text(text)[:chunks_per_url]
for ch in chunks:
docs.append({"source_title": t["title"], "source_url": t["url"], "content": ch})
except Exception:
continue
if not docs:
st.warning("No relevant MedlinePlus content found. Try a different term or consult a clinician.")
else:
texts = [d["content"] for d in docs]
try:
if use_hf_api:
if not hf_token:
st.error("Provide a Hugging Face API token to use the Inference API.")
st.stop()
doc_emb = hf_inference_embed(texts, hf_token)
q_emb = hf_inference_embed([query], hf_token)
else:
model = load_local_embedder() # Downloads from Hugging Face Hub
doc_emb = model.encode(texts, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
q_emb = model.encode([query], normalize_embeddings=True)
except Exception as e:
st.error(f"Embedding failed: {e}")
st.stop()
index = build_faiss(np.array(doc_emb, dtype=np.float32))
D, I = search_index(index, np.array(q_emb, dtype=np.float32), k=topk)
hit_ids = [int(i) for i in I[0] if i >= 0]
answer = format_answer(query, hit_ids, docs, urgent)
st.markdown(answer)
|