mshabir commited on
Commit
0b0695e
Β·
verified Β·
1 Parent(s): b9cdd5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -346
app.py CHANGED
@@ -1,371 +1,179 @@
1
  import streamlit as st
2
- from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
3
- from langchain_community.vectorstores import FAISS
4
- from langchain.chains import RetrievalQA
5
- from langchain_community.llms import HuggingFaceHub
6
- from langchain_community.document_loaders import TextLoader
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from dotenv import load_dotenv
9
  import os
10
- import pickle
11
- from pathlib import Path
12
 
13
- # Load environment variables
14
- load_dotenv()
15
-
16
- # Streamlit page configuration
17
  st.set_page_config(
18
- page_title="Medical QA Assistant",
19
  page_icon="πŸ₯",
20
  layout="wide"
21
  )
22
 
23
- # Check for required files
24
- @st.cache_resource
25
- def check_files():
26
- """Check if required files exist and provide guidance if not"""
27
- faiss_index_path = Path("medical_faiss_store/medical_faiss.faiss")
28
- faiss_pkl_path = Path("medical_faiss_store/medical_faiss.pkl")
29
-
30
- if not faiss_index_path.exists() or not faiss_pkl_path.exists():
31
- return False
32
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Initialize embeddings
35
- @st.cache_resource
36
- def load_embeddings():
37
- """Load the HuggingFace embeddings model"""
38
- try:
39
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
40
- embeddings = HuggingFaceEmbeddings(model_name=model_name)
41
- return embeddings
42
- except Exception as e:
43
- st.error(f"Error loading embeddings: {e}")
44
- return None
 
45
 
46
- # Create FAISS index if it doesn't exist
47
- def create_faiss_index():
48
- """Create FAISS index from sample medical data"""
49
- try:
50
- # Create sample medical data
51
- sample_text = """
52
- Diabetes is a chronic disease that occurs when the pancreas does not produce enough insulin.
53
- Symptoms include increased thirst, frequent urination, and unexplained weight loss.
54
- Type 1 diabetes is usually diagnosed in children and requires insulin injections.
55
- Type 2 diabetes is more common in adults and can be managed with diet, exercise, and medication.
56
-
57
- Hypertension, or high blood pressure, is when blood pressure is consistently too high.
58
- Normal blood pressure is below 120/80 mmHg.
59
- Symptoms may include headaches, shortness of breath, and nosebleeds.
60
- Treatment includes lifestyle changes like reducing salt intake and medication.
61
-
62
- Asthma is a condition where airways narrow and swell, producing extra mucus.
63
- Symptoms include wheezing, coughing, chest tightness, and shortness of breath.
64
- Asthma attacks can be triggered by allergens, exercise, or cold air.
65
- Treatment involves inhalers (bronchodilators and corticosteroids).
66
-
67
- COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus.
68
- Symptoms include fever, cough, fatigue, and loss of taste or smell.
69
- Prevention includes vaccination, wearing masks, and social distancing.
70
- Treatment depends on severity and may include antiviral medications.
71
-
72
- Heart attack (myocardial infarction) occurs when blood flow to the heart is blocked.
73
- Symptoms include chest pain, shortness of breath, nausea, and pain in arms or jaw.
74
- Immediate treatment is crucial and may include aspirin, nitroglycerin, or surgery.
75
- Risk factors include smoking, high cholesterol, and family history.
76
-
77
- Stroke occurs when blood supply to part of the brain is interrupted.
78
- Symptoms include sudden numbness, confusion, trouble speaking, and loss of balance.
79
- FAST is an acronym for Face drooping, Arm weakness, Speech difficulty, Time to call emergency.
80
- Treatment includes clot-busting drugs and rehabilitation.
81
-
82
- Cancer is a disease caused by uncontrolled cell growth.
83
- Common types include lung, breast, prostate, and colorectal cancer.
84
- Symptoms vary but may include lumps, unexplained weight loss, and persistent pain.
85
- Treatments include surgery, chemotherapy, radiation, and immunotherapy.
86
- """
87
-
88
- # Save sample text to a temporary file
89
- temp_file = "temp_medical_data.txt"
90
- with open(temp_file, "w") as f:
91
- f.write(sample_text)
92
-
93
- # Load and process documents
94
- loader = TextLoader(temp_file)
95
- documents = loader.load()
96
-
97
- text_splitter = RecursiveCharacterTextSplitter(
98
- chunk_size=500,
99
- chunk_overlap=50,
100
- length_function=len,
101
- separators=["\n\n", "\n", " ", ""]
102
- )
103
- texts = text_splitter.split_documents(documents)
104
-
105
- # Create embeddings
106
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
107
-
108
- # Create and save FAISS index
109
- db = FAISS.from_documents(texts, embeddings)
110
-
111
- # Create directory if it doesn't exist
112
- Path("medical_faiss_store").mkdir(exist_ok=True)
113
-
114
- # Save the FAISS index
115
- db.save_local("medical_faiss_store", index_name="medical_faiss")
116
-
117
- # Clean up temp file
118
- os.remove(temp_file)
119
-
120
- st.success("βœ… FAISS index created successfully with sample medical data!")
121
- return db
122
-
123
- except Exception as e:
124
- st.error(f"Error creating FAISS index: {e}")
125
- return None
126
 
