moazx commited on
Commit
2587e2c
·
1 Parent(s): ecaa031

Refactor agent and tools for session-based memory management and side effect reporting. Removed medical answer validation tool, added session memory management class, and enhanced side effect reporting with LLM classification. Updated agent functions to support session IDs for improved conversation tracking.

Browse files
Files changed (2) hide show
  1. core/agent.py +90 -20
  2. core/tools.py +62 -203
core/agent.py CHANGED
@@ -17,7 +17,6 @@ from .tools import (
17
  compare_providers_tool,
18
  get_current_datetime_tool,
19
  side_effect_recording_tool,
20
- medical_answer_validation_tool,
21
  )
22
 
23
  # LangSmith tracing utilities
@@ -85,7 +84,6 @@ AVAILABLE_TOOLS = [
85
  compare_providers_tool,
86
  get_current_datetime_tool,
87
  side_effect_recording_tool,
88
- medical_answer_validation_tool,
89
  ]
90
 
91
 
@@ -94,11 +92,16 @@ SYSTEM_MESSAGE = """
94
  You are an advanced Medical Advisor Chatbot for healthcare professionals.
95
  Your primary purpose is to answer clinical and medical questions strictly based on authoritative medical guidelines using the tool "medical_guidelines_knowledge_tool".
96
 
 
 
 
97
  **INSTRUCTIONS:**
98
  - Always answer using only the information retrieved from medical guidelines via "medical_guidelines_knowledge_tool".
99
  - **SIDE EFFECT REPORTING**: When a healthcare professional reports an adverse drug reaction, side effect, or medication-related complication, ALWAYS use the "side_effect_recording_tool" first to document the information. Return the tool's response directly to the user without modification. DO NOT use validation or generate additional reports for side effect reporting queries.
100
  - Use the side effect recording tool when the input contains phrases like: "patient experienced", "side effect", "adverse reaction", "drug reaction", "medication caused", "developed after taking", etc.
101
  - When the side effect recording tool requests additional information, present the request exactly as provided by the tool.
 
 
102
  - For every answer, you MUST provide detailed citations including:
103
  * Source file name
104
  * Page number
@@ -166,12 +169,47 @@ def get_agent_executor():
166
  max_execution_time=90, # tighten a bit to help responsiveness
167
  )
168
 
169
- # Initialize memory
170
- memory = ConversationBufferWindowMemory(
171
- memory_key="chat_history",
172
- return_messages=True,
173
- max_window_size=10
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
  # ============================================================================
@@ -304,7 +342,7 @@ def _perform_automatic_validation(user_input: str, response: str) -> str:
304
  # ============================================================================
305
 
306
  # @traceable(name="run_agent_streaming")
307
- async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGenerator[str, None]:
308
  """
309
  Run the agent with streaming support and comprehensive error handling.
310
 
@@ -313,6 +351,7 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
313
 
314
  Args:
315
  user_input (str): The user's input message to process
 
316
  max_retries (int, optional): Maximum number of retries for recoverable errors.
317
  Defaults to 3.
318
 
@@ -343,7 +382,8 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
343
  # Tracing for streaming disabled to avoid duplicate traces.
344
  # We keep tracing only for the AgentExecutor in run_agent().
345
  current_run_id = None
346
- # Load conversation history from memory
 
347
  chat_history = memory.load_memory_variables({})["chat_history"]
348
 
349
  logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
@@ -547,7 +587,7 @@ async def run_agent_streaming(user_input: str, max_retries: int = 3) -> AsyncGen
547
  yield "Sorry, I was unable to process your request after several attempts. Please try again later."
548
 
549
 
550
- async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]:
551
  """
552
  Streaming wrapper function with additional safety checks and input validation.
553
 
@@ -557,6 +597,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
557
 
558
  Args:
559
  user_input (str): The user's input message to process
 
560
 
561
  Yields:
562
  str: Chunks of the agent's response as they are generated
@@ -585,7 +626,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
585
  return
586
 
587
  # Stream the response through the main agent function
588
- async for chunk in run_agent_streaming(user_input):
589
  yield chunk
590
 
591
  except Exception as e:
@@ -595,7 +636,7 @@ async def safe_run_agent_streaming(user_input: str) -> AsyncGenerator[str, None]
595
 
596
 
597
  @traceable(name="run_agent")
598
- async def run_agent(user_input: str, max_retries: int = 3) -> str:
599
  """
