Himanshu Gangwar commited on
Commit
9b69c13
·
1 Parent(s): 58ad204

Refactor: Simplify Neo4j connection management and remove unused code

Browse files
Files changed (1) hide show
  1. app.py +61 -471
app.py CHANGED
@@ -1,477 +1,67 @@
1
- import gradio as gr
2
- import faiss
3
- import json
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from groq import Groq
7
- from neo4j import GraphDatabase
8
- from dotenv import load_dotenv
9
  import os
 
 
 
 
10
 
 
11
  load_dotenv()
12
 
13
- # Load credentials from environment or Hugging Face Spaces secrets
14
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
15
- # Use local Neo4j instance running directly (not Docker)
16
- NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
17
- NEO4J_USER = os.getenv("NEO4J_USERNAME", "neo4j")
18
- NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
19
- NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")
20
- FAISS_INDEX_PATH = "db/medicine_embeddings.index"
21
- METADATA_PATH = "db/metadata.json"
22
-
23
- EMBED_MODEL = "BAAI/bge-large-en-v1.5"
24
- LLM_MODEL = "openai/gpt-oss-120b"
25
-
26
-
27
- # ---------------------------------------------------------
28
- # LOAD MODELS & DATABASES (ON STARTUP)
29
- # ---------------------------------------------------------
30
-
31
- def load_faiss():
32
- return faiss.read_index(FAISS_INDEX_PATH)
33
-
34
- def load_metadata():
35
- with open(METADATA_PATH, "r") as f:
36
- return json.load(f)
37
-
38
- def load_embedder():
39
- return SentenceTransformer(EMBED_MODEL)
40
-
41
- def load_llm():
42
- return Groq(api_key=GROQ_API_KEY)
43
-
44
- def load_neo4j():
45
- if not all([NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD]):
46
- raise ValueError("Neo4j credentials not configured")
47
-
48
- driver = GraphDatabase.driver(
49
- NEO4J_URI,
50
- auth=(NEO4J_USER, NEO4J_PASSWORD),
51
- max_connection_lifetime=3600,
52
- max_connection_pool_size=50,
53
- connection_acquisition_timeout=120
54
- )
55
- # Test the connection
56
- driver.verify_connectivity()
57
- return driver
58
-
59
-
60
- # Initialize resources
61
- print("Loading FAISS index...")
62
- faiss_index = load_faiss()
63
- print("Loading metadata...")
64
- metadata = load_metadata()
65
- print("Loading embedder model...")
66
- embedder = load_embedder()
67
- print("Loading Groq LLM client...")
68
- groq_client = load_llm()
69
-
70
- # Load Neo4j with error handling
71
- neo4j_status = ""
72
- neo4j_driver = None
73
- try:
74
- print("Connecting to Neo4j...")
75
- neo4j_driver = load_neo4j()
76
- neo4j_status = "✅ Connected to Neo4j"
77
- print(neo4j_status)
78
- except Exception as e:
79
- neo4j_status = f"❌ Neo4j Connection Failed: {str(e)}"
80
- print(neo4j_status)
81
- print("⚠️ App will continue with FAISS search only (Graph features disabled)")
82
-
83
-
84
- # ---------------------------------------------------------
85
- # GRAPH EXPANSION — FETCH RELATED NODES
86
- # ---------------------------------------------------------
87
-
88
- def get_graph_info(drug_name):
89
- if neo4j_driver is None:
90
- return {}
91
-
92
- # Use case-insensitive matching since metadata has lowercase names
93
- # but Neo4j has Title Case names
94
- query = """
95
- MATCH (m:Medicine)
96
- WHERE toLower(m.name) = toLower($name)
97
- OPTIONAL MATCH (m)-[r]->(n)
98
- WITH type(r) AS rel_type, n.name AS target_name
99
- WHERE rel_type IS NOT NULL
100
- RETURN rel_type AS relation, target_name AS value
101
- LIMIT 200
102
- """
103
- try:
104
- with neo4j_driver.session(database=NEO4J_DATABASE) as session:
105
- result = session.run(query, name=drug_name).data()
106
- except Exception as e:
107
- print(f"Graph query error: {e}")
108
- return {}
109
-
110
- graph_dict = {}
111
- for row in result:
112
- relation = row.get("relation")
113
- value = row.get("value")
114
- if relation and value:
115
- graph_dict.setdefault(relation, []).append(value)
116
-
117
- return graph_dict
118
-
119
-
120
- # ---------------------------------------------------------
121
- # SEMANTIC SEARCH (FAISS)
122
- # ---------------------------------------------------------
123
-
124
- def semantic_search(query, top_k=5):
125
- query_emb = embedder.encode(query).astype("float32")
126
-
127
- distances, indices = faiss_index.search(
128
- np.array([query_emb]), top_k
129
- )
130
-
131
- results = []
132
- for idx in indices[0]:
133
- results.append(metadata[idx])
134
- return results
135
-
136
-
137
- # ---------------------------------------------------------
138
- # DIRECT NEO4J SEARCH (Graph-based)
139
- # ---------------------------------------------------------
140
-
141
- def search_neo4j_directly(query, limit=10):
142
- """
143
- Search Neo4j directly for medicines, conditions, side effects, or ingredients
144
- based on the query keywords.
145
  """