127
- # Load FAISS database with error handling
128
- @st.cache_resource
129
- def load_faiss():
130
- """Load FAISS database or create if it doesn't exist"""
131
- try:
132
- embeddings = load_embeddings()
133
- if embeddings is None:
134
- return None
135
-
136
- # Check if files exist
137
- if not check_files():
138
- st.warning("FAISS index not found. Creating a new one with sample medical data...")
139
- return create_faiss_index()
140
-
141
- # Load existing FAISS database
142
- db = FAISS.load_local(
143
- "medical_faiss_store",
144
- embeddings,
145
- index_name="medical_faiss",
146
- allow_dangerous_deserialization=True
147
- )
148
- st.success("βœ… FAISS database loaded successfully!")
149
- return db
150
-
151
- except Exception as e:
152
- st.error(f"Error loading FAISS database: {e}")
153
- return None
154
 
155
- # Initialize LLM
156
- @st.cache_resource
157
- def load_llm():
158
- """Load the HuggingFace LLM"""
159
- try:
160
- # Check for API token
161
- api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
162
- if not api_token:
163
- st.error("HuggingFace API token not found in environment variables.")
164
- st.info("""
165
- Please add your token to the .env file as:
166
- HUGGINGFACEHUB_API_TOKEN=your_token_here
167
-
168
- You can get a free token from:
169
- https://huggingface.co/settings/tokens
170
- """)
171
- return None
172
-
173
- # Using a model that works well for QA
174
- llm = HuggingFaceHub(
175
- repo_id="google/flan-t5-large",
176
- model_kwargs={
177
- "temperature": 0.1,
178
- "max_length": 512,
179
- "min_length": 50
180
- },
181
- huggingfacehub_api_token=api_token
182
- )
183
- return llm
184
- except Exception as e:
185
- st.error(f"Error loading LLM: {e}")
186
- return None
187
 
188
- # Create QA chain
189
- @st.cache_resource
190
- def create_qa_chain(_db, _llm):
191
- """Create the QA chain"""
192
- if _db is None or _llm is None:
193
- return None
194
 
195
  try:
196
- retriever = _db.as_retriever(
197
- search_type="similarity",
198
- search_kwargs={"k": 3}
199
- )
200
- qa_chain = RetrievalQA.from_chain_type(
201
- llm=_llm,
202
- chain_type="stuff",
203
- retriever=retriever,
204
- return_source_documents=True,
205
- verbose=False
206
- )
207
- return qa_chain
208
  except Exception as e:
209
- st.error(f"Error creating QA chain: {e}")
210
- return None
211
 
212
- # Main app function
213
- def main():
214
- st.title("πŸ₯ Medical QA Assistant")
215
- st.markdown("Ask questions about medical information and get AI-powered answers.")
216
-
217
- # Initialize session state
218
- if 'chat_history' not in st.session_state:
219
- st.session_state.chat_history = []
220
- if 'initialized' not in st.session_state:
221
- st.session_state.initialized = False
222
 
