Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| from datetime import datetime | |
| import traceback | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from rag_functions import get_direct_answer, get_answer_with_query_engine | |
| from utils import get_index | |
| print("β Successfully imported RAG functions") | |
| class PregnancyRiskAgent: | |
| def __init__(self): | |
| self.conversation_history = [] | |
| self.current_symptoms = {} | |
| self.risk_assessment_done = False | |
| self.user_context = {} | |
| self.last_user_query = "" | |
| self.symptom_questions = [ | |
| "Are you currently experiencing any unusual bleeding or discharge?", | |
| "How would you describe your baby's movements today compared to yesterday?", | |
| "Have you had any headaches that won't go away or that affect your vision?", | |
| "Do you feel any pressure or pain in your pelvis or lower back?", | |
| "Are you experiencing any other symptoms? (If yes, please describe briefly)" | |
| ] | |
| self.current_question_index = 0 | |
| self.waiting_for_first_response = True | |
| def add_to_conversation_history(self, role, message): | |
| self.conversation_history.append({ | |
| "role": role, | |
| "message": message, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| if len(self.conversation_history) > 20: | |
| self.conversation_history = self.conversation_history[-20:] | |
| def get_conversation_context(self): | |
| context_parts = [] | |
| recent_history = self.conversation_history[-10:] | |
| for entry in recent_history: | |
| if entry["role"] == "user": | |
| context_parts.append(f"User: {entry['message']}") | |
| else: | |
| context_parts.append(f"Assistant: {entry['message'][:200]}...") | |
| return "\n".join(context_parts) | |
| def is_follow_up_question(self, user_input): | |
| follow_up_indicators = [ | |
| "what about", "can you explain", "what does", "why", "how", | |
| "tell me more", "what should i", "is it normal", "should i be worried", | |
| "what if", "when should", "how long", "what causes", "is this" | |
| ] | |
| user_lower = user_input.lower() | |
| return any(indicator in user_lower for indicator in follow_up_indicators) | |
| def process_user_input(self, user_input, chat_history): | |
| try: | |
| self.last_user_query = user_input | |
| self.add_to_conversation_history("user", user_input) | |
| if self.waiting_for_first_response: | |
| self.current_symptoms[f"question_0"] = user_input | |
| self.waiting_for_first_response = False | |
| self.current_question_index = 1 | |
| if self.current_question_index < len(self.symptom_questions): | |
| bot_response = f"{self.symptom_questions[self.current_question_index]}" | |
| else: | |
| bot_response = self.provide_risk_assessment() | |
| self.risk_assessment_done = True | |
| self.add_to_conversation_history("assistant", bot_response) | |
| return bot_response | |
| elif self.current_question_index < len(self.symptom_questions) and not self.risk_assessment_done: | |
| self.current_symptoms[f"question_{self.current_question_index}"] = user_input | |
| self.current_question_index += 1 | |
| if self.current_question_index < len(self.symptom_questions): | |
| bot_response = f"{self.symptom_questions[self.current_question_index]}" | |
| else: | |
| bot_response = self.provide_risk_assessment() | |
| self.risk_assessment_done = True | |
| self.add_to_conversation_history("assistant", bot_response) | |
| return bot_response | |
| else: | |
| bot_response = self.handle_follow_up_conversation(user_input) | |
| self.add_to_conversation_history("assistant", bot_response) | |
| return bot_response | |
| except Exception as e: | |
| print(f"β Error in process_user_input: {e}") | |
| traceback.print_exc() | |
| error_response = "I encountered an error. Please try again or consult your healthcare provider." | |
| self.add_to_conversation_history("assistant", error_response) | |
| return error_response | |
| def handle_follow_up_conversation(self, user_input): | |
| try: | |
| print(f"π Processing follow-up question: {user_input}") | |
| symptom_summary = self.create_symptom_summary() | |
| conversation_context = self.get_conversation_context() | |
| if any(word in user_input.lower() for word in ["last", "previous", "what did i ask", "my question"]): | |
| if self.last_user_query: | |
| return f"Your last question was: \"{self.last_user_query}\"\n\nWould you like me to elaborate on that topic or do you have a different question?" | |
| else: | |
| return "I don't have a record of your previous question. Could you please rephrase what you'd like to know?" | |
| rag_response = get_direct_answer(user_input, symptom_summary, conversation_context=conversation_context, is_risk_assessment=False) | |
| if "Error" in rag_response or len(rag_response) < 50: | |
| print("π Trying alternative method...") | |
| rag_response = get_answer_with_query_engine(user_input) | |
| bot_response = f"""Based on your symptoms and medical literature: | |
| {rag_response}""" | |
| return bot_response | |
| except Exception as e: | |
| print(f"β Error in follow-up conversation: {e}") | |
| return "I encountered an error processing your question. Could you please rephrase it or consult your healthcare provider?" | |
| def create_symptom_summary(self): | |
| if not self.current_symptoms: | |
| return "No specific symptoms reported yet" | |
| summary_parts = [] | |
| for i, (key, response) in enumerate(self.current_symptoms.items()): | |
| if i < len(self.symptom_questions): | |
| question = self.symptom_questions[i] | |
| summary_parts.append(f"{question}: {response}") | |
| return "\n".join(summary_parts) | |
| def parse_risk_level(self, text): | |
| import re | |
| patterns = [ | |
| r'\*\*Risk Level:\*\*\s*(Low|Medium|High)', | |
| r'Risk Level:\s*\*\*(Low|Medium|High)\*\*', | |
| r'Risk Level:\s*(Low|Medium|High)', | |
| r'\*\*Risk Level:\*\*\s*<(Low|Medium|High)>', | |
| r'Risk Level.*?<(Low|Medium|High)>', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| risk_level = match.group(1).capitalize() | |
| print(f"β Successfully parsed risk level: {risk_level}") | |
| return risk_level | |
| print(f"β Could not parse risk level from: {text[:200]}...") | |
| return None | |
| def provide_risk_assessment(self): | |
| all_symptoms = self.create_symptom_summary() | |
| rag_query = f"Analyze these pregnancy symptoms for risk assessment:\n{all_symptoms}\n\nProvide risk level and medical recommendations." | |
| detailed_analysis = get_direct_answer(rag_query, all_symptoms, is_risk_assessment=True) | |
| print(f"π RAG Response: {detailed_analysis[:300]}...") | |
| llm_risk_level = self.parse_risk_level(detailed_analysis) | |
| if llm_risk_level: | |
| risk_level = llm_risk_level | |
| if risk_level == "Low": | |
| action = "β Continue routine prenatal care and self-monitoring" | |
| elif risk_level == "Medium": | |
| action = "β οΈ Contact your doctor within 24 hours" | |
| elif risk_level == "High": | |
| action = "π¨ Immediate visit to ER or OB emergency care required" | |
| else: | |
| print("β οΈ RAG assessment failed, using fallback") | |
| risk_level = "Medium" | |
| action = "β οΈ Contact your doctor within 24 hours" | |
| symptom_list = [] | |
| for i, (key, symptom) in enumerate(self.current_symptoms.items()): | |
| question = self.symptom_questions[i] if i < len(self.symptom_questions) else f"Question {i+1}" | |
| symptom_list.append(f"β’ **{question}**: {symptom}") | |
| assessment = f""" | |
| ## π₯ **Risk Assessment Complete** | |
| **Risk Level: {risk_level}** | |
| **Recommended Action: {action}** | |
| ### π **Your Reported Symptoms:** | |
| {chr(10).join(symptom_list)} | |
| ### π¬ **Medical Analysis:** | |
| {detailed_analysis} | |
| ### π‘ **Next Steps:** | |
| - Follow the recommended action above | |
| - Keep monitoring your symptoms | |
| - Contact your healthcare provider if symptoms worsen | |
| - Feel free to ask me any follow-up questions about pregnancy health | |
| """ | |
| return assessment | |
| def reset_conversation(self): | |
| self.conversation_history = [] | |
| self.current_symptoms = {} | |
| self.current_question_index = 0 | |
| self.risk_assessment_done = False | |
| self.waiting_for_first_response = True | |
| self.user_context = {} | |
| self.last_user_query = "" | |
| return get_welcome_message() | |
| def get_welcome_message(): | |
| return """Hello! I'm here to help assess pregnancy-related symptoms and provide risk insights based on medical literature. | |
| I'll ask you a few important questions about your current symptoms, then provide a risk assessment and recommendations. After that, feel free to ask any follow-up questions! | |
| **To get started, please tell me:** | |
| Are you currently experiencing any unusual bleeding or discharge? | |
| --- | |
| β οΈ **Important**: This tool is for informational purposes only and should not replace professional medical care. In case of emergency, contact your healthcare provider immediately.""" | |
| def create_new_agent(): | |
| return PregnancyRiskAgent() | |
| agent = create_new_agent() | |
| def chat_interface_with_reset(user_input, history): | |
| global agent | |
| if user_input.lower() in ["reset", "restart", "new assessment"]: | |
| agent = create_new_agent() | |
| return get_welcome_message() | |
| response = agent.process_user_input(user_input, history) | |
| return response | |
| def reset_chat(): | |
| global agent | |
| agent = create_new_agent() | |
| return [{"role": "assistant", "content": get_welcome_message()}], "" | |
| custom_css = """ | |
| body, .gradio-container { | |
| color: yellow !important; | |
| } | |
| .header { | |
| background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 100%); | |
| padding: 2rem; | |
| border-radius: 1rem; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
| } | |
| .header h1 { | |
| color: black !important; | |
| margin-bottom: 0.5rem; | |
| font-size: 2.5rem; | |
| } | |
| .header p { | |
| color: black !important; | |
| font-size: 1.1rem; | |
| margin: 0.5rem 0; | |
| } | |
| .warning { | |
| background-color: #fff4e6; | |
| border-left: 6px solid #ff7f00; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| .warning h3 { | |
| color: black !important; | |
| margin-top: 0; | |
| } | |
| .warning p { | |
| color: black !important; | |
| line-height: 1.6; | |
| } | |
| div[style*="background-color: #e8f5e8"] { | |
| color: black !important; | |
| } | |
| div[style*="background-color: #e8f5e8"] h3 { | |
| color: black !important; | |
| } | |
| div[style*="background-color: #e8f5e8"] li { | |
| color: black !important; | |
| } | |
| .chatbot { | |
| color: black !important; | |
| } | |
| .message { | |
| color: black !important; | |
| } | |
| /* Hide Gradio footer elements */ | |
| .footer { | |
| display: none !important; | |
| } | |
| .gradio-container .footer { | |
| display: none !important; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| .api-docs { | |
| display: none !important; | |
| } | |
| .built-with { | |
| display: none !important; | |
| } | |
| .gradio-container > .built-with { | |
| display: none !important; | |
| } | |
| .settings { | |
| display: none !important; | |
| } | |
| div[class*="footer"] { | |
| display: none !important; | |
| } | |
| div[class*="built"] { | |
| display: none !important; | |
| } | |
| *:contains("Built with Gradio") { | |
| display: none !important; | |
| } | |
| *:contains("Use via API") { | |
| display: none !important; | |
| } | |
| *:contains("Settings") { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>π€± Pregnancy RAG Chatbot</h1> | |
| <p><strong style="color: black !important;">Proactive RAG-powered pregnancy risk management</strong></p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="warning"> | |
| <h3>β οΈ Medical Disclaimer</h3> | |
| <p>This AI assistant provides information based on medical literature but is NOT a substitute for professional medical advice, diagnosis, or treatment.</p> | |
| <p><strong style="color: black !important;">In emergencies, call emergency services immediately.</strong></p> | |
| </div> | |
| """) | |
| chatbot = gr.ChatInterface( | |
| fn=chat_interface_with_reset, | |
| chatbot=gr.Chatbot( | |
| value=[{"role": "assistant", "content": get_welcome_message()}], | |
| show_label=False, | |
| type='messages' | |
| ), | |
| textbox=gr.Textbox( | |
| placeholder="Type your response here...", | |
| show_label=False, | |
| max_length=1000, | |
| submit_btn=True | |
| ) | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button("π Start New Assessment", variant="secondary") | |
| reset_btn.click( | |
| fn=reset_chat, | |
| outputs=[chatbot.chatbot, chatbot.textbox], | |
| show_progress=False | |
| ) | |
| def check_groq_connection(): | |
| try: | |
| from backend.utils import llm | |
| test_response = llm.complete("Hello") | |
| print("β Groq connection successful") | |
| return True | |
| except Exception as e: | |
| print(f"β Groq connection failed: {e}") | |
| return False | |
| def refresh_page(): | |
| """Force a complete page refresh""" | |
| return None | |
| if __name__ == "__main__": | |
| print("π Starting GraviLog Pregnancy Risk Assessment Agent...") | |
| check_groq_connection() | |
| is_hf_space = os.getenv('SPACE_ID') is not None | |
| if is_hf_space: | |
| print("π Running on Hugging Face Spaces") | |
| print("π Each page refresh will start a new conversation") | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False | |
| ) | |
| else: | |
| print("π Running locally") | |
| print("π Using Groq API for LLM processing") | |
| print("π Make sure your GROQ_API_KEY is set in environment variables") | |
| print("π Make sure your Pinecone index is set up and populated") | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=True, | |
| show_error=True | |
| ) |