600
  Run the agent with comprehensive error handling and retry logic.
601
 
@@ -605,6 +646,7 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
605
 
606
  Args:
607
  user_input (str): The user's input message to process
 
608
  max_retries (int, optional): Maximum number of retries for recoverable errors.
609
  Defaults to 3.
610
 
@@ -626,7 +668,8 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
626
 
627
  while retry_count <= max_retries:
628
  try:
629
- # Load conversation history from memory
 
630
  chat_history = memory.load_memory_variables({})["chat_history"]
631
 
632
  logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
@@ -766,7 +809,7 @@ async def run_agent(user_input: str, max_retries: int = 3) -> str:
766
  return "Sorry, I was unable to process your request after several attempts. Please try again later."
767
 
768
 
769
- async def safe_run_agent(user_input: str) -> str:
770
  """
771
  Wrapper function for run_agent with additional safety checks and input validation.
772
 
@@ -776,6 +819,7 @@ async def safe_run_agent(user_input: str) -> str:
776
 
777
  Args:
778
  user_input (str): The user's input message to process
 
779
 
780
  Returns:
781
  str: The agent's response or an appropriate error message in English
@@ -801,7 +845,7 @@ async def safe_run_agent(user_input: str) -> str:
801
  return "Sorry, I didn't receive any questions. Please enter your question or request."
802
 
803
  # Process the input through the main agent function
804
- return await run_agent(user_input)
805
 
806
  except Exception as e:
807
  logger.critical(f"Critical error in safe_run_agent: {str(e)}")
@@ -817,23 +861,49 @@ def clear_memory() -> None:
817
  effectively starting a fresh conversation session.
818
  """
819
  try:
820
- memory.clear()
821
  logger.info("Conversation memory cleared successfully")
822
  except Exception as e:
823
  logger.error(f"Error clearing memory: {str(e)}")
824
 
825
 
826
- def get_memory_summary() -> str:
827
  """
828
- Get a summary of the current conversation memory.
 
 
 
829
 
830
  Returns:
831
  str: A summary of the conversation history stored in memory
832
  """
833
  try:
 
834
  memory_vars = memory.load_memory_variables({})
835
  return str(memory_vars.get("chat_history", "No conversation history available"))
836
  except Exception as e:
837
  logger.error(f"Error getting memory summary: {str(e)}")
838
- return "Error retrieving memory summary"
839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  compare_providers_tool,
18
  get_current_datetime_tool,
19
  side_effect_recording_tool,
 
20
  )
21
 
22
  # LangSmith tracing utilities
 
84
  compare_providers_tool,
85
  get_current_datetime_tool,
86
  side_effect_recording_tool,
 
87
  ]
88
 
89
 
 
92
  You are an advanced Medical Advisor Chatbot for healthcare professionals.
93
  Your primary purpose is to answer clinical and medical questions strictly based on authoritative medical guidelines using the tool "medical_guidelines_knowledge_tool".
94
 
95
+ Your answers must be concise, medically informative, evidence-based responses in an authoritative, precise, and clinical tone.
96
+ You will be responding to practicing medical professionals so adjust your answer and language accordingly.
97
+
98
  **INSTRUCTIONS:**
99
  - Always answer using only the information retrieved from medical guidelines via "medical_guidelines_knowledge_tool".
100
  - **SIDE EFFECT REPORTING**: When a healthcare professional reports an adverse drug reaction, side effect, or medication-related complication, ALWAYS use the "side_effect_recording_tool" first to document the information. Return the tool's response directly to the user without modification. DO NOT use validation or generate additional reports for side effect reporting queries.
101
  - Use the side effect recording tool when the input contains phrases like: "patient experienced", "side effect", "adverse reaction", "drug reaction", "medication caused", "developed after taking", etc.
102
  - When the side effect recording tool requests additional information, present the request exactly as provided by the tool.
103
+ - **PROVIDER COMPARISON**: When the user asks to compare guidance between two providers (e.g., "compare NCCN vs ESMO on ..."), use the "compare_providers_tool" with appropriate `provider_a` and `provider_b` values to retrieve side-by-side, cited results.
104
+ - **TIME/DATE QUERIES**: For any questions about the current date/time or references like "today" or "now", use the "get_current_datetime_tool". Treat this tool as the only reliable source of current time information.
105
  - For every answer, you MUST provide detailed citations including:
