import os, re, requests, json from typing import List, Dict, Any, Tuple from bs4 import BeautifulSoup import numpy as np import faiss import streamlit as st from sentence_transformers import SentenceTransformer # Local HF model use MEDLINE_WSEARCH = "https://wsearch.nlm.nih.gov/ws/query" DISCLAIMER = ("This assistant provides general health information and is not a substitute for professional medical advice, " "diagnosis, or treatment. For personal medical concerns, consult a qualified clinician or seek emergency care for urgent symptoms.") # --- Red flag patterns for basic triage --- RED_FLAGS = [ r"\b(chest pain|pressure in chest)\b", r"\b(trouble breathing|shortness of breath|severe breathlessness)\b", r"\b(signs of stroke|face droop|arm weakness|speech trouble|sudden confusion)\b", r"\b(severe allergic reaction|anaphylaxis|swelling of face|swelling of tongue)\b", r"\b(black stools|vomiting blood|severe bleeding)\b", r"\b(severe dehydration|no urination|sunken eyes)\b", r"\b(high fever|stiff neck|severe headache)\b", ] def has_red_flags(text: str) -> bool: t = text.lower() return any(re.search(p, t) for p in RED_FLAGS) # --- MedlinePlus search and fetch --- def medline_search(term: str, retmax: int = 5, rettype: str = "brief") -> List[Dict[str, str]]: params = {"db": "healthTopics", "term": term, "retmax": str(retmax), "rettype": rettype} r = requests.get(MEDLINE_WSEARCH, params=params, timeout=10) r.raise_for_status() soup = BeautifulSoup(r.text, "xml") results = [] for doc in soup.find_all("document"): title = doc.find("content", {"name": "title"}) url = doc.find("content", {"name": "url"}) snippet = doc.find("content", {"name": "snippet"}) or doc.find("content", {"name": "full-summary"}) if title and url: results.append({"title": title.text.strip(), "url": url.text.strip(), "snippet": (snippet.text.strip() if snippet else "")}) return results def fetch_page_text(url: str, max_chars: int = 12000) -> str: r = requests.get(url, timeout=10) r.raise_for_status() soup = BeautifulSoup(r.text, "html.parser") for tag in soup(["script", "style", "nav", "footer", "header", "form", "aside"]): tag.decompose() text = soup.get_text(separator="\n") text = re.sub(r"\n{2,}", "\n", text) return text[:max_chars].strip() def chunk_text(text: str, approx_tokens: int = 220) -> List[str]: words = text.split() chunks = [] for i in range(0, len(words), approx_tokens): chunk = " ".join(words[i:i+approx_tokens]) if len(chunk) > 40: chunks.append(chunk) return chunks # --- Embeddings via Hugging Face --- @st.cache_resource def load_local_embedder(): # Uses Hugging Face model from the Hub locally return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") def hf_inference_embed(texts: List[str], hf_token: str) -> np.ndarray: # Uses Hugging Face Inference API directly to get embeddings from the model repo # Some providers return lists of vectors; normalize after api_url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2" headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"} # Batch once for simplicity; for large corpora, split into smaller requests resp = requests.post(api_url, headers=headers, json={"inputs": texts}, timeout=30) resp.raise_for_status() data = resp.json() # Handle potential {'error': ...} or streaming-like responses if isinstance(data, dict) and "error" in data: raise RuntimeError(data["error"]) # Expect a list of vectors arr = np.array(data, dtype=np.float32) # L2 normalize for cosine similarity norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12 return arr / norms def build_faiss(embeddings: np.ndarray) -> faiss.IndexFlatIP: dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings.astype(np.float32)) return index def search_index(index: faiss.IndexFlatIP, query_emb: np.ndarray, k: int = 6) -> Tuple[np.ndarray, np.ndarray]: D, I = index.search(query_emb.astype(np.float32), k) return D, I def format_answer(query: str, hits: List[int], docs: List[Dict[str, str]], urgent: bool) -> str: grouped = {} for idx in hits: d = docs[idx] key = (d["source_title"], d["source_url"]) grouped.setdefault(key, []).append(d["content"]) lines = [] if urgent: lines.append("Potential urgent symptoms detected. Consider seeking immediate care before self-care steps.") lines.append("What it is:\n- Below are excerpts from MedlinePlus topics related to the question.") lines.append("Common symptoms:\n- See excerpts; symptom overlap is common, confirm with a clinician.") lines.append("Self-care steps:\n- Follow patient-friendly guidance in the excerpts when appropriate.") lines.append("When to seek care:\n- New, severe, or worsening symptoms, or red flags such as chest pain, trouble breathing, stroke signs, or severe allergic reaction.") lines.append("Sources:") for (title, url), chunks in grouped.items(): lines.append(f"- {title} — {url}") for c in chunks[:2]: snippet = (c[:360] + "…") if len(c) > 360 else c lines.append(f" • {snippet}") lines.append(DISCLAIMER) return "\n\n".join(lines) st.set_page_config(page_title="MedAssist (HF MiniLM + MedlinePlus)", page_icon="🩺") st.title("MedAssist: Hugging Face MiniLM + MedlinePlus") st.info(DISCLAIMER) with st.sidebar: st.header("Retriever settings") use_hf_api = st.checkbox("Use Hugging Face Inference API (else local)", value=False) hf_token = st.text_input("HF API Token (if API mode)", type="password") topk_urls = st.slider("MedlinePlus URLs to fetch", 1, 8, 4) chunks_per_url = st.slider("Chunks per URL", 2, 12, 6) topk = st.slider("Top chunks to return", 2, 12, 6) st.caption("MedlinePlus wsearch → fetch pages → MiniLM embeddings → FAISS semantic search") query = st.text_input("Describe symptoms or enter a medical term") if st.button("Search"): urgent = has_red_flags(query) try: topics = medline_search(query, retmax=topk_urls, rettype="brief") except Exception as e: st.error(f"MedlinePlus search failed: {e}") topics = [] docs = [] for t in topics: try: text = fetch_page_text(t["url"]) chunks = chunk_text(text)[:chunks_per_url] for ch in chunks: docs.append({"source_title": t["title"], "source_url": t["url"], "content": ch}) except Exception: continue if not docs: st.warning("No relevant MedlinePlus content found. Try a different term or consult a clinician.") else: texts = [d["content"] for d in docs] try: if use_hf_api: if not hf_token: st.error("Provide a Hugging Face API token to use the Inference API.") st.stop() doc_emb = hf_inference_embed(texts, hf_token) q_emb = hf_inference_embed([query], hf_token) else: model = load_local_embedder() # Downloads from Hugging Face Hub doc_emb = model.encode(texts, normalize_embeddings=True, batch_size=32, show_progress_bar=False) q_emb = model.encode([query], normalize_embeddings=True) except Exception as e: st.error(f"Embedding failed: {e}") st.stop() index = build_faiss(np.array(doc_emb, dtype=np.float32)) D, I = search_index(index, np.array(q_emb, dtype=np.float32), k=topk) hit_ids = [int(i) for i in I[0] if i >= 0] answer = format_answer(query, hit_ids, docs, urgent) st.markdown(answer)