|
|
import gradio as gr |
|
|
import faiss |
|
|
import json |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from groq import Groq |
|
|
from neo4j import GraphDatabase |
|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
NEO4J_URI = os.getenv("NEO4J_URI") |
|
|
NEO4J_USER = os.getenv("NEO4J_USERNAME") |
|
|
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") |
|
|
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j") |
|
|
FAISS_INDEX_PATH = "db/medicine_embeddings.index" |
|
|
METADATA_PATH = "db/metadata.json" |
|
|
|
|
|
EMBED_MODEL = "BAAI/bge-large-en-v1.5" |
|
|
LLM_MODEL = "openai/gpt-oss-120b" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_faiss(): |
|
|
return faiss.read_index(FAISS_INDEX_PATH) |
|
|
|
|
|
def load_metadata(): |
|
|
with open(METADATA_PATH, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
def load_embedder(): |
|
|
return SentenceTransformer(EMBED_MODEL) |
|
|
|
|
|
def load_llm(): |
|
|
return Groq(api_key=GROQ_API_KEY) |
|
|
|
|
|
def load_neo4j(): |
|
|
if not all([NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD]): |
|
|
raise ValueError("Neo4j credentials not configured") |
|
|
|
|
|
driver = GraphDatabase.driver( |
|
|
NEO4J_URI, |
|
|
auth=(NEO4J_USER, NEO4J_PASSWORD), |
|
|
max_connection_lifetime=3600, |
|
|
max_connection_pool_size=50, |
|
|
connection_acquisition_timeout=120 |
|
|
) |
|
|
|
|
|
driver.verify_connectivity() |
|
|
return driver |
|
|
|
|
|
|
|
|
|
|
|
print("Loading FAISS index...") |
|
|
faiss_index = load_faiss() |
|
|
print("Loading metadata...") |
|
|
metadata = load_metadata() |
|
|
print("Loading embedder model...") |
|
|
embedder = load_embedder() |
|
|
print("Loading Groq LLM client...") |
|
|
groq_client = load_llm() |
|
|
|
|
|
|
|
|
neo4j_status = "" |
|
|
neo4j_driver = None |
|
|
try: |
|
|
print("Connecting to Neo4j...") |
|
|
neo4j_driver = load_neo4j() |
|
|
neo4j_status = "β
Connected to Neo4j" |
|
|
print(neo4j_status) |
|
|
except Exception as e: |
|
|
neo4j_status = f"β Neo4j Connection Failed: {str(e)}" |
|
|
print(neo4j_status) |
|
|
print("β οΈ App will continue with FAISS search only (Graph features disabled)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_graph_info(drug_name): |
|
|
if neo4j_driver is None: |
|
|
return {} |
|
|
|
|
|
query = """ |
|
|
MATCH (d:Drug {name: $name})-[r]->(n) |
|
|
RETURN type(r) AS relation, n.name AS value |
|
|
LIMIT 200 |
|
|
""" |
|
|
try: |
|
|
with neo4j_driver.session(database=NEO4J_DATABASE) as session: |
|
|
result = session.run(query, name=drug_name).data() |
|
|
except Exception as e: |
|
|
return {} |
|
|
|
|
|
graph_dict = {} |
|
|
for row in result: |
|
|
relation = row["relation"] |
|
|
value = row["value"] |
|
|
graph_dict.setdefault(relation, []).append(value) |
|
|
|
|
|
return graph_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def semantic_search(query, top_k=5): |
|
|
query_emb = embedder.encode(query).astype("float32") |
|
|
|
|
|
distances, indices = faiss_index.search( |
|
|
np.array([query_emb]), top_k |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for idx in indices[0]: |
|
|
results.append(metadata[idx]) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_with_groq(query, retrieved, graph_info): |
|
|
system_prompt = """ |
|
|
You are a medical question answering assistant. |
|
|
You must: |
|
|
- Use the retrieved medicine information. |
|
|
- Use graph relations (substitutes, side effects, uses, classes). |
|
|
- Never hallucinate facts. |
|
|
- Respond using ONLY provided context. |
|
|
""" |
|
|
|
|
|
|
|
|
text_block = "" |
|
|
for item in retrieved: |
|
|
text_block += f""" |
|
|
Medicine: {item['name']} |
|
|
Uses: {item['uses']} |
|
|
Side Effects: {item['side_effects']} |
|
|
Manufacturer: {item['manufacturer']} |
|
|
""" |
|
|
|
|
|
|
|
|
graph_text = "" |
|
|
for medicine, relations in graph_info.items(): |
|
|
graph_text += f"\nGraph Data for {medicine}:\n" |
|
|
for rel, vals in relations.items(): |
|
|
graph_text += f"{rel}: {', '.join(vals)}\n" |
|
|
|
|
|
full_prompt = f""" |
|
|
{system_prompt} |
|
|
|
|
|
User Query: |
|
|
{query} |
|
|
|
|
|
Retrieved Medicine Data: |
|
|
{text_block} |
|
|
|
|
|
Graph Knowledge: |
|
|
{graph_text} |
|
|
|
|
|
Final Answer: |
|
|
""" |
|
|
|
|
|
response = groq_client.chat.completions.create( |
|
|
model=LLM_MODEL, |
|
|
messages=[{"role": "user", "content": full_prompt}], |
|
|
temperature=0.2, |
|
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_query(query): |
|
|
"""Main function to process user query and return results""" |
|
|
if not query.strip(): |
|
|
return "β οΈ Please enter a query.", "", "", neo4j_status |
|
|
|
|
|
|
|
|
status_msg = "π Searching medicines via FAISS semantic search...\n" |
|
|
results = semantic_search(query) |
|
|
|
|
|
|
|
|
medicines_text = "### π¬ Top Relevant Medicines\n\n" |
|
|
for r in results: |
|
|
medicines_text += f"**{r['name']}** β {r['uses']}\n\n" |
|
|
|
|
|
|
|
|
status_msg += "π§ Expanding Knowledge Graph for all retrieved medicines...\n" |
|
|
graph_dict = {} |
|
|
for r in results: |
|
|
graph_dict[r["name"]] = get_graph_info(r["name"]) |
|
|
|
|
|
graph_text = "### 𧬠Graph Relations Found\n\n" |
|
|
graph_text += json.dumps(graph_dict, indent=2) |
|
|
|
|
|
|
|
|
status_msg += "π€ Generating LLM Answer...\n" |
|
|
answer = answer_with_groq(query, results, graph_dict) |
|
|
|
|
|
final_answer = "### π©Ί Final Answer\n\n" + answer |
|
|
|
|
|
return medicines_text, graph_text, final_answer, neo4j_status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
with gr.Blocks(title="Medicine GraphRAG AI") as demo: |
|
|
gr.Markdown("# π Medicine GraphRAG AI") |
|
|
gr.Markdown("**Semantic Search + Graph DB + LLM reasoning using Groq GPT-OSS-120B**") |
|
|
|
|
|
with gr.Row(): |
|
|
status_display = gr.Textbox( |
|
|
label="Database Status", |
|
|
value=neo4j_status, |
|
|
interactive=False, |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
query_input = gr.Textbox( |
|
|
label="Enter your medical query", |
|
|
placeholder="e.g., best medicine for acidity", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
search_btn = gr.Button("Search", variant="primary", size="lg") |
|
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
medicines_output = gr.Markdown(label="Top Relevant Medicines") |
|
|
|
|
|
with gr.Column(): |
|
|
graph_output = gr.Markdown(label="Graph Relations") |
|
|
|
|
|
with gr.Row(): |
|
|
answer_output = gr.Markdown(label="Final Answer") |
|
|
|
|
|
|
|
|
search_btn.click( |
|
|
fn=process_query, |
|
|
inputs=[query_input], |
|
|
outputs=[medicines_output, graph_output, answer_output, status_display] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: ("", "", "", neo4j_status), |
|
|
inputs=[], |
|
|
outputs=[medicines_output, graph_output, answer_output, status_display] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["What is the best medicine for acidity?"], |
|
|
["Show me medicines for headache"], |
|
|
["What are the side effects of paracetamol?"], |
|
|
["Suggest medicine for cold and fever"] |
|
|
], |
|
|
inputs=query_input |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch() |
|
|
|