Spaces:
Sleeping
Sleeping
File size: 9,627 Bytes
fac69f9 6781788 afe6838 e23fefd fac69f9 e23fefd fac69f9 e23fefd fac69f9 6781788 24e3e87 233d8ee 24e3e87 e670011 24e3e87 fac69f9 233d8ee e23fefd fac69f9 afe6838 fac69f9 afe6838 24e3e87 fac69f9 afe6838 fac69f9 24e3e87 233d8ee 24e3e87 d05878e fac69f9 24e3e87 233d8ee 24e3e87 233d8ee 24e3e87 fac69f9 e23fefd fac69f9 e670011 afe6838 fac69f9 6781788 fac69f9 6781788 fac69f9 e23fefd fac69f9 24e3e87 233d8ee 24e3e87 fac69f9 6781788 fac69f9 6781788 e23fefd fac69f9 e23fefd fac69f9 24e3e87 fac69f9 e23fefd fac69f9 e23fefd fac69f9 6781788 fac69f9 6781788 fac69f9 afe6838 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from chatlib.state_types import AppState
import json
def remove_tool_call_messages(messages):
new_messages = []
skip_tool_call_ids = set()
for msg in messages:
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
for call in msg.tool_calls:
skip_tool_call_ids.add(call["id"])
continue # skip AIMessage with tool calls
if isinstance(msg, ToolMessage) and msg.tool_call_id in skip_tool_call_ids:
continue # skip ToolMessages corresponding to removed AIMessage
new_messages.append(msg)
return new_messages
def summarize_conversation(messages, llm):
"""Summarizes the conversation history (excluding system messages)."""
history = [m for m in messages if isinstance(m, (HumanMessage, AIMessage))]
text = "\n\n".join(
f"{'User' if isinstance(m, HumanMessage) else 'Assistant'}: {m.content}"
for m in history
)
prompt = (
"Summarize the clinical conversation below in a way that retains all key clinical facts and decisions.\n\n"
f"{text}\n\nSummary:"
)
response = llm.invoke([SystemMessage(content=prompt)])
return response.content
def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
# Initialize missing keys with defaults
state.setdefault("question", "")
state.setdefault("pk_hash", "")
state.setdefault("sitecode", "")
state.setdefault("rag_result", "")
state.setdefault("rag_sources", "")
state.setdefault("answer", "")
state.setdefault("last_answer", None)
state.setdefault("last_user_message", None)
state.setdefault("last_tool", None)
state.setdefault("idsr_disclaimer_shown", False)
state.setdefault("summary", None)
state.setdefault("context", None)
state.setdefault("context_versions", {})
state.setdefault("last_context_injected_versions", {})
state.setdefault("context_version_ready_for_injection", 0)
state.setdefault("context_first_response_sent", True)
messages = state.get("messages", [])
base_messages = [sys_msg]
messages = base_messages + [m for m in messages if not isinstance(m, SystemMessage)]
# Filter out existing pk_hash and sitecode system messages and add new ones
messages = [
m
for m in messages
if not (
isinstance(m, SystemMessage)
and (
m.content.startswith("Patient identifier (pk_hash):")
or m.content.startswith("Site code:")
)
)
]
# Inject pk_hash and sitecode as system messages if they exist and are non-empty
pk_hash_value = state.get("pk_hash")
if pk_hash_value:
pk_hash_msg = SystemMessage(
content=f"Patient identifier (pk_hash): {pk_hash_value}"
)
messages.append(pk_hash_msg)
sitecode_value = state.get("sitecode")
if sitecode_value:
sitecode_msg = SystemMessage(content=f"Site code: {sitecode_value}")
messages.append(sitecode_msg)
latest_question = next(
(m.content for m in reversed(messages) if isinstance(m, HumanMessage)), ""
)
user_message_changed = latest_question != state.get("last_user_message")
if user_message_changed:
# Clean old tool calls before invoking new ones
messages = remove_tool_call_messages(messages)
state["answer"] = ""
state["rag_result"] = ""
# Process latest ToolMessage and update context_version
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
try:
content = msg.content
data = json.loads(content) if isinstance(content, str) else content
tool_name = data.get("last_tool")
new_context = data.get("context")
if tool_name:
old_context = state.get("context", "")
old_version = state["context_versions"].get(tool_name, 0)
if new_context is not None and new_context != old_context:
state["context"] = new_context
state["context_versions"][tool_name] = old_version + 1
state["context_first_response_sent"] = (
False # Reset flag on new context
)
state["last_tool"] = tool_name
for k, v in data.items():
if k not in ("context", "last_tool"):
state[k] = v
break
except json.JSONDecodeError:
break
tool_name = "idsr_check"
current_version = state["context_versions"].get(tool_name, 0)
last_injected_version = state["last_context_injected_versions"].get(tool_name, 0)
# On turns where user message is unchanged, advance ready_for_injection to current_version
if (
not user_message_changed
and state["context_version_ready_for_injection"] < current_version
):
state["context_version_ready_for_injection"] = current_version
# Inject context system message only if:
# - last_tool matches tool_name
# - context exists
# - ready_for_injection > last injected version
# - AND first AI response after new context has been sent
if (
state.get("last_tool") == tool_name
and state.get("context")
and state["context_version_ready_for_injection"] > last_injected_version
and state.get("context_first_response_sent", True)
):
context_msg = SystemMessage(
content=(
f"The following information was retrieved from the {tool_name.upper()} database and may help answer the user's question:\n\n"
f"{state['context']}\n\n"
"Use this information when responding."
)
)
messages.append(context_msg)
state["last_context_injected_versions"][tool_name] = state[
"context_version_ready_for_injection"
]
state["last_tool"] = None
# Invoke LLM with tools (this returns AIMessage with tool_calls if tool call is needed)
new_message = llm_with_tools.invoke(messages)
messages.append(new_message)
# If the new_message has tool_calls, it means a tool call is pending; return now so tool node runs
if getattr(new_message, "tool_calls", None):
state["messages"] = messages
state["last_user_message"] = latest_question
return state
# No more tool calls: generate final answer from state or AIMessage content
if state.get("answer"):
final_content = state["answer"]
elif state.get("rag_result"):
# Use conversation history + a system message to inject RAG guidance
rag_msg = SystemMessage(
content = (
"Based on the following clinical guideline excerpts, answer the clinician's question as precisely as possible.\n\n"
"Focus only on information that directly addresses the question.\n"
"Do not include background or general recommendations unless they are explicitly relevant.\n\n"
"Guideline excerpts:\n"
f"{state['rag_result']}\n\n"
"Respond with a focused summary tailored to the question about advanced HIV disease."
)
)
messages_with_rag = messages + [rag_msg]
llm_response = llm.invoke(messages_with_rag)
final_content = llm_response.content
else:
final_content = new_message.content
# Add disclaimer if needed
if state.get("last_tool") == "idsr_check" and not state.get(
"idsr_disclaimer_shown", False
):
disclaimer = (
"Disclaimer: This is not a diagnosis. This is meant to help "
"identify possible matches based on priority IDSR diseases for clinician awareness.\n\n"
)
final_content = disclaimer + final_content
state["idsr_disclaimer_shown"] = True
# After generating AI message, mark first response sent
if (
state.get("last_tool") == tool_name
or state.get("context_first_response_sent") is False
):
state["context_first_response_sent"] = True
# Replace the last AIMessage content with final_content to avoid duplicates
for i in reversed(range(len(messages))):
if isinstance(messages[i], AIMessage):
messages[i] = AIMessage(content=final_content)
break
else:
# fallback: append if no AIMessage found (rare)
messages.append(AIMessage(content=final_content))
# Summarization logic
non_sys_messages = [m for m in messages if not isinstance(m, SystemMessage)]
human_ai_messages = [
m for m in non_sys_messages if isinstance(m, (HumanMessage, AIMessage))
]
if len(human_ai_messages) > 10:
summary_text = summarize_conversation(messages, llm)
summary_msg = SystemMessage(
content="Summary of earlier conversation:\n" + summary_text
)
# Keep sys_msg, the new summary message, and the last 5 Human/AI messages
recent_msgs = [
m for m in reversed(messages) if isinstance(m, (HumanMessage, AIMessage))
][:5]
recent_msgs.reverse()
messages = [sys_msg, summary_msg] + recent_msgs
state["answer"] = final_content
state["messages"] = messages
state["last_user_message"] = latest_question
state["question"] = latest_question
return state
|