from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage from chatlib.state_types import AppState import json def remove_tool_call_messages(messages): new_messages = [] skip_tool_call_ids = set() for msg in messages: if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): for call in msg.tool_calls: skip_tool_call_ids.add(call["id"]) continue # skip AIMessage with tool calls if isinstance(msg, ToolMessage) and msg.tool_call_id in skip_tool_call_ids: continue # skip ToolMessages corresponding to removed AIMessage new_messages.append(msg) return new_messages def summarize_conversation(messages, llm): """Summarizes the conversation history (excluding system messages).""" history = [m for m in messages if isinstance(m, (HumanMessage, AIMessage))] text = "\n\n".join( f"{'User' if isinstance(m, HumanMessage) else 'Assistant'}: {m.content}" for m in history ) prompt = ( "Summarize the clinical conversation below in a way that retains all key clinical facts and decisions.\n\n" f"{text}\n\nSummary:" ) response = llm.invoke([SystemMessage(content=prompt)]) return response.content def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState: # Initialize missing keys with defaults state.setdefault("question", "") state.setdefault("pk_hash", "") state.setdefault("sitecode", "") state.setdefault("rag_result", "") state.setdefault("rag_sources", "") state.setdefault("answer", "") state.setdefault("last_answer", None) state.setdefault("last_user_message", None) state.setdefault("last_tool", None) state.setdefault("idsr_disclaimer_shown", False) state.setdefault("summary", None) state.setdefault("context", None) state.setdefault("context_versions", {}) state.setdefault("last_context_injected_versions", {}) state.setdefault("context_version_ready_for_injection", 0) state.setdefault("context_first_response_sent", True) messages = state.get("messages", []) base_messages = [sys_msg] messages = base_messages + [m for m in messages if not isinstance(m, SystemMessage)] # Filter out existing pk_hash and sitecode system messages and add new ones messages = [ m for m in messages if not ( isinstance(m, SystemMessage) and ( m.content.startswith("Patient identifier (pk_hash):") or m.content.startswith("Site code:") ) ) ] # Inject pk_hash and sitecode as system messages if they exist and are non-empty pk_hash_value = state.get("pk_hash") if pk_hash_value: pk_hash_msg = SystemMessage( content=f"Patient identifier (pk_hash): {pk_hash_value}" ) messages.append(pk_hash_msg) sitecode_value = state.get("sitecode") if sitecode_value: sitecode_msg = SystemMessage(content=f"Site code: {sitecode_value}") messages.append(sitecode_msg) latest_question = next( (m.content for m in reversed(messages) if isinstance(m, HumanMessage)), "" ) user_message_changed = latest_question != state.get("last_user_message") if user_message_changed: # Clean old tool calls before invoking new ones messages = remove_tool_call_messages(messages) state["answer"] = "" state["rag_result"] = "" # Process latest ToolMessage and update context_version for msg in reversed(messages): if isinstance(msg, ToolMessage): try: content = msg.content data = json.loads(content) if isinstance(content, str) else content tool_name = data.get("last_tool") new_context = data.get("context") if tool_name: old_context = state.get("context", "") old_version = state["context_versions"].get(tool_name, 0) if new_context is not None and new_context != old_context: state["context"] = new_context state["context_versions"][tool_name] = old_version + 1 state["context_first_response_sent"] = ( False # Reset flag on new context ) state["last_tool"] = tool_name for k, v in data.items(): if k not in ("context", "last_tool"): state[k] = v break except json.JSONDecodeError: break tool_name = "idsr_check" current_version = state["context_versions"].get(tool_name, 0) last_injected_version = state["last_context_injected_versions"].get(tool_name, 0) # On turns where user message is unchanged, advance ready_for_injection to current_version if ( not user_message_changed and state["context_version_ready_for_injection"] < current_version ): state["context_version_ready_for_injection"] = current_version # Inject context system message only if: # - last_tool matches tool_name # - context exists # - ready_for_injection > last injected version # - AND first AI response after new context has been sent if ( state.get("last_tool") == tool_name and state.get("context") and state["context_version_ready_for_injection"] > last_injected_version and state.get("context_first_response_sent", True) ): context_msg = SystemMessage( content=( f"The following information was retrieved from the {tool_name.upper()} database and may help answer the user's question:\n\n" f"{state['context']}\n\n" "Use this information when responding." ) ) messages.append(context_msg) state["last_context_injected_versions"][tool_name] = state[ "context_version_ready_for_injection" ] state["last_tool"] = None # Invoke LLM with tools (this returns AIMessage with tool_calls if tool call is needed) new_message = llm_with_tools.invoke(messages) messages.append(new_message) # If the new_message has tool_calls, it means a tool call is pending; return now so tool node runs if getattr(new_message, "tool_calls", None): state["messages"] = messages state["last_user_message"] = latest_question return state # No more tool calls: generate final answer from state or AIMessage content if state.get("answer"): final_content = state["answer"] elif state.get("rag_result"): # Use conversation history + a system message to inject RAG guidance rag_msg = SystemMessage( content = ( "Based on the following clinical guideline excerpts, answer the clinician's question as precisely as possible.\n\n" "Focus only on information that directly addresses the question.\n" "Do not include background or general recommendations unless they are explicitly relevant.\n\n" "Guideline excerpts:\n" f"{state['rag_result']}\n\n" "Respond with a focused summary tailored to the question about advanced HIV disease." ) ) messages_with_rag = messages + [rag_msg] llm_response = llm.invoke(messages_with_rag) final_content = llm_response.content else: final_content = new_message.content # Add disclaimer if needed if state.get("last_tool") == "idsr_check" and not state.get( "idsr_disclaimer_shown", False ): disclaimer = ( "Disclaimer: This is not a diagnosis. This is meant to help " "identify possible matches based on priority IDSR diseases for clinician awareness.\n\n" ) final_content = disclaimer + final_content state["idsr_disclaimer_shown"] = True # After generating AI message, mark first response sent if ( state.get("last_tool") == tool_name or state.get("context_first_response_sent") is False ): state["context_first_response_sent"] = True # Replace the last AIMessage content with final_content to avoid duplicates for i in reversed(range(len(messages))): if isinstance(messages[i], AIMessage): messages[i] = AIMessage(content=final_content) break else: # fallback: append if no AIMessage found (rare) messages.append(AIMessage(content=final_content)) # Summarization logic non_sys_messages = [m for m in messages if not isinstance(m, SystemMessage)] human_ai_messages = [ m for m in non_sys_messages if isinstance(m, (HumanMessage, AIMessage)) ] if len(human_ai_messages) > 10: summary_text = summarize_conversation(messages, llm) summary_msg = SystemMessage( content="Summary of earlier conversation:\n" + summary_text ) # Keep sys_msg, the new summary message, and the last 5 Human/AI messages recent_msgs = [ m for m in reversed(messages) if isinstance(m, (HumanMessage, AIMessage)) ][:5] recent_msgs.reverse() messages = [sys_msg, summary_msg] + recent_msgs state["answer"] = final_content state["messages"] = messages state["last_user_message"] = latest_question state["question"] = latest_question return state