223
- # Sidebar
224
- with st.sidebar:
225
- st.header("Configuration")
226
- st.markdown("---")
227
-
228
- # Display file status
229
- files_exist = check_files()
230
- if files_exist:
231
- st.success("βœ… FAISS index files found")
232
- else:
233
- st.warning("⚠️ FAISS index will be created on first run")
234
-
235
- st.markdown("---")
236
-
237
- # Model info
238
- st.subheader("Model Information")
239
- st.markdown("""
240
- - **Embeddings**: sentence-transformers/all-MiniLM-L6-v2
241
- - **LLM**: google/flan-t5-large
242
- - **Retrieval**: FAISS with 3 similar chunks
243
- """)
244
-
245
- # API token status
246
- api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
247
- if api_token:
248
- st.success("βœ… HuggingFace API token found")
249
- else:
250
- st.error("❌ HuggingFace API token missing")
251
-
252
- # Clear chat button
253
- if st.button("Clear Chat History"):
254
- st.session_state.chat_history = []
255
- st.rerun()
256
-
257
- # Recreate index button
258
- if st.button("Recreate FAISS Index"):
259
- with st.spinner("Creating new index..."):
260
- db = create_faiss_index()
261
- if db:
262
- st.success("Index recreated successfully!")
263
- st.rerun()
264
-
265
- # Debug info
266
- with st.expander("Debug Information"):
267
- st.write(f"Python version: {os.sys.version}")
268
- st.write(f"Working directory: {os.getcwd()}")
269
- st.write(f"Files in directory: {os.listdir('.')}")
270
 
271
- # Main content area
272
- col1, col2 = st.columns([3, 1])
 
 
 
 
 
273
 
274
- with col1:
275
- # Initialize components
276
- if not st.session_state.initialized:
277
- with st.spinner("Initializing system... This may take a minute."):
 
 
278
  try:
279
- # Try to load or create FAISS index
280
- db = load_faiss()
281
-
282
- if db is None:
283
- st.error("Failed to create or load FAISS index.")
284
- return
285
-
286
- # Try to load LLM
287
- llm = load_llm()
288
-
289
- if llm is None:
290
- st.error("Failed to load LLM. Please check your HuggingFace API token.")
291
- return
292
-
293
- # Create QA chain
294
- qa_chain = create_qa_chain(db, llm)
295
-
296
- if qa_chain is None:
297
- st.error("Failed to create QA chain.")
298
- return
299
-
300
- # Store in session state
301
- st.session_state.db = db
302
- st.session_state.llm = llm
303
- st.session_state.qa_chain = qa_chain
304
- st.session_state.initialized = True
305
-
306
- st.success("βœ… System initialized successfully!")
307
-
308
  except Exception as e:
309
- st.error(f"Initialization error: {e}")
310
- return
311
-
312
- # Check if system is initialized
313
- if not st.session_state.initialized:
314
- st.error("System not initialized. Please check the error messages above.")
315
- return
316
-
317
- # Chat input
318
- query = st.text_input(
319
- "πŸ’¬ Ask a medical question:",
320
- placeholder="e.g., What are the symptoms of diabetes?",
321
- key="query_input"
322
- )
323
-
324
- # Submit button
325
- col_submit1, col_submit2 = st.columns([1, 5])
326
- with col_submit1:
327
- submit_button = st.button("Submit", type="primary", disabled=(not query))
328
-
329
- # Process query
330
- if submit_button and query:
331
- with st.spinner("Searching for relevant information..."):
332
- try:
333
- # Get response from QA chain
334
- result = st.session_state.qa_chain({"query": query})
335
-
336
- # Display answer
337
- st.markdown("### Answer:")
338
- st.write(result['result'])
339
-
340
- # Display source documents
341
- with st.expander("πŸ“š View source information"):
342
- for i, doc in enumerate(result['source_documents']):
343
- st.markdown(f"**Source {i+1}:**")
344
- st.write(doc.page_content[:300] + ("..." if len(doc.page_content) > 300 else ""))
345
- st.markdown("---")
 
 
 
 
 
 
 
346
 
347
- # Save to chat history
348
- st.session_state.chat_history.append({
349
- "question": query,
350
- "answer": result['result'],
351
- "sources": result['source_documents']
352
- })
 
 
 
 
 
 
 