146
- if neo4j_driver is None:
147
- return {"medicines": [], "conditions": [], "side_effects": [], "ingredients": []}
148
-
149
- results = {
150
- "medicines": [],
151
- "conditions": [],
152
- "side_effects": [],
153
- "ingredients": []
154
- }
155
-
156
- # Extract keywords from query (simple approach)
157
- query_lower = query.lower()
158
-
159
- try:
160
- with neo4j_driver.session(database=NEO4J_DATABASE) as session:
161
- # Search medicines by name or composition containing query terms
162
- med_query = """
163
- MATCH (m:Medicine)
164
- WHERE toLower(m.name) CONTAINS $query
165
- OR toLower(m.composition) CONTAINS $query
166
- OR toLower(m.uses_text) CONTAINS $query
167
- RETURN m.name AS name, m.composition AS composition,
168
- m.uses_text AS uses, m.side_effects_text AS side_effects,
169
- m.excellent_review_pct AS excellent_review
170
- ORDER BY m.excellent_review_pct DESC
171
- LIMIT $limit
172
- """
173
- med_results = session.run(med_query, query=query_lower, limit=limit).data()
174
- results["medicines"] = med_results
175
-
176
- # Search conditions that match query
177
- cond_query = """
178
- MATCH (c:Condition)<-[:TREATS]-(m:Medicine)
179
- WHERE toLower(c.name) CONTAINS $query
180
- RETURN c.name AS condition, collect(DISTINCT m.name)[0..5] AS treating_medicines
181
- LIMIT 5
182
- """
183
- cond_results = session.run(cond_query, query=query_lower).data()
184
- results["conditions"] = cond_results
185
-
186
- # Search side effects that match query
187
- se_query = """
188
- MATCH (s:SideEffect)<-[:HAS_SIDE_EFFECT]-(m:Medicine)
189
- WHERE toLower(s.name) CONTAINS $query
190
- RETURN s.name AS side_effect, collect(DISTINCT m.name)[0..5] AS medicines_with_effect
191
- LIMIT 5
192
- """
193
- se_results = session.run(se_query, query=query_lower).data()
194
- results["side_effects"] = se_results
195
-
196
- # Search ingredients that match query
197
- ing_query = """
198
- MATCH (i:ActiveIngredient)<-[:CONTAINS_INGREDIENT]-(m:Medicine)
199
- WHERE toLower(i.name) CONTAINS $query
200
- RETURN i.name AS ingredient, collect(DISTINCT m.name)[0..10] AS medicines_containing
201
- LIMIT 5
202
- """
203
- ing_results = session.run(ing_query, query=query_lower).data()
204
- results["ingredients"] = ing_results
205
-
206
- except Exception as e:
207
- print(f"Neo4j direct search error: {e}")
208
-
209
- return results
210
-
211
-
212
- # ---------------------------------------------------------
213
- # LLM ANSWER USING GROQ
214
- # ---------------------------------------------------------
215
-
216
- def answer_with_groq(query, faiss_results, graph_expansion, neo4j_direct_results):
217
- system_prompt = """
218
- You are a medical question answering assistant with access to TWO data sources:
219
-
220
- 1. **FAISS Vector Database**: Semantic similarity search results - good for finding medicines
221
- related to the query meaning, even if exact keywords don't match.
222
-
223
- 2. **Neo4j Graph Database**:
224
- - Direct search results: Exact matches for medicines, conditions, side effects, ingredients
225
- - Graph expansion: Relationships like TREATS, HAS_SIDE_EFFECT, CONTAINS_INGREDIENT, MANUFACTURED_BY
226
-
227
- Your task:
228
- - Analyze BOTH data sources
229
- - Decide which source is more relevant for the specific question
230
- - You can use BOTH sources if they provide complementary information
231
- - For specific medicine queries → prioritize Neo4j direct matches
232
- - For general symptom/condition queries → combine FAISS semantics + Neo4j graph relationships
233
- - For side effect queries → prioritize Neo4j graph data (HAS_SIDE_EFFECT relationships)
234
- - For ingredient queries → prioritize Neo4j graph data (CONTAINS_INGREDIENT relationships)
235
-
236
- Rules:
237
- - Never hallucinate facts - use ONLY the provided context
238
- - If data is conflicting, prefer Neo4j graph data (more structured)
239
- - Clearly cite which source provided the information when helpful
240
- - Be concise but comprehensive
241
  """