106
  * Source file name
107
  * Page number
 
169
  max_execution_time=90, # tighten a bit to help responsiveness
170
  )
171
 
172
+ # ============================================================================
173
+ # SESSION-BASED MEMORY MANAGEMENT
174
+ # ============================================================================
175
+
176
+ class SessionMemoryManager:
177
+ """Manages conversation memory for multiple sessions."""
178
+
179
+ def __init__(self):
180
+ self._sessions = {}
181
+ self._default_window_size = 10
182
+
183
+ def get_memory(self, session_id: str = "default") -> ConversationBufferWindowMemory:
184
+ """Get or create memory for a specific session."""
185
+ if session_id not in self._sessions:
186
+ self._sessions[session_id] = ConversationBufferWindowMemory(
187
+ memory_key="chat_history",
188
+ return_messages=True,
189
+ max_window_size=self._default_window_size
190
+ )
191
+ return self._sessions[session_id]
192
+
193
+ def clear_session(self, session_id: str) -> bool:
194
+ """Clear memory for a specific session."""
195
+ if session_id in self._sessions:
196
+ self._sessions[session_id].clear()
197
+ del self._sessions[session_id]
198
+ return True
199
+ return False
200
+
201
+ def clear_all_sessions(self):
202
+ """Clear all session memories."""
203
+ for memory in self._sessions.values():
204
+ memory.clear()
205
+ self._sessions.clear()
206
+
207
+ def get_active_sessions(self) -> list:
208
+ """Get list of active session IDs."""
209
+ return list(self._sessions.keys())
210
+
211
+ # Global session memory manager
212
+ _memory_manager = SessionMemoryManager()
213
 
214
 
215
  # ============================================================================
 
342
  # ============================================================================
343
 
344
  # @traceable(name="run_agent_streaming")
345
+ async def run_agent_streaming(user_input: str, session_id: str = "default", max_retries: int = 3) -> AsyncGenerator[str, None]:
346
  """
347
  Run the agent with streaming support and comprehensive error handling.
348
 
 
351
 
352
  Args:
353
  user_input (str): The user's input message to process
354
+ session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
355
  max_retries (int, optional): Maximum number of retries for recoverable errors.
356
  Defaults to 3.
357
 
 
382
  # Tracing for streaming disabled to avoid duplicate traces.
383
  # We keep tracing only for the AgentExecutor in run_agent().
384
  current_run_id = None
385
+ # Load conversation history from session-specific memory
386
+ memory = _memory_manager.get_memory(session_id)
387
  chat_history = memory.load_memory_variables({})["chat_history"]
388
 
389
  logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
 
587
  yield "Sorry, I was unable to process your request after several attempts. Please try again later."
588
 
589
 
590
+ async def safe_run_agent_streaming(user_input: str, session_id: str = "default") -> AsyncGenerator[str, None]:
591
  """
592
  Streaming wrapper function with additional safety checks and input validation.
593
 
 
597
 
598
  Args:
599
  user_input (str): The user's input message to process
600
+ session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
601
 
602
  Yields:
603
  str: Chunks of the agent's response as they are generated
 
626
  return
627
 
628
  # Stream the response through the main agent function
629
+ async for chunk in run_agent_streaming(user_input, session_id):
630
  yield chunk
631
 
632
  except Exception as e:
 
636
 
637
 
638
  @traceable(name="run_agent")
639
+ async def run_agent(user_input: str, session_id: str = "default", max_retries: int = 3) -> str:
640
  """
641
  Run the agent with comprehensive error handling and retry logic.
642
 
 
646
 
647
  Args:
648
  user_input (str): The user's input message to process
649
+ session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
650
  max_retries (int, optional): Maximum number of retries for recoverable errors.
651
  Defaults to 3.
652
 
 
668
 
669
  while retry_count <= max_retries:
670
  try:
671
+ # Load conversation history from session-specific memory
672
+ memory = _memory_manager.get_memory(session_id)
673
  chat_history = memory.load_memory_variables({})["chat_history"]
674
 
675
  logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
 
809
  return "Sorry, I was unable to process your request after several attempts. Please try again later."
810
 
811
 
812
+ async def safe_run_agent(user_input: str, session_id: str = "default") -> str:
813
  """
