FurqanIshaq's picture
Update app.py
2a45158 verified
# ==========================================================
# MedMentor β€” Virtual Patient Chatbot (Hugging Face Prototype)
# Author: Furqan Ishaq (Supervised Build)
# ==========================================================
import os
import json
import gradio as gr
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from google.generativeai import configure, GenerativeModel
from datetime import datetime
from utils.rag_utils import RAGRetriever
from utils.response_generator import ResponseGenerator
from utils.rule_engine import RuleEngine
# ================================
# 1️⃣ SETUP GEMINI + EMBEDDINGS
# ================================
configure(api_key=os.getenv("GEMINI_API_KEY")) # Add this in HF Secrets
model = GenerativeModel("gemini-2.5-pro")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# ================================
# 2️⃣ PATHS & CONFIG
# ================================
DATA_DIR = "datasets/diabetes"
VECTOR_DIR = os.path.join(DATA_DIR, "vectorstores")
RAG_DIR = os.path.join(DATA_DIR, "rag_text_enriched")
JSON_DIR = os.path.join(DATA_DIR, "cleaned")
SESSIONS_DIR = "sessions"
os.makedirs(SESSIONS_DIR, exist_ok=True)
# Initialize retriever (for diabetes dataset)
retriever = RAGRetriever(DATA_DIR)
response_gen = ResponseGenerator()
# Initialize Rule Engine
rule_engine = RuleEngine(
dataset_path="/content/drive/MyDrive/MedMentor_Final_Year_Project/datasets/diabetes"
)
# ================================
# 3️⃣ LOAD PATIENT CONTEXT
# ================================
def load_patient_context(patient_id):
"""Load RAG data and embeddings for a given patient via retriever."""
index, text_chunks, _ = retriever.load_patient_data(patient_id)
return text_chunks, index
# ================================
# 4️⃣ RETRIEVE CONTEXT
# ================================
def retrieve_context(query, text_chunks, index, top_k=5):
"""Retrieve top relevant chunks from vector DB."""
query_emb = embedding_model.encode([query])
D, I = index.search(query_emb, top_k)
retrieved = [text_chunks[i] for i in I[0] if i < len(text_chunks)]
return "\n".join(retrieved)
# ================================
# 5️⃣ PATIENT RESPONSE GENERATION
# ================================
def generate_patient_response(user_input, context):
"""Generate patient response using Gemini model."""
prompt = f"""
You are a virtual patient in a medical consultation.
Speak ONLY as the patient β€” do not provide medical advice.
Use the following medical context:
[PATIENT CONTEXT]
{context}
[DOCTOR'S QUESTION]
{user_input}
Respond naturally, in short conversational style, as if you are the patient.
"""
response = model.generate_content(prompt)
return response.text.strip()
# ================================
# 6️⃣ SESSION MANAGEMENT
# ================================
def load_or_create_session(patient_id):
"""Load previous chat history for same patient."""
session_path = os.path.join(SESSIONS_DIR, f"{patient_id}_session.json")
if os.path.exists(session_path):
with open(session_path, "r") as f:
return json.load(f)
else:
return {"patient_id": patient_id, "chat_history": [], "diagnosed": False}
def save_session(patient_id, session_data):
"""Save chat history back to file."""
session_path = os.path.join(SESSIONS_DIR, f"{patient_id}_session.json")
with open(session_path, "w") as f:
json.dump(session_data, f, indent=2)
# ================================
# 7️⃣ RULE-BASED EVALUATION
# ================================
def evaluate_diagnosis(patient_id, doctor_diagnosis):
"""Simple evaluation: compare with metadata ground truth."""
json_path = os.path.join(JSON_DIR, f"{patient_id}.json")
with open(json_path, "r") as f:
data = json.load(f)
true_disease = data["metadata"]["disease"].lower()
doctor_diagnosis = doctor_diagnosis.lower()
result = {}
if doctor_diagnosis in true_disease:
result["status"] = "βœ… Correct Diagnosis"
result["score"] = 95
result["missed_aspects"] = []
else:
result["status"] = "❌ Incorrect Diagnosis"
result["score"] = 60
result["missed_aspects"] = [
"Doctor missed polyuria-polydipsia link",
"Ignored family diabetes history",
]
return result
# ================================
# 8️⃣ MAIN CHAT LOGIC
# ================================
def medmentor_chat(user_input, state, patient_id):
"""Handles chat and diagnosis end logic."""
# Load or reuse session
session = load_or_create_session(patient_id)
text_chunks, index = load_patient_context(patient_id)
if "diagnosis complete" in user_input.lower():
# Ask doctor for their diagnosis
return "Please type the name of the disease you diagnosed.", state
elif "disease:" in user_input.lower():
# Extract doctor’s diagnosis
diagnosis = user_input.split("disease:")[-1].strip()
result = evaluate_diagnosis(patient_id, diagnosis)
summary = f"""
🩺 Doctor diagnosed: **{diagnosis}**
πŸ“Š Evaluation: {result['status']}
⭐ Score: {result['score']}%
⚠️ Missed Points: {', '.join(result['missed_aspects']) if result['missed_aspects'] else 'None'}
"""
session["diagnosed"] = True
save_session(patient_id, session)
return summary, state
else:
# Retrieve top context and generate response
retrieved_context = retrieve_context(user_input, text_chunks, index)
patient_reply = response_gen.generate_response(user_input, retrieved_context.split("\n"))
# Save chat to session
session["chat_history"].append({"doctor": user_input, "patient": patient_reply})
save_session(patient_id, session)
return patient_reply, state
# ================================
# 9️⃣ BUILD GRADIO INTERFACE
# ================================
with gr.Blocks(title="MedMentor Virtual Patient") as demo:
gr.Markdown("## 🧠 MedMentor: Virtual Patient Consultation\nChat with a synthetic patient powered by RAG + Gemini API.")
patient_id = gr.Textbox(label="Enter Patient ID", value="DIA_001")
chatbot = gr.Chatbot(height=400)
user_input = gr.Textbox(label="Doctor's Question", placeholder="Ask your question here...")
state = gr.State([])
def respond(message, chat_history, patient_id):
response, new_state = medmentor_chat(message, chat_history, patient_id)
chat_history.append((message, response))
return "", chat_history
user_input.submit(respond, [user_input, chatbot, patient_id], [user_input, chatbot])
@app.route("/evaluate_diagnosis", methods=["POST"])
def evaluate_diagnosis():
data = request.get_json()
patient_id = data.get("patient_id")
doctor_diagnosis = data.get("doctor_diagnosis")
if not patient_id or not doctor_diagnosis:
return jsonify({"status": "error", "message": "Missing required fields"}), 400
result = rule_engine.evaluate_diagnosis(patient_id, doctor_diagnosis)
return jsonify(result)
# ================================
# πŸ”Ÿ LAUNCH APP
# ================================
if __name__ == "__main__":
demo.launch()