import gradio as gr import faiss import json import numpy as np from sentence_transformers import SentenceTransformer from groq import Groq from neo4j import GraphDatabase from dotenv import load_dotenv import os load_dotenv() # Load credentials from environment or Hugging Face Spaces secrets GROQ_API_KEY = os.getenv("GROQ_API_KEY") NEO4J_URI = os.getenv("NEO4J_URI") NEO4J_USER = os.getenv("NEO4J_USERNAME") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j") FAISS_INDEX_PATH = "db/medicine_embeddings.index" METADATA_PATH = "db/metadata.json" EMBED_MODEL = "BAAI/bge-large-en-v1.5" LLM_MODEL = "openai/gpt-oss-120b" # --------------------------------------------------------- # LOAD MODELS & DATABASES (ON STARTUP) # --------------------------------------------------------- def load_faiss(): return faiss.read_index(FAISS_INDEX_PATH) def load_metadata(): with open(METADATA_PATH, "r") as f: return json.load(f) def load_embedder(): return SentenceTransformer(EMBED_MODEL) def load_llm(): return Groq(api_key=GROQ_API_KEY) def load_neo4j(): if not all([NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD]): raise ValueError("Neo4j credentials not configured") driver = GraphDatabase.driver( NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD), max_connection_lifetime=3600, max_connection_pool_size=50, connection_acquisition_timeout=120 ) # Test the connection driver.verify_connectivity() return driver # Initialize resources print("Loading FAISS index...") faiss_index = load_faiss() print("Loading metadata...") metadata = load_metadata() print("Loading embedder model...") embedder = load_embedder() print("Loading Groq LLM client...") groq_client = load_llm() # Load Neo4j with error handling neo4j_status = "" neo4j_driver = None try: print("Connecting to Neo4j...") neo4j_driver = load_neo4j() neo4j_status = "✅ Connected to Neo4j" print(neo4j_status) except Exception as e: neo4j_status = f"❌ Neo4j Connection Failed: {str(e)}" print(neo4j_status) print("⚠️ App will continue with FAISS search only (Graph features disabled)") # --------------------------------------------------------- # GRAPH EXPANSION — FETCH RELATED NODES # --------------------------------------------------------- def get_graph_info(drug_name): if neo4j_driver is None: return {} query = """ MATCH (d:Drug {name: $name})-[r]->(n) RETURN type(r) AS relation, n.name AS value LIMIT 200 """ try: with neo4j_driver.session(database=NEO4J_DATABASE) as session: result = session.run(query, name=drug_name).data() except Exception as e: return {} graph_dict = {} for row in result: relation = row["relation"] value = row["value"] graph_dict.setdefault(relation, []).append(value) return graph_dict # --------------------------------------------------------- # SEMANTIC SEARCH (FAISS) # --------------------------------------------------------- def semantic_search(query, top_k=5): query_emb = embedder.encode(query).astype("float32") distances, indices = faiss_index.search( np.array([query_emb]), top_k ) results = [] for idx in indices[0]: results.append(metadata[idx]) return results # --------------------------------------------------------- # LLM ANSWER USING GROQ # --------------------------------------------------------- def answer_with_groq(query, retrieved, graph_info): system_prompt = """ You are a medical question answering assistant. You must: - Use the retrieved medicine information. - Use graph relations (substitutes, side effects, uses, classes). - Never hallucinate facts. - Respond using ONLY provided context. """ # Build context from FAISS metadata text_block = "" for item in retrieved: text_block += f""" Medicine: {item['name']} Uses: {item['uses']} Side Effects: {item['side_effects']} Manufacturer: {item['manufacturer']} """ # Add graph info graph_text = "" for medicine, relations in graph_info.items(): graph_text += f"\nGraph Data for {medicine}:\n" for rel, vals in relations.items(): graph_text += f"{rel}: {', '.join(vals)}\n" full_prompt = f""" {system_prompt} User Query: {query} Retrieved Medicine Data: {text_block} Graph Knowledge: {graph_text} Final Answer: """ response = groq_client.chat.completions.create( model=LLM_MODEL, messages=[{"role": "user", "content": full_prompt}], temperature=0.2, ) return response.choices[0].message.content # --------------------------------------------------------- # MAIN QUERY FUNCTION # --------------------------------------------------------- def process_query(query): """Main function to process user query and return results""" if not query.strip(): return "⚠️ Please enter a query.", "", "", neo4j_status # Step 1: Semantic Search status_msg = "🔍 Searching medicines via FAISS semantic search...\n" results = semantic_search(query) # Step 2: Format retrieved medicines medicines_text = "### 🔬 Top Relevant Medicines\n\n" for r in results: medicines_text += f"**{r['name']}** — {r['uses']}\n\n" # Step 3: Graph expansion status_msg += "🧠 Expanding Knowledge Graph for all retrieved medicines...\n" graph_dict = {} for r in results: graph_dict[r["name"]] = get_graph_info(r["name"]) graph_text = "### 🧬 Graph Relations Found\n\n" graph_text += json.dumps(graph_dict, indent=2) # Step 4: Generate LLM answer status_msg += "🤖 Generating LLM Answer...\n" answer = answer_with_groq(query, results, graph_dict) final_answer = "### 🩺 Final Answer\n\n" + answer return medicines_text, graph_text, final_answer, neo4j_status # --------------------------------------------------------- # GRADIO UI # --------------------------------------------------------- def create_interface(): with gr.Blocks(title="Medicine GraphRAG AI") as demo: gr.Markdown("# 💊 Medicine GraphRAG AI") gr.Markdown("**Semantic Search + Graph DB + LLM reasoning using Groq GPT-OSS-120B**") with gr.Row(): status_display = gr.Textbox( label="Database Status", value=neo4j_status, interactive=False, lines=1 ) with gr.Row(): query_input = gr.Textbox( label="Enter your medical query", placeholder="e.g., best medicine for acidity", lines=2 ) with gr.Row(): search_btn = gr.Button("Search", variant="primary", size="lg") clear_btn = gr.Button("Clear", variant="secondary") with gr.Row(): with gr.Column(): medicines_output = gr.Markdown(label="Top Relevant Medicines") with gr.Column(): graph_output = gr.Markdown(label="Graph Relations") with gr.Row(): answer_output = gr.Markdown(label="Final Answer") # Event handlers search_btn.click( fn=process_query, inputs=[query_input], outputs=[medicines_output, graph_output, answer_output, status_display] ) clear_btn.click( fn=lambda: ("", "", "", neo4j_status), inputs=[], outputs=[medicines_output, graph_output, answer_output, status_display] ) # Examples gr.Examples( examples=[ ["What is the best medicine for acidity?"], ["Show me medicines for headache"], ["What are the side effects of paracetamol?"], ["Suggest medicine for cold and fever"] ], inputs=query_input ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()