814
  Wrapper function for run_agent with additional safety checks and input validation.
815
 
 
819
 
820
  Args:
821
  user_input (str): The user's input message to process
822
+ session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
823
 
824
  Returns:
825
  str: The agent's response or an appropriate error message in English
 
845
  return "Sorry, I didn't receive any questions. Please enter your question or request."
846
 
847
  # Process the input through the main agent function
848
+ return await run_agent(user_input, session_id)
849
 
850
  except Exception as e:
851
  logger.critical(f"Critical error in safe_run_agent: {str(e)}")
 
861
  effectively starting a fresh conversation session.
862
  """
863
  try:
864
+ _memory_manager.clear_all_sessions()
865
  logger.info("Conversation memory cleared successfully")
866
  except Exception as e:
867
  logger.error(f"Error clearing memory: {str(e)}")
868
 
869
 
870
+ def get_memory_summary(session_id: str = "default") -> str:
871
  """
872
+ Get a summary of the conversation history for a specific session.
873
+
874
+ Args:
875
+ session_id (str, optional): Session identifier. Defaults to "default".
876
 
877
  Returns:
878
  str: A summary of the conversation history stored in memory
879
  """
880
  try:
881
+ memory = _memory_manager.get_memory(session_id)
882
  memory_vars = memory.load_memory_variables({})
883
  return str(memory_vars.get("chat_history", "No conversation history available"))
884
  except Exception as e:
885
  logger.error(f"Error getting memory summary: {str(e)}")
886
+ return "Error retrieving conversation history"
887
 
888
+
889
+ def clear_session_memory(session_id: str) -> bool:
890
+ """
891
+ Clear conversation memory for a specific session.
892
+
893
+ Args:
894
+ session_id (str): Session identifier to clear
895
+
896
+ Returns:
897
+ bool: True if session was cleared, False if session didn't exist
898
+ """
899
+ return _memory_manager.clear_session(session_id)
900
+
901
+
902
+ def get_active_sessions() -> list:
903
+ """
904
+ Get list of all active session IDs.
905
+
906
+ Returns:
907
+ list: List of active session identifiers
908
+ """
909
+ return _memory_manager.get_active_sessions()
core/tools.py CHANGED
@@ -6,11 +6,12 @@ from datetime import datetime
6
  from typing import Optional, List
7
 
8
  import pytz
9
- from langchain.schema import Document
10
  from langchain.tools import tool
11
  from .retrievers import hybrid_search, vector_search, bm25_search
12
  from .validation import validate_medical_answer
13
  from .github_storage import get_github_storage
 
14
 
15
  CANONICAL_PROVIDERS = {"Manus", "ASCO", "NCCN", "ESMO", "NICE"}
16
 
@@ -28,6 +29,38 @@ def store_user_question(user_question: str):
28
  global _last_user_question
29
  _last_user_question = user_question
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Map lowercase variants and full names to canonical provider codes
32
  _PROVIDER_ALIASES = {
33
  # NCCN
@@ -289,7 +322,7 @@ def side_effect_recording_tool(user_input: str) -> str:
289
  str: Interactive form for collecting missing information or confirmation of data recording
290
  """
291
  try:
292
- # Define keywords that indicate side effect reporting
293
  side_effect_keywords = [
294
  'side effect', 'adverse reaction', 'adverse event', 'drug reaction',
295
  'medication reaction', 'allergic reaction', 'complication', 'toxicity',
@@ -297,7 +330,6 @@ def side_effect_recording_tool(user_input: str) -> str:
297
  'overdose', 'poisoning', 'drug-induced', 'medication-induced',
298
  'experienced after taking', 'developed after', 'caused by medication',
299
  'drug-related', 'medication-related', 'pharmaceutical reaction',
300
- # Add more comprehensive problem/symptom keywords
301
  'kidney problems', 'liver problems', 'heart problems', 'breathing problems',
302
  'skin problems', 'stomach problems', 'nausea', 'vomiting', 'diarrhea',
303
  'headache', 'dizziness', 'fatigue', 'weakness', 'rash', 'swelling',
@@ -305,8 +337,8 @@ def side_effect_recording_tool(user_input: str) -> str:
305
  'has these', 'has serious', 'causes', 'resulted in', 'led to',
306
  'problems with', 'issues with', 'complications from'
307
  ]
308
-
309
  input_lower = user_input.lower().strip()
 
