Revanthr75 commited on
Commit
aeb5696
·
verified ·
1 Parent(s): 372cde3
Files changed (1) hide show
  1. med.ipynb +174 -0
med.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, requests, json
2
+ from typing import List, Dict, Any, Tuple
3
+ from bs4 import BeautifulSoup
4
+ import numpy as np
5
+ import faiss
6
+ import streamlit as st
7
+
8
+ from sentence_transformers import SentenceTransformer # Local HF model use
9
+
10
+ MEDLINE_WSEARCH = "https://wsearch.nlm.nih.gov/ws/query"
11
+ DISCLAIMER = ("This assistant provides general health information and is not a substitute for professional medical advice, "
12
+ "diagnosis, or treatment. For personal medical concerns, consult a qualified clinician or seek emergency care for urgent symptoms.")
13
+
14
+ # --- Red flag patterns for basic triage ---
15
+ RED_FLAGS = [
16
+ r"\b(chest pain|pressure in chest)\b",
17
+ r"\b(trouble breathing|shortness of breath|severe breathlessness)\b",
18
+ r"\b(signs of stroke|face droop|arm weakness|speech trouble|sudden confusion)\b",
19
+ r"\b(severe allergic reaction|anaphylaxis|swelling of face|swelling of tongue)\b",
20
+ r"\b(black stools|vomiting blood|severe bleeding)\b",
21
+ r"\b(severe dehydration|no urination|sunken eyes)\b",
22
+ r"\b(high fever|stiff neck|severe headache)\b",
23
+ ]
24
+
25
+ def has_red_flags(text: str) -> bool:
26
+ t = text.lower()
27
+ return any(re.search(p, t) for p in RED_FLAGS)
28
+
29
+ # --- MedlinePlus search and fetch ---
30
+ def medline_search(term: str, retmax: int = 5, rettype: str = "brief") -> List[Dict[str, str]]:
31
+ params = {"db": "healthTopics", "term": term, "retmax": str(retmax), "rettype": rettype}
32
+ r = requests.get(MEDLINE_WSEARCH, params=params, timeout=10)
33
+ r.raise_for_status()
34
+ soup = BeautifulSoup(r.text, "xml")
35
+ results = []
36
+ for doc in soup.find_all("document"):
37
+ title = doc.find("content", {"name": "title"})
38
+ url = doc.find("content", {"name": "url"})
39
+ snippet = doc.find("content", {"name": "snippet"}) or doc.find("content", {"name": "full-summary"})
40
+ if title and url:
41
+ results.append({"title": title.text.strip(), "url": url.text.strip(), "snippet": (snippet.text.strip() if snippet else "")})
42
+ return results
43
+
44
+ def fetch_page_text(url: str, max_chars: int = 12000) -> str:
45
+ r = requests.get(url, timeout=10)
46
+ r.raise_for_status()
47
+ soup = BeautifulSoup(r.text, "html.parser")
48
+ for tag in soup(["script", "style", "nav", "footer", "header", "form", "aside"]):
49
+ tag.decompose()
50
+ text = soup.get_text(separator="\n")
51
+ text = re.sub(r"\n{2,}", "\n", text)
52
+ return text[:max_chars].strip()
53
+
54
+ def chunk_text(text: str, approx_tokens: int = 220) -> List[str]:
55
+ words = text.split()
56
+ chunks = []
57
+ for i in range(0, len(words), approx_tokens):
58
+ chunk = " ".join(words[i:i+approx_tokens])
59
+ if len(chunk) > 40:
60
+ chunks.append(chunk)
61
+ return chunks
62
+
63
+ # --- Embeddings via Hugging Face ---
64
+ @st.cache_resource
65
+ def load_local_embedder():
66
+ # Uses Hugging Face model from the Hub locally
67
+ return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
68
+
69
+ def hf_inference_embed(texts: List[str], hf_token: str) -> np.ndarray:
70
+ # Uses Hugging Face Inference API directly to get embeddings from the model repo
71
+ # Some providers return lists of vectors; normalize after
72
+ api_url = "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
73
+ headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
74
+ # Batch once for simplicity; for large corpora, split into smaller requests
75
+ resp = requests.post(api_url, headers=headers, json={"inputs": texts}, timeout=30)
76
+ resp.raise_for_status()
77
+ data = resp.json()
78
+ # Handle potential {'error': ...} or streaming-like responses
79
+ if isinstance(data, dict) and "error" in data:
80
+ raise RuntimeError(data["error"])
81
+ # Expect a list of vectors
82
+ arr = np.array(data, dtype=np.float32)
83
+ # L2 normalize for cosine similarity
84
+ norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
85
+ return arr / norms
86
+
87
+ def build_faiss(embeddings: np.ndarray) -> faiss.IndexFlatIP:
88
+ dim = embeddings.shape[1]
89
+ index = faiss.IndexFlatIP(dim)
90
+ index.add(embeddings.astype(np.float32))
91
+ return index
92
+
93
+ def search_index(index: faiss.IndexFlatIP, query_emb: np.ndarray, k: int = 6) -> Tuple[np.ndarray, np.ndarray]:
94
+ D, I = index.search(query_emb.astype(np.float32), k)
95
+ return D, I
96
+
97
+ def format_answer(query: str, hits: List[int], docs: List[Dict[str, str]], urgent: bool) -> str:
98
+ grouped = {}
99
+ for idx in hits:
100
+ d = docs[idx]
101
+ key = (d["source_title"], d["source_url"])
102
+ grouped.setdefault(key, []).append(d["content"])
103
+ lines = []
104
+ if urgent:
105
+ lines.append("Potential urgent symptoms detected. Consider seeking immediate care before self-care steps.")
106
+ lines.append("What it is:\n- Below are excerpts from MedlinePlus topics related to the question.")
107
+ lines.append("Common symptoms:\n- See excerpts; symptom overlap is common, confirm with a clinician.")
108
+ lines.append("Self-care steps:\n- Follow patient-friendly guidance in the excerpts when appropriate.")
109
+ lines.append("When to seek care:\n- New, severe, or worsening symptoms, or red flags such as chest pain, trouble breathing, stroke signs, or severe allergic reaction.")
110
+ lines.append("Sources:")
111
+ for (title, url), chunks in grouped.items():
112
+ lines.append(f"- {title} — {url}")
113
+ for c in chunks[:2]:
114
+ snippet = (c[:360] + "…") if len(c) > 360 else c
115
+ lines.append(f" • {snippet}")
116
+ lines.append(DISCLAIMER)
117
+ return "\n\n".join(lines)
118
+
119
+ st.set_page_config(page_title="MedAssist (HF MiniLM + MedlinePlus)", page_icon="🩺")
120
+ st.title("MedAssist: Hugging Face MiniLM + MedlinePlus")
121
+ st.info(DISCLAIMER)
122
+
123
+ with st.sidebar:
124
+ st.header("Retriever settings")
125
+ use_hf_api = st.checkbox("Use Hugging Face Inference API (else local)", value=False)
126
+ hf_token = st.text_input("HF API Token (if API mode)", type="password")
127
+ topk_urls = st.slider("MedlinePlus URLs to fetch", 1, 8, 4)
128
+ chunks_per_url = st.slider("Chunks per URL", 2, 12, 6)
129
+ topk = st.slider("Top chunks to return", 2, 12, 6)
130
+ st.caption("MedlinePlus wsearch → fetch pages → MiniLM embeddings → FAISS semantic search")
131
+
132
+ query = st.text_input("Describe symptoms or enter a medical term")
133
+ if st.button("Search"):
134
+ urgent = has_red_flags(query)
135
+ try:
136
+ topics = medline_search(query, retmax=topk_urls, rettype="brief")
137
+ except Exception as e:
138
+ st.error(f"MedlinePlus search failed: {e}")
139
+ topics = []
140
+
141
+ docs = []
142
+ for t in topics:
143
+ try:
144
+ text = fetch_page_text(t["url"])
145
+ chunks = chunk_text(text)[:chunks_per_url]
146
+ for ch in chunks:
147
+ docs.append({"source_title": t["title"], "source_url": t["url"], "content": ch})
148
+ except Exception:
149
+ continue
150
+
151
+ if not docs:
152
+ st.warning("No relevant MedlinePlus content found. Try a different term or consult a clinician.")
153
+ else:
154
+ texts = [d["content"] for d in docs]
155
+ try:
156
+ if use_hf_api:
157
+ if not hf_token:
158
+ st.error("Provide a Hugging Face API token to use the Inference API.")
159
+ st.stop()
160
+ doc_emb = hf_inference_embed(texts, hf_token)
161
+ q_emb = hf_inference_embed([query], hf_token)
162
+ else:
163
+ model = load_local_embedder() # Downloads from Hugging Face Hub
164
+ doc_emb = model.encode(texts, normalize_embeddings=True, batch_size=32, show_progress_bar=False)
165
+ q_emb = model.encode([query], normalize_embeddings=True)
166
+ except Exception as e:
167
+ st.error(f"Embedding failed: {e}")
168
+ st.stop()
169
+
170
+ index = build_faiss(np.array(doc_emb, dtype=np.float32))
171
+ D, I = search_index(index, np.array(q_emb, dtype=np.float32), k=topk)
172
+ hit_ids = [int(i) for i in I[0] if i >= 0]
173
+ answer = format_answer(query, hit_ids, docs, urgent)
174
+ st.markdown(answer)