242
-
243
- # Build context from FAISS metadata
244
- faiss_text = "=== FAISS VECTOR SEARCH RESULTS ===\n"
245
- if faiss_results:
246
- for item in faiss_results:
247
- faiss_text += f"""
248
- Medicine: {item.get('name', 'N/A')}
249
- Uses: {item.get('uses', 'N/A')}
250
- Side Effects: {item.get('side_effects', 'N/A')}
251
- Manufacturer: {item.get('manufacturer', 'N/A')}
252
- ---
253
- """
254
- else:
255
- faiss_text += "No FAISS results found.\n"
256
-
257
- # Build graph expansion info
258
- graph_text = "\n=== NEO4J GRAPH EXPANSION (Relationships) ===\n"
259
- has_graph_data = False
260
- for medicine, relations in graph_expansion.items():
261
- if relations:
262
- has_graph_data = True
263
- graph_text += f"\n📊 Graph Data for '{medicine}':\n"
264
- for rel, vals in relations.items():
265
- rel_readable = rel.replace("_", " ").title()
266
- graph_text += f" • {rel_readable}: {', '.join(vals[:10])}\n"
267
- if not has_graph_data:
268
- graph_text += "No graph expansion data found.\n"
269
-
270
- # Build Neo4j direct search results
271
- neo4j_text = "\n=== NEO4J DIRECT SEARCH RESULTS ===\n"
272
- has_neo4j_data = False
273
-
274
- if neo4j_direct_results.get("medicines"):
275
- has_neo4j_data = True
276
- neo4j_text += "\n🔍 Matching Medicines:\n"
277
- for med in neo4j_direct_results["medicines"][:5]:
278
- neo4j_text += f" • {med.get('name', 'N/A')}\n"
279
- neo4j_text += f" Uses: {med.get('uses', 'N/A')[:200]}...\n" if med.get('uses') else ""
280
- neo4j_text += f" Side Effects: {med.get('side_effects', 'N/A')[:200]}...\n" if med.get('side_effects') else ""
281
-
282
- if neo4j_direct_results.get("conditions"):
283
- has_neo4j_data = True
284
- neo4j_text += "\n🏥 Matching Conditions:\n"
285
- for cond in neo4j_direct_results["conditions"]:
286
- neo4j_text += f" • {cond.get('condition', 'N/A')}\n"
287
- neo4j_text += f" Treating Medicines: {', '.join(cond.get('treating_medicines', []))}\n"
288
-
289
- if neo4j_direct_results.get("side_effects"):
290
- has_neo4j_data = True
291
- neo4j_text += "\n⚠️ Matching Side Effects:\n"
292
- for se in neo4j_direct_results["side_effects"]:
293
- neo4j_text += f" • {se.get('side_effect', 'N/A')}\n"
294
- neo4j_text += f" Found in: {', '.join(se.get('medicines_with_effect', []))}\n"
295
-
296
- if neo4j_direct_results.get("ingredients"):
297
- has_neo4j_data = True
298
- neo4j_text += "\n💊 Matching Ingredients:\n"
299
- for ing in neo4j_direct_results["ingredients"]:
300
- neo4j_text += f" • {ing.get('ingredient', 'N/A')}\n"
301
- neo4j_text += f" Found in: {', '.join(ing.get('medicines_containing', [])[:5])}\n"
302
-
303
- if not has_neo4j_data:
304
- neo4j_text += "No direct Neo4j matches found.\n"
305
-
306
- full_prompt = f"""
307
- {system_prompt}
308
-
309
- 📝 USER QUERY: {query}
310
-
311
- {faiss_text}
312
- {graph_text}
313
- {neo4j_text}
314
-
315
- Based on the above data sources, provide a comprehensive answer. Indicate which data source(s) you primarily used.
316
- """
317
-
318
- response = groq_client.chat.completions.create(
319
- model=LLM_MODEL,
320
- messages=[{"role": "user", "content": full_prompt}],
321
- temperature=0.2,
322
- )
323
-
324
- return response.choices[0].message.content
325
-
326
-
327
- # ---------------------------------------------------------
328
- # MAIN QUERY FUNCTION
329
- # ---------------------------------------------------------
330
-
331
- def process_query(query):
332
- """Main function to process user query and return results"""
333
- if not query.strip():
334
- return "⚠️ Please enter a query.", "", "", "", neo4j_status
335
-
336
- # Step 1: FAISS Semantic Search
337
- faiss_results = semantic_search(query)
338
-
339
- # Step 2: Neo4j Direct Search
340
- neo4j_direct_results = search_neo4j_directly(query)
341
-
342
- # Step 3: Graph expansion for FAISS results
343
- graph_expansion = {}
344
- for r in faiss_results:
345
- graph_expansion[r["name"]] = get_graph_info(r["name"])
346
-
347
- # Step 4: Format FAISS results for display
348
- faiss_text = "### 🔍 FAISS Vector Search Results\n\n"
349
- for r in faiss_results:
350
- faiss_text += f"**{r['name']}**\n"
351
- faiss_text += f"- Uses: {r.get('uses', 'N/A')[:150]}...\n"
352
- faiss_text += f"- Side Effects: {r.get('side_effects', 'N/A')[:100]}...\n\n"
353
-
354
- # Step 5: Format Neo4j results for display
355
- neo4j_text = "### 🧬 Neo4j Graph Database Results\n\n"
356
-
357
- # Direct matches
358
- if neo4j_direct_results.get("medicines"):
359
- neo4j_text += "**📋 Direct Medicine Matches:**\n"
360
- for med in neo4j_direct_results["medicines"][:5]:
361
- neo4j_text += f"- {med.get('name', 'N/A')}\n"
362
- neo4j_text += "\n"
363
-
364
- if neo4j_direct_results.get("conditions"):
365
- neo4j_text += "**🏥 Matching Conditions:**\n"
366
- for cond in neo4j_direct_results["conditions"]:
367
- neo4j_text += f"- {cond.get('condition', 'N/A')}: {', '.join(cond.get('treating_medicines', [])[:3])}\n"
368
- neo4j_text += "\n"
369
-
370
- if neo4j_direct_results.get("ingredients"):
371
- neo4j_text += "**💊 Matching Ingredients:**\n"
372
- for ing in neo4j_direct_results["ingredients"]:
373
- neo4j_text += f"- {ing.get('ingredient', 'N/A')}: {', '.join(ing.get('medicines_containing', [])[:3])}\n"
374
- neo4j_text += "\n"
375
-
376
- if neo4j_direct_results.get("side_effects"):
377
- neo4j_text += "**⚠️ Matching Side Effects:**\n"
378
- for se in neo4j_direct_results["side_effects"]:
379
- neo4j_text += f"- {se.get('side_effect', 'N/A')}: {', '.join(se.get('medicines_with_effect', [])[:3])}\n"
380
- neo4j_text += "\n"
381
-
382
- # Graph expansion
383
- neo4j_text += "**🔗 Graph Relationships:**\n```json\n"
384
- neo4j_text += json.dumps(graph_expansion, indent=2)[:2000]
385
- neo4j_text += "\n```"
386
-
387
- # Step 6: Generate LLM answer using all sources
388
- answer = answer_with_groq(query, faiss_results, graph_expansion, neo4j_direct_results)
389
-
390
- final_answer = "### 🩺 AI Answer (Using Both Databases)\n\n" + answer
391
-
392
- return faiss_text, neo4j_text, final_answer, neo4j_status
393
-
394
-
395
- # ---------------------------------------------------------
396
- # GRADIO UI
397
- # ---------------------------------------------------------
398
-
399
- def create_interface():
400
- with gr.Blocks(title="Medicine GraphRAG AI") as demo:
401
- gr.Markdown("# 💊 Medicine GraphRAG AI")
402
- gr.Markdown("**Dual Database Search: FAISS Vector DB + Neo4j Graph DB + LLM Reasoning**")
403
-
404
- with gr.Row():
405
- status_display = gr.Textbox(
406
- label="Database Status",
407
- value=neo4j_status,
408
- interactive=False,
409
- lines=1
410
- )
411
-
412
- with gr.Row():
413
- query_input = gr.Textbox(
414
- label="Enter your medical query",
415
- placeholder="e.g., What are the side effects of paracetamol?",
416
- lines=2
417
- )
418
-
419
- with gr.Row():
420
- search_btn = gr.Button("🔍 Search Both Databases", variant="primary", size="lg")
421
- clear_btn = gr.Button("Clear", variant="secondary")
422
-
423
- # Answer section FIRST (most important)
424
- with gr.Row():
425
- answer_output = gr.Markdown(
426
- label="AI Answer",
427
- value="",
428
- )
429
-
430
- # Database results in collapsible/scrollable sections
431
- with gr.Row():
432
- with gr.Column():
433
- with gr.Accordion("🔍 FAISS Vector Search Results", open=False):
434
- faiss_output = gr.Markdown(
435
- label="FAISS Results",
436
- value="",
437
- )
438
-
439
- with gr.Column():
440
- with gr.Accordion("🧬 Neo4j Graph Database Results", open=False):
441
- neo4j_output = gr.Markdown(
442
- label="Neo4j Results",
443
- value="",
444
- )
445
-
446
- # Event handlers
447
- search_btn.click(
448
- fn=process_query,
449
- inputs=[query_input],
450
- outputs=[faiss_output, neo4j_output, answer_output, status_display]
451
- )
452
-
453
- clear_btn.click(
454
- fn=lambda: ("", "", "", neo4j_status),
455
- inputs=[],
456
- outputs=[faiss_output, neo4j_output, answer_output, status_display]
457
- )
458
-
459
- # Examples
460
- gr.Examples(
461
- examples=[
462
- ["What is the best medicine for acidity?"],
463
- ["Show me medicines for headache"],
464
- ["What are the side effects of paracetamol?"],
465
- ["Suggest medicine for cold and fever"],
466
- ["Find medicines containing ibuprofen"],
467
- ["What treats hypertension?"]
468
- ],
469
- inputs=query_input
470
- )
471
-
472
- return demo
473
-
474
-
475
- if __name__ == "__main__":
476
- demo = create_interface()
477
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import streamlit as st
3
+ from neo4j import GraphDatabase
4
+ from langchain_community.graphs import Neo4jGraph
5
+ from dotenv import load_dotenv
6
 