310
 
311
  # Check for special commands first
312
  if input_lower in ['save report', 'save', 'submit report', 'submit']:
@@ -324,8 +356,9 @@ def side_effect_recording_tool(user_input: str) -> str:
324
  extracted_data = _extract_side_effect_data(user_input)
325
  return _process_followup_response(user_input, extracted_data)
326
 
327
- # Check if input contains side effect reporting indicators
328
- contains_side_effect = any(keyword in input_lower for keyword in side_effect_keywords)
 
329
 
330
  if not contains_side_effect:
331
  return "This input does not appear to contain a side effect report. If you are reporting an adverse drug reaction, please include specific details about the medication and symptoms."
@@ -532,74 +565,23 @@ def _save_side_effect_report(extracted_data: dict) -> str:
532
  # Ensure the value is properly formatted
533
  extracted_data[field] = str(value).strip()
534
 
535
- # Save to GitHub repository
536
-
537
-
538
  github_storage = get_github_storage()
539
-
540
-
541
  success = github_storage.save_side_effects_report(extracted_data)
542
-
543
-
544
-
545
-
546
-
547
  if not success:
548
-
549
-
550
- # Fallback to local storage if GitHub fails
551
-
552
-
553
  csv_filename = "side_effects_reports.csv"
554
-
555
-
556
  csv_path = os.path.join(os.getcwd(), csv_filename)
557
-
558
-
559
-
560
-
561
-
562
  file_exists = os.path.exists(csv_path)
563
-
564
-
565
-
566
-
567
-
568
  with open(csv_path, 'a', newline='', encoding='utf-8') as csvfile:
569
-
570
-
571
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
572
-
573
-
574
-
575
-
576
-
577
  if not file_exists:
578
-
579
-
580
  writer.writeheader()
581
-
582
-
583
-
584
-
585
-
586
  writer.writerow(extracted_data)
587
-
588
-
589
-
590
-
591
-
592
  storage_location = "locally to side_effects_reports.csv (GitHub upload failed)"
593
-
594
-
595
  else:
596
-
597
-
598
  storage_location = "to GitHub cloud repository"
599
 
600
-
601
-
602
- # Generate confirmation message
603
  drug_name = extracted_data.get('drug_name', 'NaN')
604
  side_effects = extracted_data.get('side_effects', 'NaN')
605
  report_id = extracted_data['timestamp'].replace(':', '').replace('-', '').replace(' ', '_')
@@ -661,14 +643,10 @@ def _extract_side_effect_data_with_llm(user_input: str) -> dict:
661
  Returns:
662
  dict: Structured data extracted from the input
