Spaces:
Sleeping
Sleeping
| """ | |
| Simple Medical Chatbot Interface v2.0 | |
| Beautiful Gradio interface for the simplified medical RAG system | |
| """ | |
| import gradio as gr | |
| import time | |
| import json | |
| from datetime import datetime | |
| from typing import List, Tuple, Dict, Any | |
| # Import our simplified medical RAG system | |
| from simple_medical_rag import SimpleMedicalRAG, MedicalResponse | |
| class SimpleMedicalChatbot: | |
| """Professional medical chatbot interface using simplified RAG system""" | |
| def __init__(self): | |
| """Initialize the medical chatbot""" | |
| self.rag_system = None | |
| self.chat_history = [] | |
| self.session_stats = { | |
| "queries_processed": 0, | |
| "total_response_time": 0, | |
| "avg_confidence": 0, | |
| "session_start": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| # Initialize RAG system | |
| self._initialize_rag_system() | |
| def _initialize_rag_system(self): | |
| """Initialize the RAG system""" | |
| try: | |
| print("π Initializing Medical RAG System...") | |
| self.rag_system = SimpleMedicalRAG() | |
| print("β Medical RAG System initialized successfully!") | |
| except Exception as e: | |
| print(f"β Error initializing RAG system: {e}") | |
| self.rag_system = None | |
| def process_query(self, query: str, history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: | |
| """Process medical query and return response""" | |
| if not self.rag_system: | |
| error_msg = "β **System Error**: Medical RAG system not initialized. Please refresh and try again." | |
| history.append((query, error_msg)) | |
| return history, "" | |
| if not query.strip(): | |
| return history, "" | |
| start_time = time.time() | |
| try: | |
| # Process query with RAG system | |
| response = self.rag_system.query(query, k=5) | |
| # Format response for display | |
| formatted_response = self._format_response_for_display(response) | |
| # Update session statistics | |
| query_time = time.time() - start_time | |
| self._update_session_stats(query_time, response.confidence) | |
| # Add to chat history | |
| history.append((query, formatted_response)) | |
| return history, "" | |
| except Exception as e: | |
| error_msg = f"β **Error processing query**: {str(e)}\n\nβ οΈ Please try rephrasing your question or contact support." | |
| history.append((query, error_msg)) | |
| return history, "" | |
| def _format_response_for_display(self, response: MedicalResponse) -> str: | |
| """Format medical response for beautiful display in Gradio""" | |
| # Confidence level indicator | |
| confidence_emoji = "π’" if response.confidence > 0.7 else "π‘" if response.confidence > 0.5 else "π΄" | |
| confidence_text = f"{confidence_emoji} **Confidence: {response.confidence:.1%}**" | |
| # Response type indicator | |
| type_emoji = "π" if "dosage" in response.response_type else "π¨" if "emergency" in response.response_type else "π₯" | |
| # Main response | |
| formatted_response = f""" | |
| {type_emoji} **Medical Information** | |
| {response.answer} | |
| --- | |
| π **Response Details** | |
| {confidence_text} | |
| π **Sources**: {len(response.sources)} documents referenced | |
| """ | |
| # Add top sources | |
| if response.sources: | |
| formatted_response += "π **Primary Sources**:\n" | |
| for i, source in enumerate(response.sources[:3], 1): | |
| doc_name = source['document'].replace('.pdf', '').replace('-', ' ').title() | |
| formatted_response += f"{i}. {doc_name} (Relevance: {source['relevance_score']:.1%})\n" | |
| formatted_response += "\n" | |
| # Add medical disclaimer | |
| formatted_response += f""" | |
| --- | |
| {response.medical_disclaimer} | |
| π **Note**: This response is based on Sri Lankan maternal health guidelines and should be used in conjunction with current clinical protocols. | |
| """ | |
| return formatted_response | |
| def _update_session_stats(self, query_time: float, confidence: float): | |
| """Update session statistics""" | |
| self.session_stats["queries_processed"] += 1 | |
| self.session_stats["total_response_time"] += query_time | |
| # Update average confidence | |
| current_avg = self.session_stats["avg_confidence"] | |
| queries = self.session_stats["queries_processed"] | |
| self.session_stats["avg_confidence"] = ((current_avg * (queries - 1)) + confidence) / queries | |
| def get_system_info(self) -> str: | |
| """Get system information for display""" | |
| if not self.rag_system: | |
| return "β **System Status**: Not initialized" | |
| try: | |
| stats = self.rag_system.get_system_stats() | |
| system_info = f""" | |
| π₯ **Sri Lankan Maternal Health Assistant v2.0** | |
| π **System Status**: {stats['status'].upper()} β | |
| **Knowledge Base**: | |
| β’ π Total Documents: {stats['vector_store']['total_chunks']:,} medical chunks | |
| β’ π§ Embedding Model: {stats['vector_store']['embedding_model']} | |
| β’ πΎ Vector Store Size: {stats['vector_store']['vector_store_size_mb']} MB | |
| β’ β‘ Approach: Simplified document-based retrieval | |
| **Content Distribution**: | |
| """ | |
| # Add content distribution | |
| for content_type, count in stats['vector_store']['content_type_distribution'].items(): | |
| percentage = (count / stats['vector_store']['total_chunks']) * 100 | |
| content_info = content_type.replace('_', ' ').title() | |
| system_info += f"β’ {content_info}: {count:,} chunks ({percentage:.1f}%)\n" | |
| return system_info | |
| except Exception as e: | |
| return f"β **Error retrieving system info**: {str(e)}" | |
| def get_session_stats(self) -> str: | |
| """Get session statistics for display""" | |
| if self.session_stats["queries_processed"] == 0: | |
| return "π **Session Statistics**: No queries processed yet" | |
| avg_response_time = self.session_stats["total_response_time"] / self.session_stats["queries_processed"] | |
| return f""" | |
| π **Session Statistics** | |
| π **Session Started**: {self.session_stats["session_start"]} | |
| π **Queries Processed**: {self.session_stats["queries_processed"]} | |
| β‘ **Avg Response Time**: {avg_response_time:.2f} seconds | |
| π― **Avg Confidence**: {self.session_stats["avg_confidence"]:.1%} | |
| """ | |
| def clear_chat(self) -> Tuple[List, str]: | |
| """Clear chat history""" | |
| self.chat_history = [] | |
| return [], "" | |
| def get_example_queries(self) -> List[str]: | |
| """Get example medical queries""" | |
| return [ | |
| "What is the dosage of magnesium sulfate for preeclampsia?", | |
| "How to manage postpartum hemorrhage emergency?", | |
| "Normal fetal heart rate during labor monitoring?", | |
| "Management protocol for breech delivery?", | |
| "Antenatal care schedule for high-risk pregnancies?", | |
| "Signs and symptoms of preeclampsia?", | |
| "When to perform cesarean delivery?", | |
| "Postpartum care guidelines for new mothers?" | |
| ] | |
| def create_medical_chatbot_interface(): | |
| """Create the main Gradio interface""" | |
| # Initialize chatbot | |
| chatbot = SimpleMedicalChatbot() | |
| # Custom CSS for medical theme | |
| css = """ | |
| .gradio-container { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .medical-header { | |
| background: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .chat-container { | |
| background: white; | |
| border-radius: 15px; | |
| box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1); | |
| } | |
| .medical-disclaimer { | |
| background: #fff3cd; | |
| border: 1px solid #ffeaa7; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| color: #856404; | |
| } | |
| .example-queries { | |
| background: #e8f5e8; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="Sri Lankan Maternal Health Assistant", theme=gr.themes.Soft()) as interface: | |
| # Header | |
| gr.Markdown(""" | |
| # π₯ Sri Lankan Maternal Health Assistant v2.0 | |
| ### Simplified Document-Based Medical RAG System | |
| **Professional medical guidance based on Sri Lankan maternal health guidelines** | |
| """, elem_classes=["medical-header"]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Main chat interface | |
| with gr.Group(elem_classes=["chat-container"]): | |
| gr.Markdown("## π¬ Medical Query Interface") | |
| chatbot_display = gr.Chatbot( | |
| label="Medical Assistant", | |
| height=500, | |
| show_label=False, | |
| container=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| placeholder="Ask a medical question about maternal health...", | |
| label="Your Medical Query", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("π Ask", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
| refresh_btn = gr.Button("π Refresh System", variant="secondary") | |
| with gr.Column(scale=1): | |
| # System information and examples | |
| with gr.Group(): | |
| gr.Markdown("## π System Information") | |
| system_info_display = gr.Markdown( | |
| chatbot.get_system_info(), | |
| label="System Status" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("## π Session Statistics") | |
| session_stats_display = gr.Markdown( | |
| chatbot.get_session_stats(), | |
| label="Current Session" | |
| ) | |
| # Example queries | |
| with gr.Group(elem_classes=["example-queries"]): | |
| gr.Markdown("## π‘ Example Queries") | |
| example_queries = chatbot.get_example_queries() | |
| for i, example in enumerate(example_queries[:4]): | |
| example_btn = gr.Button( | |
| f"π {example}", | |
| variant="secondary", | |
| size="sm" | |
| ) | |
| example_btn.click( | |
| fn=lambda x=example: x, | |
| outputs=query_input | |
| ) | |
| # Medical disclaimer | |
| gr.Markdown(""" | |
| ## β οΈ Important Medical Disclaimer | |
| This system provides information from Sri Lankan maternal health guidelines for **educational purposes only**. | |
| **Always consult qualified healthcare providers for**: | |
| - Medical decisions and patient care | |
| - Emergency medical situations | |
| - Clinical diagnosis and treatment | |
| - Medication administration | |
| This tool is designed to **supplement**, not replace, professional medical judgment. | |
| """, elem_classes=["medical-disclaimer"]) | |
| # Event handlers | |
| def submit_query(query, history): | |
| """Handle query submission""" | |
| new_history, _ = chatbot.process_query(query, history) | |
| return new_history, "", chatbot.get_session_stats() | |
| def refresh_system(): | |
| """Refresh system information""" | |
| return chatbot.get_system_info(), chatbot.get_session_stats() | |
| def clear_chat_handler(): | |
| """Handle chat clearing""" | |
| new_history, _ = chatbot.clear_chat() | |
| return new_history, "", chatbot.get_session_stats() | |
| # Connect event handlers | |
| submit_btn.click( | |
| fn=submit_query, | |
| inputs=[query_input, chatbot_display], | |
| outputs=[chatbot_display, query_input, session_stats_display] | |
| ) | |
| query_input.submit( | |
| fn=submit_query, | |
| inputs=[query_input, chatbot_display], | |
| outputs=[chatbot_display, query_input, session_stats_display] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat_handler, | |
| inputs=[], | |
| outputs=[chatbot_display, query_input, session_stats_display] | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_system, | |
| inputs=[], | |
| outputs=[system_info_display, session_stats_display] | |
| ) | |
| return interface | |
| def main(): | |
| """Main function to launch the medical chatbot""" | |
| print("π Launching Sri Lankan Maternal Health Assistant v2.0") | |
| print("=" * 60) | |
| # Create and launch interface | |
| interface = create_medical_chatbot_interface() | |
| # Launch with custom settings | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, # Enable public sharing | |
| show_error=True, | |
| inbrowser=True, | |
| debug=True | |
| ) | |
| if __name__ == "__main__": | |
| main() |