7
+ # Load environment variables from .env file for local development
8
  load_dotenv()
9
 
10
+ class Neo4jConnection:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
+ A class to manage the connection to a Neo4j database.
13
+ It uses the credentials sourced from Streamlit secrets or a local .env file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
+ def __init__(self):
16
+ # Prioritize Streamlit secrets, fall back to .env for local dev
17
+ if hasattr(st, 'secrets') and "NEO4J_URI" in st.secrets:
18
+ uri = st.secrets["NEO4J_URI"]
19
+ user = st.secrets["NEO4J_USER"]
20
+ password = st.secrets["NEO4J_PASSWORD"]
21
+ print("Connecting to Neo4j using Streamlit secrets.")
22
+ else:
23
+ uri = os.getenv("NEO4J_URI")
24
+ user = os.getenv("NEO4J_USER")
25
+ password = os.getenv("NEO4J_PASSWORD")
26
+ print("Connecting to Neo4j using local .env file.")
27
+
28
+ self._driver = GraphDatabase.driver(uri, auth=(user, password))
29
+ try:
30
+ # Verify connection
31
+ self._driver.verify_connectivity()
32
+ print("Connected to Neo4j")
33
+ except Exception as e:
34
+ print(f"Neo4j connection failed: {e}")
35
+
36
+ def close(self):
37
+ if self._driver is not None:
38
+ self._driver.close()
39
+
40
+ def query(self, query, parameters=None, db=None):
41
+ """Runs a Cypher query and returns the results."""
42
+ assert self._driver is not None, "Driver not initialized!"
43
+ session = None
44
+ response = None
45
+ try:
46
+ session = self._driver.session(database=db) if db is not None else self._driver.session()
47
+ response = list(session.run(query, parameters))
48
+ except Exception as e:
49
+ print("Query failed:", e)
50
+ finally:
51
+ if session is not None:
52
+ session.close()
53
+ return response
54
+
55
+
56
+ graph = Neo4jGraph(
57
+ url=st.secrets["NEO4J_URI"],
58
+ username=st.secrets["NEO4J_USER"],
59
+ password=st.secrets["NEO4J_PASSWORD"]
60
+ )
61
+
62
+ # Refresh schema information for the LangChain graph object
63
+ # This helps the LLM generate more accurate Cypher queries
64
+ try:
65
+ graph.refresh_schema()
66
+ except Exception as e:
67
+ print(f"Warning: Could not refresh graph schema. The LLM might generate less accurate queries. Error: {e}")