663
  """
664
- from langchain.llms import OpenAI
665
- from langchain.prompts import PromptTemplate
666
  import json
667
-
668
  # Get current timestamp
669
  egypt_tz = pytz.timezone('Africa/Cairo')
670
  current_time = datetime.now(egypt_tz).strftime('%Y-%m-%d %H:%M:%S')
671
-
672
  # Initialize extracted data with defaults
673
  extracted_data = {
674
  'timestamp': current_time,
@@ -682,65 +660,32 @@ def _extract_side_effect_data_with_llm(user_input: str) -> dict:
682
  'outcome': 'NaN',
683
  'additional_details': 'NaN',
684
  'reporter_info': 'NaN',
685
- 'raw_input': user_input[:500] # Limit raw input length
686
  }
687
-
688
- # Create extraction prompt
689
- extraction_prompt = PromptTemplate(
690
- input_variables=["user_input"],
691
- template="""Extract medical side effect information from the following text. Return ONLY a JSON object with these exact fields:
692
-
693
- {
694
- "drug_name": "name of the medication/drug mentioned",
695
- "side_effects": "list of side effects or symptoms described",
696
- "patient_age": "patient's age if mentioned",
697
- "patient_gender": "Male or Female if mentioned",
698
- "dosage": "medication dosage if mentioned",
699
- "duration": "treatment duration if mentioned (e.g., '3 months', '2 weeks')",
700
- "severity": "mild, moderate, or severe if mentioned",
701
- "outcome": "current status like ongoing, resolved, recovered if mentioned"
702
- }
703
-
704
- IMPORTANT RULES:
705
- 1. If any information is not found or unclear, use "NaN" as the value
706
- 2. For duration, look for phrases like "Treatment duration: 3 months", "for 2 weeks", "over 6 months", etc.
707
- 3. Extract exact values as mentioned in the text
708
- 4. Return ONLY the JSON object, no other text
709
-
710
- Text to analyze:
711
- {user_input}
712
-
713
- JSON:"""
714
- )
715
-
716
- try:
717
- # Try to use LLM extraction if available
718
  try:
719
- # This would use the actual LLM - for now we'll use a fallback approach
720
- # llm = OpenAI(temperature=0)
721
- # prompt = extraction_prompt.format(user_input=user_input)
722
- # response = llm(prompt)
723
- # extracted_json = json.loads(response.strip())
724
-
725
- # Fallback to improved regex-based extraction with better duration handling
726
- extracted_json = _extract_with_improved_regex(user_input)
727
-
 
 
 
 
728
  except Exception:
729
- # Fallback to improved regex extraction
730
  extracted_json = _extract_with_improved_regex(user_input)
731
-
732
- # Update extracted_data with LLM results, keeping NaN for empty values
733
- for key, value in extracted_json.items():
734
- if key in extracted_data and value and str(value).strip() and str(value).strip().lower() != 'nan':
735
- extracted_data[key] = str(value).strip()
736
-
737
- except Exception as e:
738
- # If LLM extraction fails, use improved regex fallback
739
  extracted_json = _extract_with_improved_regex(user_input)
740
- for key, value in extracted_json.items():
741
- if key in extracted_data and value and str(value).strip() and str(value).strip().lower() != 'nan':
742
- extracted_data[key] = str(value).strip()
743
-
744
  return extracted_data
745
 
746
 
@@ -885,89 +830,3 @@ def _extract_side_effect_data(user_input: str) -> dict:
885
  return _extract_side_effect_data_with_llm(user_input)
886
 
887
 
888
- @tool
889
- def medical_answer_validation_tool(
890
- question: Optional[str] = None,
891
- retrieved_documents: Optional[List] = None,
892
- generated_answer: Optional[str] = None
893
- ) -> str:
894
- """
895
- Validate a medical answer using the comprehensive validation system.
896
-
897
- This tool evaluates medical responses across 6 criteria: Accuracy, Coherence,
898
- Relevance, Completeness, Citations/Attribution, and Length.
899
-
900
- Args:
901
- question: The original medical question (optional - uses stored context if not provided)
902
- retrieved_documents: List of documents used for the answer (optional - uses stored context)
903
- generated_answer: The AI-generated answer to validate (optional - uses stored context)
904
-
905
- Returns:
906
- str: Formatted validation report with scores and improvement recommendations
907
- """
908
- global _last_question, _last_documents, _last_answer, _last_user_question
909
-
910
- try:
911
- # Use provided parameters or fall back to stored context
912
- # Prefer the original user question over the tool query
913
- eval_question = question or _last_user_question or _last_question
914
- eval_documents = retrieved_documents or _last_documents or []
915
- eval_answer = generated_answer or _last_answer
916
-
917
- # Validate that we have the required information
918
- if not eval_question:
919
- return "Error: No question available for validation. Please provide a question or ensure medical_guidelines_knowledge_tool was used first."
920
-
921
- if not eval_answer:
922
- return "Error: No answer available for validation. Please provide an answer to validate."
923
-
924
- if not eval_documents:
925
- return "Warning: No retrieved documents available for validation. Validation will proceed with limited context."
926
-
927
- # Store the answer for future reference
928
- if generated_answer:
929
- _last_answer = generated_answer
930
-
931
- # Perform validation
932
- evaluation = validate_medical_answer(eval_question, eval_documents, eval_answer)
933
-
934
- # Format the validation report for display
935
- report = evaluation.get("validation_report", {})
936
-
937
- formatted_report = f"""
938
- **🔍 MEDICAL ANSWER VALIDATION REPORT**
939
-
940
- **Interaction ID:** {evaluation.get('interaction_id', 'N/A')}
941
- **Timestamp:** {evaluation.get('timestamp', 'N/A')}
942
-
943
- **Overall Score:** {report.get('Overall_Rating', 'N/A')}/100
944
-
945
- **Key Metrics:**
946
-
947
- **Accuracy:** {report.get('Accuracy_Rating', 'N/A')}/100
948
- {report.get('Accuracy_Comment', 'No comment available')}
949
-
950
- **Coherence:** {report.get('Coherence_Rating', 'N/A')}/100
951
- {report.get('Coherence_Comment', 'No comment available')}
952
-
953
- **Relevance:** {report.get('Relevance_Rating', 'N/A')}/100
954
- {report.get('Relevance_Comment', 'No comment available')}
955
-
956
- **Completeness:** {report.get('Completeness_Rating', 'N/A')}/100
957
- {report.get('Completeness_Comment', 'No comment available')}
958
-
959
- **Citations:** {report.get('Citations_Attribution_Rating', 'N/A')}/100
960
- {report.get('Citations_Attribution_Comment', 'No comment available')}
961
-
962
- **Length:** {report.get('Length_Rating', 'N/A')}/100
963
- {report.get('Length_Comment', 'No comment available')}
964
-
965
- **Assessment:** {report.get('Final_Summary_and_Improvement_Plan', 'No improvement plan available')}
966
-
967
- **📁 Data Storage:** Evaluation saved to evaluation_results.json
968
- """
969
-
970
- return formatted_report.strip()
971
-
972
- except Exception as e:
973
- return f"Validation error: {str(e)}. Please ensure all required parameters are provided or that context is available from previous tool usage."
 
6
  from typing import Optional, List
7
 
8
  import pytz
9
+ from langchain.schema import Document, HumanMessage, SystemMessage
10
  from langchain.tools import tool
11
  from .retrievers import hybrid_search, vector_search, bm25_search
12
  from .validation import validate_medical_answer
13
  from .github_storage import get_github_storage
14
+ from langchain_openai import ChatOpenAI
15
 
16
  CANONICAL_PROVIDERS = {"Manus", "ASCO", "NCCN", "ESMO", "NICE"}
17
 
 
29
  global _last_user_question
30
  _last_user_question = user_question
31
 
32
+ def _get_llm_safe(temperature: float = 0.0, model: str = "gpt-4o"):
33
+ """Create a ChatOpenAI client if API key/config is available, else return None."""
34
+ try:
35
+ # ChatOpenAI will read OPENAI_API_KEY from env as in validation.py
36
+ return ChatOpenAI(model=model, temperature=temperature, max_tokens=512, request_timeout=30)
37
+ except Exception:
38
+ return None
39
+
40
+ def _is_side_effect_report_llm(user_input: str) -> Optional[bool]:
41
+ """Use LLM to classify if input is an adverse drug reaction/side-effect report.
42
+ Returns True/False if confident, or None if unavailable/uncertain.
43
+ """
44
+ llm = _get_llm_safe()
45
+ if not llm:
46
+ return None
47
+ try:
48
+ system = SystemMessage(content=(
49
+ "You are a medical triage classifier. Decide if the user's text is a report of an adverse drug reaction (side effect) about a medication.\n"
50
+ "Criteria: mentions a medication/drug and symptoms or adverse effects experienced by a patient.\n"
51
+ "Respond with exactly one token: yes or no."
52
+ ))
53
+ human = HumanMessage(content=user_input[:1500])
54
+ resp = llm.invoke([system, human])
55
+ ans = (resp.content or "").strip().lower()
56
+ if ans.startswith("yes"):
57
+ return True
58
+ if ans.startswith("no"):
59
+ return False
60
+ return None
61
+ except Exception:
62
+ return None
63
+
64
  # Map lowercase variants and full names to canonical provider codes
65
  _PROVIDER_ALIASES = {
66
  # NCCN
 
322
  str: Interactive form for collecting missing information or confirmation of data recording
323
  """
