Spaces:
Runtime error
Runtime error
| # ========================================================== | |
| # MedMentor β Virtual Patient Chatbot (Hugging Face Prototype) | |
| # Author: Furqan Ishaq (Supervised Build) | |
| # ========================================================== | |
| import os | |
| import json | |
| import gradio as gr | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from google.generativeai import configure, GenerativeModel | |
| from datetime import datetime | |
| from utils.rag_utils import RAGRetriever | |
| from utils.response_generator import ResponseGenerator | |
| from utils.rule_engine import RuleEngine | |
| # ================================ | |
| # 1οΈβ£ SETUP GEMINI + EMBEDDINGS | |
| # ================================ | |
| configure(api_key=os.getenv("GEMINI_API_KEY")) # Add this in HF Secrets | |
| model = GenerativeModel("gemini-2.5-pro") | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # ================================ | |
| # 2οΈβ£ PATHS & CONFIG | |
| # ================================ | |
| DATA_DIR = "datasets/diabetes" | |
| VECTOR_DIR = os.path.join(DATA_DIR, "vectorstores") | |
| RAG_DIR = os.path.join(DATA_DIR, "rag_text_enriched") | |
| JSON_DIR = os.path.join(DATA_DIR, "cleaned") | |
| SESSIONS_DIR = "sessions" | |
| os.makedirs(SESSIONS_DIR, exist_ok=True) | |
| # Initialize retriever (for diabetes dataset) | |
| retriever = RAGRetriever(DATA_DIR) | |
| response_gen = ResponseGenerator() | |
| # Initialize Rule Engine | |
| rule_engine = RuleEngine( | |
| dataset_path="/content/drive/MyDrive/MedMentor_Final_Year_Project/datasets/diabetes" | |
| ) | |
| # ================================ | |
| # 3οΈβ£ LOAD PATIENT CONTEXT | |
| # ================================ | |
| def load_patient_context(patient_id): | |
| """Load RAG data and embeddings for a given patient via retriever.""" | |
| index, text_chunks, _ = retriever.load_patient_data(patient_id) | |
| return text_chunks, index | |
| # ================================ | |
| # 4οΈβ£ RETRIEVE CONTEXT | |
| # ================================ | |
| def retrieve_context(query, text_chunks, index, top_k=5): | |
| """Retrieve top relevant chunks from vector DB.""" | |
| query_emb = embedding_model.encode([query]) | |
| D, I = index.search(query_emb, top_k) | |
| retrieved = [text_chunks[i] for i in I[0] if i < len(text_chunks)] | |
| return "\n".join(retrieved) | |
| # ================================ | |
| # 5οΈβ£ PATIENT RESPONSE GENERATION | |
| # ================================ | |
| def generate_patient_response(user_input, context): | |
| """Generate patient response using Gemini model.""" | |
| prompt = f""" | |
| You are a virtual patient in a medical consultation. | |
| Speak ONLY as the patient β do not provide medical advice. | |
| Use the following medical context: | |
| [PATIENT CONTEXT] | |
| {context} | |
| [DOCTOR'S QUESTION] | |
| {user_input} | |
| Respond naturally, in short conversational style, as if you are the patient. | |
| """ | |
| response = model.generate_content(prompt) | |
| return response.text.strip() | |
| # ================================ | |
| # 6οΈβ£ SESSION MANAGEMENT | |
| # ================================ | |
| def load_or_create_session(patient_id): | |
| """Load previous chat history for same patient.""" | |
| session_path = os.path.join(SESSIONS_DIR, f"{patient_id}_session.json") | |
| if os.path.exists(session_path): | |
| with open(session_path, "r") as f: | |
| return json.load(f) | |
| else: | |
| return {"patient_id": patient_id, "chat_history": [], "diagnosed": False} | |
| def save_session(patient_id, session_data): | |
| """Save chat history back to file.""" | |
| session_path = os.path.join(SESSIONS_DIR, f"{patient_id}_session.json") | |
| with open(session_path, "w") as f: | |
| json.dump(session_data, f, indent=2) | |
| # ================================ | |
| # 7οΈβ£ RULE-BASED EVALUATION | |
| # ================================ | |
| def evaluate_diagnosis(patient_id, doctor_diagnosis): | |
| """Simple evaluation: compare with metadata ground truth.""" | |
| json_path = os.path.join(JSON_DIR, f"{patient_id}.json") | |
| with open(json_path, "r") as f: | |
| data = json.load(f) | |
| true_disease = data["metadata"]["disease"].lower() | |
| doctor_diagnosis = doctor_diagnosis.lower() | |
| result = {} | |
| if doctor_diagnosis in true_disease: | |
| result["status"] = "β Correct Diagnosis" | |
| result["score"] = 95 | |
| result["missed_aspects"] = [] | |
| else: | |
| result["status"] = "β Incorrect Diagnosis" | |
| result["score"] = 60 | |
| result["missed_aspects"] = [ | |
| "Doctor missed polyuria-polydipsia link", | |
| "Ignored family diabetes history", | |
| ] | |
| return result | |
| # ================================ | |
| # 8οΈβ£ MAIN CHAT LOGIC | |
| # ================================ | |
| def medmentor_chat(user_input, state, patient_id): | |
| """Handles chat and diagnosis end logic.""" | |
| # Load or reuse session | |
| session = load_or_create_session(patient_id) | |
| text_chunks, index = load_patient_context(patient_id) | |
| if "diagnosis complete" in user_input.lower(): | |
| # Ask doctor for their diagnosis | |
| return "Please type the name of the disease you diagnosed.", state | |
| elif "disease:" in user_input.lower(): | |
| # Extract doctorβs diagnosis | |
| diagnosis = user_input.split("disease:")[-1].strip() | |
| result = evaluate_diagnosis(patient_id, diagnosis) | |
| summary = f""" | |
| π©Ί Doctor diagnosed: **{diagnosis}** | |
| π Evaluation: {result['status']} | |
| β Score: {result['score']}% | |
| β οΈ Missed Points: {', '.join(result['missed_aspects']) if result['missed_aspects'] else 'None'} | |
| """ | |
| session["diagnosed"] = True | |
| save_session(patient_id, session) | |
| return summary, state | |
| else: | |
| # Retrieve top context and generate response | |
| retrieved_context = retrieve_context(user_input, text_chunks, index) | |
| patient_reply = response_gen.generate_response(user_input, retrieved_context.split("\n")) | |
| # Save chat to session | |
| session["chat_history"].append({"doctor": user_input, "patient": patient_reply}) | |
| save_session(patient_id, session) | |
| return patient_reply, state | |
| # ================================ | |
| # 9οΈβ£ BUILD GRADIO INTERFACE | |
| # ================================ | |
| with gr.Blocks(title="MedMentor Virtual Patient") as demo: | |
| gr.Markdown("## π§ MedMentor: Virtual Patient Consultation\nChat with a synthetic patient powered by RAG + Gemini API.") | |
| patient_id = gr.Textbox(label="Enter Patient ID", value="DIA_001") | |
| chatbot = gr.Chatbot(height=400) | |
| user_input = gr.Textbox(label="Doctor's Question", placeholder="Ask your question here...") | |
| state = gr.State([]) | |
| def respond(message, chat_history, patient_id): | |
| response, new_state = medmentor_chat(message, chat_history, patient_id) | |
| chat_history.append((message, response)) | |
| return "", chat_history | |
| user_input.submit(respond, [user_input, chatbot, patient_id], [user_input, chatbot]) | |
| def evaluate_diagnosis(): | |
| data = request.get_json() | |
| patient_id = data.get("patient_id") | |
| doctor_diagnosis = data.get("doctor_diagnosis") | |
| if not patient_id or not doctor_diagnosis: | |
| return jsonify({"status": "error", "message": "Missing required fields"}), 400 | |
| result = rule_engine.evaluate_diagnosis(patient_id, doctor_diagnosis) | |
| return jsonify(result) | |
| # ================================ | |
| # π LAUNCH APP | |
| # ================================ | |
| if __name__ == "__main__": | |
| demo.launch() | |