353
 
354
- except Exception as e:
355
- st.error(f"Error getting response: {e}")
356
- st.info("Please try rephrasing your question.")
357
-
358
- with col2:
359
- # Chat history
360
- st.subheader("πŸ“ Chat History")
361
- if st.session_state.chat_history:
362
- for i, chat in enumerate(st.session_state.chat_history[-5:][::-1]): # Show last 5, newest first
363
- with st.expander(f"Q: {chat['question'][:50]}..."):
364
- st.write(f"**Q:** {chat['question']}")
365
- st.write(f"**A:** {chat['answer'][:150]}...")
366
- else:
367
- st.info("No questions asked yet.")
368
 
369
- # Run the app
370
- if __name__ == "__main__":
371
- main()
 
 
 
 
 
 
1
  import streamlit as st
2
+ import google.generativeai as genai
 
 
 
 
 
 
3
  import os
4
+ from medical_rag_system import MedicalRAGSystem
 
5
 
 
 
 
 
6
  st.set_page_config(
7
+ page_title="Medical RAG Assistant",
8
  page_icon="πŸ₯",
9
  layout="wide"
10
  )
11
 
12
+ st.markdown("""
13
+ <style>
14
+ .main-header {
15
+ font-size: 2.5rem;
16
+ color: #1f77b4;
17
+ text-align: center;
18
+ margin-bottom: 2rem;
19
+ }
20
+ .info-box {
21
+ background-color: #f0f2f6;
22
+ padding: 1rem;
23
+ border-radius: 0.5rem;
24
+ margin: 1rem 0;
25
+ }
26
+ .source-box {
27
+ background-color: #e8f4fd;
28
+ padding: 0.5rem;
29
+ border-radius: 0.3rem;
30
+ margin: 0.5rem 0;
31
+ border-left: 4px solid #1f77b4;
32
+ }
33
+ </style>
34
+ """, unsafe_allow_html=True)
35
 
36
+ def generate_medical_answer(query, context_chunks, api_key):
37
+ """Generate answer using Gemini with retrieved context"""
38
+ if not context_chunks:
39
+ return "I couldn't find relevant medical information to answer this question in the available records."
40
+
41
+ # Prepare context from retrieved chunks
42
+ context_text = "\n\n".join([
43
+ f"--- MEDICAL NOTE {i+1} (Specialty: {chunk['metadata']['medical_specialty']}) ---\n{chunk['content']}"
44
+ for i, chunk in enumerate(context_chunks)
45
+ ])
46
+
47
+ prompt = f"""You are a medical research assistant. Answer the question based ONLY on the provided medical context from clinical notes.
48
 
49
+ MEDICAL CONTEXT:
50
+ {context_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ QUESTION: {query}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ IMPORTANT INSTRUCTIONS:
55
+ - Answer using ONLY the information from the medical context above
56
+ - If the context doesn't contain relevant information, say "I cannot find specific information about this in the available medical records"
57
+ - Be precise and medically accurate
58
+ - Do not make up or hallucinate information
59
+ - Mention which medical specialty the information comes from when relevant
60
+ - Keep answers concise but informative
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ ANSWER:"""
 
 
 
 
 
63
 
64
  try:
65
+ genai.configure(api_key=api_key)
66
+ model = genai.GenerativeModel("models/gemini-2.0-flash")
67
+ response = model.generate_content(prompt)
68
+ return response.text
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
+ return f"Error generating answer: {str(e)}"
 
71
 
72
+ # Main app
73
+ st.markdown('<div class="main-header">πŸ₯ Medical RAG Assistant</div>', unsafe_allow_html=True)
74
+ st.markdown("**Ask medical questions based on 3,898 clinical transcriptions across 39 medical specialties**")
75
+
76
+ # Sidebar configuration
77
+ with st.sidebar:
78
+ st.header("βš™οΈ Configuration")
 
 
 
79
 