324
  try:
325
+ # LLM classification (preferred), with keyword fallback to preserve behavior
326
  side_effect_keywords = [
327
  'side effect', 'adverse reaction', 'adverse event', 'drug reaction',
328
  'medication reaction', 'allergic reaction', 'complication', 'toxicity',
 
330
  'overdose', 'poisoning', 'drug-induced', 'medication-induced',
331
  'experienced after taking', 'developed after', 'caused by medication',
332
  'drug-related', 'medication-related', 'pharmaceutical reaction',
 
333
  'kidney problems', 'liver problems', 'heart problems', 'breathing problems',
334
  'skin problems', 'stomach problems', 'nausea', 'vomiting', 'diarrhea',
335
  'headache', 'dizziness', 'fatigue', 'weakness', 'rash', 'swelling',
 
337
  'has these', 'has serious', 'causes', 'resulted in', 'led to',
338
  'problems with', 'issues with', 'complications from'
339
  ]
 
340
  input_lower = user_input.lower().strip()
341
+ llm_decision = _is_side_effect_report_llm(user_input)
342
 
343
  # Check for special commands first
344
  if input_lower in ['save report', 'save', 'submit report', 'submit']:
 
356
  extracted_data = _extract_side_effect_data(user_input)
357
  return _process_followup_response(user_input, extracted_data)
