Spaces:
Sleeping
Sleeping
Refactor agent and tools for session-based memory management and side effect reporting. Removed medical answer validation tool, added session memory management class, and enhanced side effect reporting with LLM classification. Updated agent functions to support session IDs for improved conversation tracking.
Browse files- core/agent.py +90 -20
- core/tools.py +62 -203
core/agent.py
CHANGED
|
@@ -17,7 +17,6 @@ from .tools import (
|
|
| 17 |
compare_providers_tool,
|
| 18 |
get_current_datetime_tool,
|
| 19 |
side_effect_recording_tool,
|
| 20 |
-
medical_answer_validation_tool,
|
| 21 |
)
|
| 22 |
|
| 23 |
# LangSmith tracing utilities
|
|
@@ -85,7 +84,6 @@ AVAILABLE_TOOLS = [
|
|
| 85 |
compare_providers_tool,
|
| 86 |
get_current_datetime_tool,
|
| 87 |
side_effect_recording_tool,
|
| 88 |
-
medical_answer_validation_tool,
|
| 89 |
]
|
| 90 |
|
| 91 |
|
|
@@ -94,11 +92,16 @@ SYSTEM_MESSAGE = """
|
|
| 94 |
You are an advanced Medical Advisor Chatbot for healthcare professionals.
|
| 95 |
Your primary purpose is to answer clinical and medical questions strictly based on authoritative medical guidelines using the tool "medical_guidelines_knowledge_tool".
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
**INSTRUCTIONS:**
|
| 98 |
- Always answer using only the information retrieved from medical guidelines via "medical_guidelines_knowledge_tool".
|
| 99 |
- **SIDE EFFECT REPORTING**: When a healthcare professional reports an adverse drug reaction, side effect, or medication-related complication, ALWAYS use the "side_effect_recording_tool" first to document the information. Return the tool's response directly to the user without modification. DO NOT use validation or generate additional reports for side effect reporting queries.
|
| 100 |
- Use the side effect recording tool when the input contains phrases like: "patient experienced", "side effect", "adverse reaction", "drug reaction", "medication caused", "developed after taking", etc.
|
| 101 |
- When the side effect recording tool requests additional information, present the request exactly as provided by the tool.
|
|
|
|
|
|
|
| 102 |
- For every answer, you MUST provide detailed citations including:
|
| 103 |
* Source file name
|
| 104 |
* Page number
|
|
@@ -166,12 +169,47 @@ def get_agent_executor():
|
|
| 166 |
max_execution_time=90, # tighten a bit to help responsiveness
|
| 167 |
)
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
# ============================================================================
|
|
@@ -304,7 +342,7 @@ def _perform_automatic_validation(user_input: str, response: str) -> str:
|
|
| 304 |
# ============================================================================
|
| 305 |
|
| 306 |
# @traceable(name="run_agent_streaming")
|
| 307 |
-
async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGenerator[str, None]:
|
| 308 |
"""
|
| 309 |
Run the agent with streaming support and comprehensive error handling.
|
| 310 |
|
|
@@ -313,6 +351,7 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
|
|
| 313 |
|
| 314 |
Args:
|
| 315 |
user_input (str): The user's input message to process
|
|
|
|
| 316 |
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 317 |
Defaults to 3.
|
| 318 |
|
|
@@ -343,7 +382,8 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
|
|
| 343 |
# Tracing for streaming disabled to avoid duplicate traces.
|
| 344 |
# We keep tracing only for the AgentExecutor in run_agent().
|
| 345 |
current_run_id = None
|
| 346 |
-
# Load conversation history from memory
|
|
|
|
| 347 |
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 348 |
|
| 349 |
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
|
@@ -547,7 +587,7 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
|
|
| 547 |
yield "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 548 |
|
| 549 |
|
| 550 |
-
async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]:
|
| 551 |
"""
|
| 552 |
Streaming wrapper function with additional safety checks and input validation.
|
| 553 |
|
|
@@ -557,6 +597,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
|
|
| 557 |
|
| 558 |
Args:
|
| 559 |
user_input (str): The user's input message to process
|
|
|
|
| 560 |
|
| 561 |
Yields:
|
| 562 |
str: Chunks of the agent's response as they are generated
|
|
@@ -585,7 +626,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
|
|
| 585 |
return
|
| 586 |
|
| 587 |
# Stream the response through the main agent function
|
| 588 |
-
async for chunk in run_agent_streaming(user_input):
|
| 589 |
yield chunk
|
| 590 |
|
| 591 |
except Exception as e:
|
|
@@ -595,7 +636,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
|
|
| 595 |
|
| 596 |
|
| 597 |
@traceable(name="run_agent")
|
| 598 |
-
async def run_agent(user_input: str, max_retries: int = 3) -> str:
|
| 599 |
"""
|
| 600 |
Run the agent with comprehensive error handling and retry logic.
|
| 601 |
|
|
@@ -605,6 +646,7 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
|
|
| 605 |
|
| 606 |
Args:
|
| 607 |
user_input (str): The user's input message to process
|
|
|
|
| 608 |
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 609 |
Defaults to 3.
|
| 610 |
|
|
@@ -626,7 +668,8 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
|
|
| 626 |
|
| 627 |
while retry_count <= max_retries:
|
| 628 |
try:
|
| 629 |
-
# Load conversation history from memory
|
|
|
|
| 630 |
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 631 |
|
| 632 |
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
|
@@ -766,7 +809,7 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
|
|
| 766 |
return "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 767 |
|
| 768 |
|
| 769 |
-
async def safe_run_agent(user_input: str) -> str:
|
| 770 |
"""
|
| 771 |
Wrapper function for run_agent with additional safety checks and input validation.
|
| 772 |
|
|
@@ -776,6 +819,7 @@ async def safe_run_agent(user_input: str) -> str:
|
|
| 776 |
|
| 777 |
Args:
|
| 778 |
user_input (str): The user's input message to process
|
|
|
|
| 779 |
|
| 780 |
Returns:
|
| 781 |
str: The agent's response or an appropriate error message in English
|
|
@@ -801,7 +845,7 @@ async def safe_run_agent(user_input: str) -> str:
|
|
| 801 |
return "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 802 |
|
| 803 |
# Process the input through the main agent function
|
| 804 |
-
return await run_agent(user_input)
|
| 805 |
|
| 806 |
except Exception as e:
|
| 807 |
logger.critical(f"Critical error in safe_run_agent: {str(e)}")
|
|
@@ -817,23 +861,49 @@ def clear_memory() -> None:
|
|
| 817 |
effectively starting a fresh conversation session.
|
| 818 |
"""
|
| 819 |
try:
|
| 820 |
-
|
| 821 |
logger.info("Conversation memory cleared successfully")
|
| 822 |
except Exception as e:
|
| 823 |
logger.error(f"Error clearing memory: {str(e)}")
|
| 824 |
|
| 825 |
|
| 826 |
-
def get_memory_summary() -> str:
|
| 827 |
"""
|
| 828 |
-
Get a summary of the
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
Returns:
|
| 831 |
str: A summary of the conversation history stored in memory
|
| 832 |
"""
|
| 833 |
try:
|
|
|
|
| 834 |
memory_vars = memory.load_memory_variables({})
|
| 835 |
return str(memory_vars.get("chat_history", "No conversation history available"))
|
| 836 |
except Exception as e:
|
| 837 |
logger.error(f"Error getting memory summary: {str(e)}")
|
| 838 |
-
return "Error retrieving
|
| 839 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
compare_providers_tool,
|
| 18 |
get_current_datetime_tool,
|
| 19 |
side_effect_recording_tool,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
# LangSmith tracing utilities
|
|
|
|
| 84 |
compare_providers_tool,
|
| 85 |
get_current_datetime_tool,
|
| 86 |
side_effect_recording_tool,
|
|
|
|
| 87 |
]
|
| 88 |
|
| 89 |
|
|
|
|
| 92 |
You are an advanced Medical Advisor Chatbot for healthcare professionals.
|
| 93 |
Your primary purpose is to answer clinical and medical questions strictly based on authoritative medical guidelines using the tool "medical_guidelines_knowledge_tool".
|
| 94 |
|
| 95 |
+
Your answers must be concise, medically informative, evidence-based responses in an authoritative, precise, and clinical tone.
|
| 96 |
+
You will be responding to practicing medical professionals so adjust your answer and language accordingly.
|
| 97 |
+
|
| 98 |
**INSTRUCTIONS:**
|
| 99 |
- Always answer using only the information retrieved from medical guidelines via "medical_guidelines_knowledge_tool".
|
| 100 |
- **SIDE EFFECT REPORTING**: When a healthcare professional reports an adverse drug reaction, side effect, or medication-related complication, ALWAYS use the "side_effect_recording_tool" first to document the information. Return the tool's response directly to the user without modification. DO NOT use validation or generate additional reports for side effect reporting queries.
|
| 101 |
- Use the side effect recording tool when the input contains phrases like: "patient experienced", "side effect", "adverse reaction", "drug reaction", "medication caused", "developed after taking", etc.
|
| 102 |
- When the side effect recording tool requests additional information, present the request exactly as provided by the tool.
|
| 103 |
+
- **PROVIDER COMPARISON**: When the user asks to compare guidance between two providers (e.g., "compare NCCN vs ESMO on ..."), use the "compare_providers_tool" with appropriate `provider_a` and `provider_b` values to retrieve side-by-side, cited results.
|
| 104 |
+
- **TIME/DATE QUERIES**: For any questions about the current date/time or references like "today" or "now", use the "get_current_datetime_tool". Treat this tool as the only reliable source of current time information.
|
| 105 |
- For every answer, you MUST provide detailed citations including:
|
| 106 |
* Source file name
|
| 107 |
* Page number
|
|
|
|
| 169 |
max_execution_time=90, # tighten a bit to help responsiveness
|
| 170 |
)
|
| 171 |
|
| 172 |
+
# ============================================================================
|
| 173 |
+
# SESSION-BASED MEMORY MANAGEMENT
|
| 174 |
+
# ============================================================================
|
| 175 |
+
|
| 176 |
+
class SessionMemoryManager:
|
| 177 |
+
"""Manages conversation memory for multiple sessions."""
|
| 178 |
+
|
| 179 |
+
def __init__(self):
|
| 180 |
+
self._sessions = {}
|
| 181 |
+
self._default_window_size = 10
|
| 182 |
+
|
| 183 |
+
def get_memory(self, session_id: str = "default") -> ConversationBufferWindowMemory:
|
| 184 |
+
"""Get or create memory for a specific session."""
|
| 185 |
+
if session_id not in self._sessions:
|
| 186 |
+
self._sessions[session_id] = ConversationBufferWindowMemory(
|
| 187 |
+
memory_key="chat_history",
|
| 188 |
+
return_messages=True,
|
| 189 |
+
max_window_size=self._default_window_size
|
| 190 |
+
)
|
| 191 |
+
return self._sessions[session_id]
|
| 192 |
+
|
| 193 |
+
def clear_session(self, session_id: str) -> bool:
|
| 194 |
+
"""Clear memory for a specific session."""
|
| 195 |
+
if session_id in self._sessions:
|
| 196 |
+
self._sessions[session_id].clear()
|
| 197 |
+
del self._sessions[session_id]
|
| 198 |
+
return True
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
def clear_all_sessions(self):
|
| 202 |
+
"""Clear all session memories."""
|
| 203 |
+
for memory in self._sessions.values():
|
| 204 |
+
memory.clear()
|
| 205 |
+
self._sessions.clear()
|
| 206 |
+
|
| 207 |
+
def get_active_sessions(self) -> list:
|
| 208 |
+
"""Get list of active session IDs."""
|
| 209 |
+
return list(self._sessions.keys())
|
| 210 |
+
|
| 211 |
+
# Global session memory manager
|
| 212 |
+
_memory_manager = SessionMemoryManager()
|
| 213 |
|
| 214 |
|
| 215 |
# ============================================================================
|
|
|
|
| 342 |
# ============================================================================
|
| 343 |
|
| 344 |
# @traceable(name="run_agent_streaming")
|
| 345 |
+
async def run_agent_streaming(user_input: str, session_id: str = "default", max_retries: int = 3) -> AsyncGenerator[str, None]:
|
| 346 |
"""
|
| 347 |
Run the agent with streaming support and comprehensive error handling.
|
| 348 |
|
|
|
|
| 351 |
|
| 352 |
Args:
|
| 353 |
user_input (str): The user's input message to process
|
| 354 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 355 |
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 356 |
Defaults to 3.
|
| 357 |
|
|
|
|
| 382 |
# Tracing for streaming disabled to avoid duplicate traces.
|
| 383 |
# We keep tracing only for the AgentExecutor in run_agent().
|
| 384 |
current_run_id = None
|
| 385 |
+
# Load conversation history from session-specific memory
|
| 386 |
+
memory = _memory_manager.get_memory(session_id)
|
| 387 |
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 388 |
|
| 389 |
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
|
|
|
| 587 |
yield "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 588 |
|
| 589 |
|
| 590 |
+
async def safe_run_agent_streaming(user_input: str, session_id: str = "default") -> AsyncGenerator[str, None]:
|
| 591 |
"""
|
| 592 |
Streaming wrapper function with additional safety checks and input validation.
|
| 593 |
|
|
|
|
| 597 |
|
| 598 |
Args:
|
| 599 |
user_input (str): The user's input message to process
|
| 600 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 601 |
|
| 602 |
Yields:
|
| 603 |
str: Chunks of the agent's response as they are generated
|
|
|
|
| 626 |
return
|
| 627 |
|
| 628 |
# Stream the response through the main agent function
|
| 629 |
+
async for chunk in run_agent_streaming(user_input, session_id):
|
| 630 |
yield chunk
|
| 631 |
|
| 632 |
except Exception as e:
|
|
|
|
| 636 |
|
| 637 |
|
| 638 |
@traceable(name="run_agent")
|
| 639 |
+
async def run_agent(user_input: str, session_id: str = "default", max_retries: int = 3) -> str:
|
| 640 |
"""
|
| 641 |
Run the agent with comprehensive error handling and retry logic.
|
| 642 |
|
|
|
|
| 646 |
|
| 647 |
Args:
|
| 648 |
user_input (str): The user's input message to process
|
| 649 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 650 |
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 651 |
Defaults to 3.
|
| 652 |
|
|
|
|
| 668 |
|
| 669 |
while retry_count <= max_retries:
|
| 670 |
try:
|
| 671 |
+
# Load conversation history from session-specific memory
|
| 672 |
+
memory = _memory_manager.get_memory(session_id)
|
| 673 |
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 674 |
|
| 675 |
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
|
|
|
| 809 |
return "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 810 |
|
| 811 |
|
| 812 |
+
async def safe_run_agent(user_input: str, session_id: str = "default") -> str:
|
| 813 |
"""
|
| 814 |
Wrapper function for run_agent with additional safety checks and input validation.
|
| 815 |
|
|
|
|
| 819 |
|
| 820 |
Args:
|
| 821 |
user_input (str): The user's input message to process
|
| 822 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 823 |
|
| 824 |
Returns:
|
| 825 |
str: The agent's response or an appropriate error message in English
|
|
|
|
| 845 |
return "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 846 |
|
| 847 |
# Process the input through the main agent function
|
| 848 |
+
return await run_agent(user_input, session_id)
|
| 849 |
|
| 850 |
except Exception as e:
|
| 851 |
logger.critical(f"Critical error in safe_run_agent: {str(e)}")
|
|
|
|
| 861 |
effectively starting a fresh conversation session.
|
| 862 |
"""
|
| 863 |
try:
|
| 864 |
+
_memory_manager.clear_all_sessions()
|
| 865 |
logger.info("Conversation memory cleared successfully")
|
| 866 |
except Exception as e:
|
| 867 |
logger.error(f"Error clearing memory: {str(e)}")
|
| 868 |
|
| 869 |
|
| 870 |
+
def get_memory_summary(session_id: str = "default") -> str:
|
| 871 |
"""
|
| 872 |
+
Get a summary of the conversation history for a specific session.
|
| 873 |
+
|
| 874 |
+
Args:
|
| 875 |
+
session_id (str, optional): Session identifier. Defaults to "default".
|
| 876 |
|
| 877 |
Returns:
|
| 878 |
str: A summary of the conversation history stored in memory
|
| 879 |
"""
|
| 880 |
try:
|
| 881 |
+
memory = _memory_manager.get_memory(session_id)
|
| 882 |
memory_vars = memory.load_memory_variables({})
|
| 883 |
return str(memory_vars.get("chat_history", "No conversation history available"))
|
| 884 |
except Exception as e:
|
| 885 |
logger.error(f"Error getting memory summary: {str(e)}")
|
| 886 |
+
return "Error retrieving conversation history"
|
| 887 |
|
| 888 |
+
|
| 889 |
+
def clear_session_memory(session_id: str) -> bool:
|
| 890 |
+
"""
|
| 891 |
+
Clear conversation memory for a specific session.
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
session_id (str): Session identifier to clear
|
| 895 |
+
|
| 896 |
+
Returns:
|
| 897 |
+
bool: True if session was cleared, False if session didn't exist
|
| 898 |
+
"""
|
| 899 |
+
return _memory_manager.clear_session(session_id)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def get_active_sessions() -> list:
|
| 903 |
+
"""
|
| 904 |
+
Get list of all active session IDs.
|
| 905 |
+
|
| 906 |
+
Returns:
|
| 907 |
+
list: List of active session identifiers
|
| 908 |
+
"""
|
| 909 |
+
return _memory_manager.get_active_sessions()
|
core/tools.py
CHANGED
|
@@ -6,11 +6,12 @@ from datetime import datetime
|
|
| 6 |
from typing import Optional, List
|
| 7 |
|
| 8 |
import pytz
|
| 9 |
-
from langchain.schema import Document
|
| 10 |
from langchain.tools import tool
|
| 11 |
from .retrievers import hybrid_search, vector_search, bm25_search
|
| 12 |
from .validation import validate_medical_answer
|
| 13 |
from .github_storage import get_github_storage
|
|
|
|
| 14 |
|
| 15 |
CANONICAL_PROVIDERS = {"Manus", "ASCO", "NCCN", "ESMO", "NICE"}
|
| 16 |
|
|
@@ -28,6 +29,38 @@ def store_user_question(user_question: str):
|
|
| 28 |
global _last_user_question
|
| 29 |
_last_user_question = user_question
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Map lowercase variants and full names to canonical provider codes
|
| 32 |
_PROVIDER_ALIASES = {
|
| 33 |
# NCCN
|
|
@@ -289,7 +322,7 @@ def side_effect_recording_tool(user_input: str) -> str:
|
|
| 289 |
str: Interactive form for collecting missing information or confirmation of data recording
|
| 290 |
"""
|
| 291 |
try:
|
| 292 |
-
#
|
| 293 |
side_effect_keywords = [
|
| 294 |
'side effect', 'adverse reaction', 'adverse event', 'drug reaction',
|
| 295 |
'medication reaction', 'allergic reaction', 'complication', 'toxicity',
|
|
@@ -297,7 +330,6 @@ def side_effect_recording_tool(user_input: str) -> str:
|
|
| 297 |
'overdose', 'poisoning', 'drug-induced', 'medication-induced',
|
| 298 |
'experienced after taking', 'developed after', 'caused by medication',
|
| 299 |
'drug-related', 'medication-related', 'pharmaceutical reaction',
|
| 300 |
-
# Add more comprehensive problem/symptom keywords
|
| 301 |
'kidney problems', 'liver problems', 'heart problems', 'breathing problems',
|
| 302 |
'skin problems', 'stomach problems', 'nausea', 'vomiting', 'diarrhea',
|
| 303 |
'headache', 'dizziness', 'fatigue', 'weakness', 'rash', 'swelling',
|
|
@@ -305,8 +337,8 @@ def side_effect_recording_tool(user_input: str) -> str:
|
|
| 305 |
'has these', 'has serious', 'causes', 'resulted in', 'led to',
|
| 306 |
'problems with', 'issues with', 'complications from'
|
| 307 |
]
|
| 308 |
-
|
| 309 |
input_lower = user_input.lower().strip()
|
|
|
|
| 310 |
|
| 311 |
# Check for special commands first
|
| 312 |
if input_lower in ['save report', 'save', 'submit report', 'submit']:
|
|
@@ -324,8 +356,9 @@ def side_effect_recording_tool(user_input: str) -> str:
|
|
| 324 |
extracted_data = _extract_side_effect_data(user_input)
|
| 325 |
return _process_followup_response(user_input, extracted_data)
|
| 326 |
|
| 327 |
-
#
|
| 328 |
-
|
|
|
|
| 329 |
|
| 330 |
if not contains_side_effect:
|
| 331 |
return "This input does not appear to contain a side effect report. If you are reporting an adverse drug reaction, please include specific details about the medication and symptoms."
|
|
@@ -532,74 +565,23 @@ def _save_side_effect_report(extracted_data: dict) -> str:
|
|
| 532 |
# Ensure the value is properly formatted
|
| 533 |
extracted_data[field] = str(value).strip()
|
| 534 |
|
| 535 |
-
# Save to GitHub repository
|
| 536 |
-
|
| 537 |
-
|
| 538 |
github_storage = get_github_storage()
|
| 539 |
-
|
| 540 |
-
|
| 541 |
success = github_storage.save_side_effects_report(extracted_data)
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
if not success:
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
# Fallback to local storage if GitHub fails
|
| 551 |
-
|
| 552 |
-
|
| 553 |
csv_filename = "side_effects_reports.csv"
|
| 554 |
-
|
| 555 |
-
|
| 556 |
csv_path = os.path.join(os.getcwd(), csv_filename)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
file_exists = os.path.exists(csv_path)
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
with open(csv_path, 'a', newline='', encoding='utf-8') as csvfile:
|
| 569 |
-
|
| 570 |
-
|
| 571 |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
if not file_exists:
|
| 578 |
-
|
| 579 |
-
|
| 580 |
writer.writeheader()
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
writer.writerow(extracted_data)
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
storage_location = "locally to side_effects_reports.csv (GitHub upload failed)"
|
| 593 |
-
|
| 594 |
-
|
| 595 |
else:
|
| 596 |
-
|
| 597 |
-
|
| 598 |
storage_location = "to GitHub cloud repository"
|
| 599 |
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
# Generate confirmation message
|
| 603 |
drug_name = extracted_data.get('drug_name', 'NaN')
|
| 604 |
side_effects = extracted_data.get('side_effects', 'NaN')
|
| 605 |
report_id = extracted_data['timestamp'].replace(':', '').replace('-', '').replace(' ', '_')
|
|
@@ -661,14 +643,10 @@ def _extract_side_effect_data_with_llm(user_input: str) -> dict:
|
|
| 661 |
Returns:
|
| 662 |
dict: Structured data extracted from the input
|
| 663 |
"""
|
| 664 |
-
from langchain.llms import OpenAI
|
| 665 |
-
from langchain.prompts import PromptTemplate
|
| 666 |
import json
|
| 667 |
-
|
| 668 |
# Get current timestamp
|
| 669 |
egypt_tz = pytz.timezone('Africa/Cairo')
|
| 670 |
current_time = datetime.now(egypt_tz).strftime('%Y-%m-%d %H:%M:%S')
|
| 671 |
-
|
| 672 |
# Initialize extracted data with defaults
|
| 673 |
extracted_data = {
|
| 674 |
'timestamp': current_time,
|
|
@@ -682,65 +660,32 @@ def _extract_side_effect_data_with_llm(user_input: str) -> dict:
|
|
| 682 |
'outcome': 'NaN',
|
| 683 |
'additional_details': 'NaN',
|
| 684 |
'reporter_info': 'NaN',
|
| 685 |
-
'raw_input': user_input[:500]
|
| 686 |
}
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
extraction_prompt = PromptTemplate(
|
| 690 |
-
input_variables=["user_input"],
|
| 691 |
-
template="""Extract medical side effect information from the following text. Return ONLY a JSON object with these exact fields:
|
| 692 |
-
|
| 693 |
-
{
|
| 694 |
-
"drug_name": "name of the medication/drug mentioned",
|
| 695 |
-
"side_effects": "list of side effects or symptoms described",
|
| 696 |
-
"patient_age": "patient's age if mentioned",
|
| 697 |
-
"patient_gender": "Male or Female if mentioned",
|
| 698 |
-
"dosage": "medication dosage if mentioned",
|
| 699 |
-
"duration": "treatment duration if mentioned (e.g., '3 months', '2 weeks')",
|
| 700 |
-
"severity": "mild, moderate, or severe if mentioned",
|
| 701 |
-
"outcome": "current status like ongoing, resolved, recovered if mentioned"
|
| 702 |
-
}
|
| 703 |
-
|
| 704 |
-
IMPORTANT RULES:
|
| 705 |
-
1. If any information is not found or unclear, use "NaN" as the value
|
| 706 |
-
2. For duration, look for phrases like "Treatment duration: 3 months", "for 2 weeks", "over 6 months", etc.
|
| 707 |
-
3. Extract exact values as mentioned in the text
|
| 708 |
-
4. Return ONLY the JSON object, no other text
|
| 709 |
-
|
| 710 |
-
Text to analyze:
|
| 711 |
-
{user_input}
|
| 712 |
-
|
| 713 |
-
JSON:"""
|
| 714 |
-
)
|
| 715 |
-
|
| 716 |
-
try:
|
| 717 |
-
# Try to use LLM extraction if available
|
| 718 |
try:
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
except Exception:
|
| 729 |
-
# Fallback to improved regex extraction
|
| 730 |
extracted_json = _extract_with_improved_regex(user_input)
|
| 731 |
-
|
| 732 |
-
# Update extracted_data with LLM results, keeping NaN for empty values
|
| 733 |
-
for key, value in extracted_json.items():
|
| 734 |
-
if key in extracted_data and value and str(value).strip() and str(value).strip().lower() != 'nan':
|
| 735 |
-
extracted_data[key] = str(value).strip()
|
| 736 |
-
|
| 737 |
-
except Exception as e:
|
| 738 |
-
# If LLM extraction fails, use improved regex fallback
|
| 739 |
extracted_json = _extract_with_improved_regex(user_input)
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
return extracted_data
|
| 745 |
|
| 746 |
|
|
@@ -885,89 +830,3 @@ def _extract_side_effect_data(user_input: str) -> dict:
|
|
| 885 |
return _extract_side_effect_data_with_llm(user_input)
|
| 886 |
|
| 887 |
|
| 888 |
-
@tool
|
| 889 |
-
def medical_answer_validation_tool(
|
| 890 |
-
question: Optional[str] = None,
|
| 891 |
-
retrieved_documents: Optional[List] = None,
|
| 892 |
-
generated_answer: Optional[str] = None
|
| 893 |
-
) -> str:
|
| 894 |
-
"""
|
| 895 |
-
Validate a medical answer using the comprehensive validation system.
|
| 896 |
-
|
| 897 |
-
This tool evaluates medical responses across 6 criteria: Accuracy, Coherence,
|
| 898 |
-
Relevance, Completeness, Citations/Attribution, and Length.
|
| 899 |
-
|
| 900 |
-
Args:
|
| 901 |
-
question: The original medical question (optional - uses stored context if not provided)
|
| 902 |
-
retrieved_documents: List of documents used for the answer (optional - uses stored context)
|
| 903 |
-
generated_answer: The AI-generated answer to validate (optional - uses stored context)
|
| 904 |
-
|
| 905 |
-
Returns:
|
| 906 |
-
str: Formatted validation report with scores and improvement recommendations
|
| 907 |
-
"""
|
| 908 |
-
global _last_question, _last_documents, _last_answer, _last_user_question
|
| 909 |
-
|
| 910 |
-
try:
|
| 911 |
-
# Use provided parameters or fall back to stored context
|
| 912 |
-
# Prefer the original user question over the tool query
|
| 913 |
-
eval_question = question or _last_user_question or _last_question
|
| 914 |
-
eval_documents = retrieved_documents or _last_documents or []
|
| 915 |
-
eval_answer = generated_answer or _last_answer
|
| 916 |
-
|
| 917 |
-
# Validate that we have the required information
|
| 918 |
-
if not eval_question:
|
| 919 |
-
return "Error: No question available for validation. Please provide a question or ensure medical_guidelines_knowledge_tool was used first."
|
| 920 |
-
|
| 921 |
-
if not eval_answer:
|
| 922 |
-
return "Error: No answer available for validation. Please provide an answer to validate."
|
| 923 |
-
|
| 924 |
-
if not eval_documents:
|
| 925 |
-
return "Warning: No retrieved documents available for validation. Validation will proceed with limited context."
|
| 926 |
-
|
| 927 |
-
# Store the answer for future reference
|
| 928 |
-
if generated_answer:
|
| 929 |
-
_last_answer = generated_answer
|
| 930 |
-
|
| 931 |
-
# Perform validation
|
| 932 |
-
evaluation = validate_medical_answer(eval_question, eval_documents, eval_answer)
|
| 933 |
-
|
| 934 |
-
# Format the validation report for display
|
| 935 |
-
report = evaluation.get("validation_report", {})
|
| 936 |
-
|
| 937 |
-
formatted_report = f"""
|
| 938 |
-
**🔍 MEDICAL ANSWER VALIDATION REPORT**
|
| 939 |
-
|
| 940 |
-
**Interaction ID:** {evaluation.get('interaction_id', 'N/A')}
|
| 941 |
-
**Timestamp:** {evaluation.get('timestamp', 'N/A')}
|
| 942 |
-
|
| 943 |
-
**Overall Score:** {report.get('Overall_Rating', 'N/A')}/100
|
| 944 |
-
|
| 945 |
-
**Key Metrics:**
|
| 946 |
-
|
| 947 |
-
**Accuracy:** {report.get('Accuracy_Rating', 'N/A')}/100
|
| 948 |
-
{report.get('Accuracy_Comment', 'No comment available')}
|
| 949 |
-
|
| 950 |
-
**Coherence:** {report.get('Coherence_Rating', 'N/A')}/100
|
| 951 |
-
{report.get('Coherence_Comment', 'No comment available')}
|
| 952 |
-
|
| 953 |
-
**Relevance:** {report.get('Relevance_Rating', 'N/A')}/100
|
| 954 |
-
{report.get('Relevance_Comment', 'No comment available')}
|
| 955 |
-
|
| 956 |
-
**Completeness:** {report.get('Completeness_Rating', 'N/A')}/100
|
| 957 |
-
{report.get('Completeness_Comment', 'No comment available')}
|
| 958 |
-
|
| 959 |
-
**Citations:** {report.get('Citations_Attribution_Rating', 'N/A')}/100
|
| 960 |
-
{report.get('Citations_Attribution_Comment', 'No comment available')}
|
| 961 |
-
|
| 962 |
-
**Length:** {report.get('Length_Rating', 'N/A')}/100
|
| 963 |
-
{report.get('Length_Comment', 'No comment available')}
|
| 964 |
-
|
| 965 |
-
**Assessment:** {report.get('Final_Summary_and_Improvement_Plan', 'No improvement plan available')}
|
| 966 |
-
|
| 967 |
-
**📁 Data Storage:** Evaluation saved to evaluation_results.json
|
| 968 |
-
"""
|
| 969 |
-
|
| 970 |
-
return formatted_report.strip()
|
| 971 |
-
|
| 972 |
-
except Exception as e:
|
| 973 |
-
return f"Validation error: {str(e)}. Please ensure all required parameters are provided or that context is available from previous tool usage."
|
|
|
|
| 6 |
from typing import Optional, List
|
| 7 |
|
| 8 |
import pytz
|
| 9 |
+
from langchain.schema import Document, HumanMessage, SystemMessage
|
| 10 |
from langchain.tools import tool
|
| 11 |
from .retrievers import hybrid_search, vector_search, bm25_search
|
| 12 |
from .validation import validate_medical_answer
|
| 13 |
from .github_storage import get_github_storage
|
| 14 |
+
from langchain_openai import ChatOpenAI
|
| 15 |
|
| 16 |
CANONICAL_PROVIDERS = {"Manus", "ASCO", "NCCN", "ESMO", "NICE"}
|
| 17 |
|
|
|
|
| 29 |
global _last_user_question
|
| 30 |
_last_user_question = user_question
|
| 31 |
|
| 32 |
+
def _get_llm_safe(temperature: float = 0.0, model: str = "gpt-4o"):
|
| 33 |
+
"""Create a ChatOpenAI client if API key/config is available, else return None."""
|
| 34 |
+
try:
|
| 35 |
+
# ChatOpenAI will read OPENAI_API_KEY from env as in validation.py
|
| 36 |
+
return ChatOpenAI(model=model, temperature=temperature, max_tokens=512, request_timeout=30)
|
| 37 |
+
except Exception:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
def _is_side_effect_report_llm(user_input: str) -> Optional[bool]:
|
| 41 |
+
"""Use LLM to classify if input is an adverse drug reaction/side-effect report.
|
| 42 |
+
Returns True/False if confident, or None if unavailable/uncertain.
|
| 43 |
+
"""
|
| 44 |
+
llm = _get_llm_safe()
|
| 45 |
+
if not llm:
|
| 46 |
+
return None
|
| 47 |
+
try:
|
| 48 |
+
system = SystemMessage(content=(
|
| 49 |
+
"You are a medical triage classifier. Decide if the user's text is a report of an adverse drug reaction (side effect) about a medication.\n"
|
| 50 |
+
"Criteria: mentions a medication/drug and symptoms or adverse effects experienced by a patient.\n"
|
| 51 |
+
"Respond with exactly one token: yes or no."
|
| 52 |
+
))
|
| 53 |
+
human = HumanMessage(content=user_input[:1500])
|
| 54 |
+
resp = llm.invoke([system, human])
|
| 55 |
+
ans = (resp.content or "").strip().lower()
|
| 56 |
+
if ans.startswith("yes"):
|
| 57 |
+
return True
|
| 58 |
+
if ans.startswith("no"):
|
| 59 |
+
return False
|
| 60 |
+
return None
|
| 61 |
+
except Exception:
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
# Map lowercase variants and full names to canonical provider codes
|
| 65 |
_PROVIDER_ALIASES = {
|
| 66 |
# NCCN
|
|
|
|
| 322 |
str: Interactive form for collecting missing information or confirmation of data recording
|
| 323 |
"""
|
| 324 |
try:
|
| 325 |
+
# LLM classification (preferred), with keyword fallback to preserve behavior
|
| 326 |
side_effect_keywords = [
|
| 327 |
'side effect', 'adverse reaction', 'adverse event', 'drug reaction',
|
| 328 |
'medication reaction', 'allergic reaction', 'complication', 'toxicity',
|
|
|
|
| 330 |
'overdose', 'poisoning', 'drug-induced', 'medication-induced',
|
| 331 |
'experienced after taking', 'developed after', 'caused by medication',
|
| 332 |
'drug-related', 'medication-related', 'pharmaceutical reaction',
|
|
|
|
| 333 |
'kidney problems', 'liver problems', 'heart problems', 'breathing problems',
|
| 334 |
'skin problems', 'stomach problems', 'nausea', 'vomiting', 'diarrhea',
|
| 335 |
'headache', 'dizziness', 'fatigue', 'weakness', 'rash', 'swelling',
|
|
|
|
| 337 |
'has these', 'has serious', 'causes', 'resulted in', 'led to',
|
| 338 |
'problems with', 'issues with', 'complications from'
|
| 339 |
]
|
|
|
|
| 340 |
input_lower = user_input.lower().strip()
|
| 341 |
+
llm_decision = _is_side_effect_report_llm(user_input)
|
| 342 |
|
| 343 |
# Check for special commands first
|
| 344 |
if input_lower in ['save report', 'save', 'submit report', 'submit']:
|
|
|
|
| 356 |
extracted_data = _extract_side_effect_data(user_input)
|
| 357 |
return _process_followup_response(user_input, extracted_data)
|
| 358 |
|
| 359 |
+
# Combine LLM decision with keyword fallback to avoid behavior regression
|
| 360 |
+
keyword_detected = any(keyword in input_lower for keyword in side_effect_keywords)
|
| 361 |
+
contains_side_effect = (llm_decision is True) or (llm_decision is not False and keyword_detected)
|
| 362 |
|
| 363 |
if not contains_side_effect:
|
| 364 |
return "This input does not appear to contain a side effect report. If you are reporting an adverse drug reaction, please include specific details about the medication and symptoms."
|
|
|
|
| 565 |
# Ensure the value is properly formatted
|
| 566 |
extracted_data[field] = str(value).strip()
|
| 567 |
|
| 568 |
+
# Save to GitHub repository (fallback to local if needed)
|
|
|
|
|
|
|
| 569 |
github_storage = get_github_storage()
|
|
|
|
|
|
|
| 570 |
success = github_storage.save_side_effects_report(extracted_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
if not success:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
csv_filename = "side_effects_reports.csv"
|
|
|
|
|
|
|
| 573 |
csv_path = os.path.join(os.getcwd(), csv_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
file_exists = os.path.exists(csv_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
with open(csv_path, 'a', newline='', encoding='utf-8') as csvfile:
|
|
|
|
|
|
|
| 576 |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
if not file_exists:
|
|
|
|
|
|
|
| 578 |
writer.writeheader()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
writer.writerow(extracted_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
storage_location = "locally to side_effects_reports.csv (GitHub upload failed)"
|
|
|
|
|
|
|
| 581 |
else:
|
|
|
|
|
|
|
| 582 |
storage_location = "to GitHub cloud repository"
|
| 583 |
|
| 584 |
+
# Generate confirmation message
|
|
|
|
|
|
|
| 585 |
drug_name = extracted_data.get('drug_name', 'NaN')
|
| 586 |
side_effects = extracted_data.get('side_effects', 'NaN')
|
| 587 |
report_id = extracted_data['timestamp'].replace(':', '').replace('-', '').replace(' ', '_')
|
|
|
|
| 643 |
Returns:
|
| 644 |
dict: Structured data extracted from the input
|
| 645 |
"""
|
|
|
|
|
|
|
| 646 |
import json
|
|
|
|
| 647 |
# Get current timestamp
|
| 648 |
egypt_tz = pytz.timezone('Africa/Cairo')
|
| 649 |
current_time = datetime.now(egypt_tz).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
| 650 |
# Initialize extracted data with defaults
|
| 651 |
extracted_data = {
|
| 652 |
'timestamp': current_time,
|
|
|
|
| 660 |
'outcome': 'NaN',
|
| 661 |
'additional_details': 'NaN',
|
| 662 |
'reporter_info': 'NaN',
|
| 663 |
+
'raw_input': user_input[:500]
|
| 664 |
}
|
| 665 |
+
llm = _get_llm_safe()
|
| 666 |
+
if llm:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
try:
|
| 668 |
+
system = SystemMessage(content=(
|
| 669 |
+
"Extract medical side effect information. Return ONLY a JSON object with these exact fields: "
|
| 670 |
+
"drug_name, side_effects, patient_age, patient_gender, dosage, duration, severity, outcome. "
|
| 671 |
+
"If missing/unclear, use 'NaN'."
|
| 672 |
+
))
|
| 673 |
+
human = HumanMessage(content=user_input[:2000])
|
| 674 |
+
response = llm.invoke([system, human])
|
| 675 |
+
text = (response.content or "").strip()
|
| 676 |
+
# Try parse; if fails, fallback regex
|
| 677 |
+
try:
|
| 678 |
+
extracted_json = json.loads(text)
|
| 679 |
+
except json.JSONDecodeError:
|
| 680 |
+
extracted_json = _extract_with_improved_regex(user_input)
|
| 681 |
except Exception:
|
|
|
|
| 682 |
extracted_json = _extract_with_improved_regex(user_input)
|
| 683 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
extracted_json = _extract_with_improved_regex(user_input)
|
| 685 |
+
# Update extracted_data
|
| 686 |
+
for key, value in extracted_json.items():
|
| 687 |
+
if key in extracted_data and value and str(value).strip() and str(value).strip().lower() != 'nan':
|
| 688 |
+
extracted_data[key] = str(value).strip()
|
| 689 |
return extracted_data
|
| 690 |
|
| 691 |
|
|
|
|
| 830 |
return _extract_side_effect_data_with_llm(user_input)
|
| 831 |
|
| 832 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|