80
+ api_key = st.text_input(
81
+ "Google AI Studio API Key",
82
+ type="password",
83
+ help="Get free API key from https://aistudio.google.com/"
84
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ st.markdown('<div class="info-box">', unsafe_allow_html=True)
87
+ st.write("**How to get API Key:**")
88
+ st.write("1. Go to [Google AI Studio](https://aistudio.google.com/)")
89
+ st.write("2. Sign in with Google account")
90
+ st.write("3. Click 'Get API Key' and create new key")
91
+ st.write("4. Paste the key here")
92
+ st.markdown('</div>', unsafe_allow_html=True)
93
 
94
+ # Initialize button
95
+ if st.button("πŸš€ Initialize Medical RAG System", use_container_width=True):
96
+ if not api_key:
97
+ st.error("Please enter your Google AI Studio API key first")
98
+ else:
99
+ with st.spinner("Loading medical database..."):
100
  try:
101
+ rag_system = MedicalRAGSystem()
102
+ st.session_state.rag_system = rag_system
103
+ st.session_state.api_key = api_key
104
+ st.success("βœ… Medical RAG system initialized successfully!")
105
+ st.info(f"πŸ“Š System contains {len(rag_system.chunks)} medical chunks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
+ st.error(f"Failed to initialize: {str(e)}")
108
+
109
+ # System info
110
+ if 'rag_system' in st.session_state:
111
+ st.markdown('<div class="info-box">', unsafe_allow_html=True)
112
+ st.write("**System Status:** βœ… Active")
113
+ st.write(f"**Chunks loaded:** {len(st.session_state.rag_system.chunks):,}")
114
+ st.write("**Ready for queries**")
115
+ st.markdown('</div>', unsafe_allow_html=True)
116
+
117
+ # Main query interface
118
+ st.divider()
119
+ st.subheader("πŸ” Ask Medical Questions")
120
+
121
+ query = st.text_area(
122
+ "Enter your medical question:",
123
+ placeholder="e.g., What are the symptoms of allergic rhinitis?",
124
+ height=100
125
+ )
126
+
127
+ num_chunks = st.slider("Number of medical chunks to retrieve:", 1, 5, 3)
128
+
129
+ if st.button("πŸ”Ž Search Medical Database", type="primary", use_container_width=True):
130
+ if 'rag_system' not in st.session_state:
131
+ st.error("Please initialize the RAG system first using the sidebar button")
132
+ elif not query.strip():
133
+ st.error("Please enter a question")
134
+ else:
135
+ with st.spinner("Searching medical database..."):
136
+ # Retrieve relevant chunks
137
+ retrieved_chunks = st.session_state.rag_system.retrieve_similar_chunks(
138
+ query,
139
+ k=num_chunks
140
+ )
141
+
142
+ if not retrieved_chunks:
143
+ st.warning("No relevant medical information found for this query.")
144
+ else:
145
+ # Display retrieved chunks
146
+ st.subheader("πŸ“‹ Retrieved Medical Information")
147
+
148
+ for i, chunk in enumerate(retrieved_chunks):
149
+ specialty = chunk['metadata']['medical_specialty']
150
+ score = chunk['similarity_score']
151
 
152
+ with st.expander(f"Source {i+1}: {specialty} (Relevance: {score:.3f})"):
153
+ st.markdown('<div class="source-box">', unsafe_allow_html=True)
154
+ st.write(chunk['content'][:500] + "..." if len(chunk['content']) > 500 else chunk['content'])
155
+ st.markdown('</div>', unsafe_allow_html=True)
156
+
157
+ # Generate answer
158
+ st.subheader("πŸ’‘ AI-Generated Answer")
159
+ with st.spinner("Generating medical answer..."):
160
+ answer = generate_medical_answer(
161
+ query,
162
+ retrieved_chunks,
163
+ st.session_state.api_key
164
+ )
165
 
166
+ st.markdown(f"""
167
+ <div style="background-color: #f8f9fa; padding: 20px; border-radius: 10px; border-left: 5px solid #1f77b4;">
168
+ {answer}
169
+ </div>
170
+ """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
171
 
172
+ # Footer
173
+ st.divider()
174
+ st.markdown("""
175
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
176
+ <p>Medical RAG Assistant | Powered by Google Gemini & FAISS</p>
177
+ <p>⚠️ <strong>Disclaimer:</strong> This tool provides information from clinical notes and should not be used for medical diagnosis or treatment decisions.</p>
178
+ </div>
179
+ """, unsafe_allow_html=True)