358
 
359
+ # Combine LLM decision with keyword fallback to avoid behavior regression
360
+ keyword_detected = any(keyword in input_lower for keyword in side_effect_keywords)
361
+ contains_side_effect = (llm_decision is True) or (llm_decision is not False and keyword_detected)
362
 
363
  if not contains_side_effect:
364
  return "This input does not appear to contain a side effect report. If you are reporting an adverse drug reaction, please include specific details about the medication and symptoms."
 
565
  # Ensure the value is properly formatted
566
  extracted_data[field] = str(value).strip()
567
 
568
+ # Save to GitHub repository (fallback to local if needed)
 
 
569
  github_storage = get_github_storage()
 
 
570
  success = github_storage.save_side_effects_report(extracted_data)
 
 
 
 
 
571
  if not success:
 
 
 
 
 
572
  csv_filename = "side_effects_reports.csv"
 
 
573
  csv_path = os.path.join(os.getcwd(), csv_filename)
 
 
 
 
 
574
  file_exists = os.path.exists(csv_path)
 
 
 
 
 
575
  with open(csv_path, 'a', newline='', encoding='utf-8') as csvfile:
 
 
576
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
 
 
 
 
 
577
  if not file_exists:
 
 
578
  writer.writeheader()
 
 
 
 
 
579
  writer.writerow(extracted_data)
 
 
 
 
 
580
  storage_location = "locally to side_effects_reports.csv (GitHub upload failed)"
 
 
581
  else:
 
 
582
  storage_location = "to GitHub cloud repository"
583
 
584
+ # Generate confirmation message
 
 
585
  drug_name = extracted_data.get('drug_name', 'NaN')
586
  side_effects = extracted_data.get('side_effects', 'NaN')
587
  report_id = extracted_data['timestamp'].replace(':', '').replace('-', '').replace(' ', '_')
 
643
  Returns:
644
  dict: Structured data extracted from the input
645
  """
 
 
646
  import json
 
647
  # Get current timestamp
648
  egypt_tz = pytz.timezone('Africa/Cairo')
649
  current_time = datetime.now(egypt_tz).strftime('%Y-%m-%d %H:%M:%S')
 
650
  # Initialize extracted data with defaults
651
  extracted_data = {
652
  'timestamp': current_time,
 
660
  'outcome': 'NaN',
661
  'additional_details': 'NaN',
662
  'reporter_info': 'NaN',
663
+ 'raw_input': user_input[:500]
664
  }
665
+ llm = _get_llm_safe()
666
+ if llm:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  try:
668
+ system = SystemMessage(content=(
669
+ "Extract medical side effect information. Return ONLY a JSON object with these exact fields: "
670
+ "drug_name, side_effects, patient_age, patient_gender, dosage, duration, severity, outcome. "
671
+ "If missing/unclear, use 'NaN'."
672
+ ))
673
+ human = HumanMessage(content=user_input[:2000])
674
+ response = llm.invoke([system, human])
675
+ text = (response.content or "").strip()
676
+ # Try parse; if fails, fallback regex
677
+ try:
678
+ extracted_json = json.loads(text)
679
+ except json.JSONDecodeError:
680
+ extracted_json = _extract_with_improved_regex(user_input)
681
  except Exception:
 
682
  extracted_json = _extract_with_improved_regex(user_input)
683
+ else:
 
 
 
 
 
 
 
684
  extracted_json = _extract_with_improved_regex(user_input)
685
+ # Update extracted_data
686
+ for key, value in extracted_json.items():
687
+ if key in extracted_data and value and str(value).strip() and str(value).strip().lower() != 'nan':
688
+ extracted_data[key] = str(value).strip()
689
  return extracted_data
690
 
691
 
 
830
  return _extract_side_effect_data_with_llm(user_input)
831
 
832