File size: 9,627 Bytes
fac69f9
6781788
afe6838
 
e23fefd
fac69f9
 
 
 
 
 
 
 
 
 
 
 
 
e23fefd
fac69f9
 
 
e23fefd
 
 
 
fac69f9
 
 
 
 
 
 
6781788
 
24e3e87
 
 
233d8ee
 
24e3e87
e670011
24e3e87
 
 
 
 
 
 
 
 
 
 
 
fac69f9
 
 
 
233d8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e23fefd
 
 
fac69f9
afe6838
fac69f9
 
 
 
 
afe6838
24e3e87
fac69f9
 
afe6838
fac69f9
 
24e3e87
 
 
 
 
 
 
 
 
 
 
233d8ee
 
 
24e3e87
 
 
 
 
 
 
 
d05878e
fac69f9
 
24e3e87
 
 
 
 
233d8ee
 
 
 
24e3e87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233d8ee
 
 
24e3e87
 
fac69f9
 
 
e23fefd
fac69f9
 
 
 
 
 
 
 
 
 
 
 
 
e670011
 
 
 
 
 
 
 
afe6838
fac69f9
 
 
 
6781788
fac69f9
6781788
fac69f9
e23fefd
 
 
fac69f9
 
 
 
 
 
 
24e3e87
233d8ee
 
 
 
24e3e87
 
fac69f9
 
 
 
6781788
fac69f9
 
 
6781788
e23fefd
fac69f9
e23fefd
 
 
fac69f9
24e3e87
fac69f9
e23fefd
 
 
fac69f9
 
e23fefd
 
 
fac69f9
 
 
 
6781788
fac69f9
6781788
fac69f9
afe6838
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from chatlib.state_types import AppState
import json


def remove_tool_call_messages(messages):
    new_messages = []
    skip_tool_call_ids = set()
    for msg in messages:
        if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
            for call in msg.tool_calls:
                skip_tool_call_ids.add(call["id"])
            continue  # skip AIMessage with tool calls
        if isinstance(msg, ToolMessage) and msg.tool_call_id in skip_tool_call_ids:
            continue  # skip ToolMessages corresponding to removed AIMessage
        new_messages.append(msg)
    return new_messages


def summarize_conversation(messages, llm):
    """Summarizes the conversation history (excluding system messages)."""
    history = [m for m in messages if isinstance(m, (HumanMessage, AIMessage))]
    text = "\n\n".join(
        f"{'User' if isinstance(m, HumanMessage) else 'Assistant'}: {m.content}"
        for m in history
    )
    prompt = (
        "Summarize the clinical conversation below in a way that retains all key clinical facts and decisions.\n\n"
        f"{text}\n\nSummary:"
    )
    response = llm.invoke([SystemMessage(content=prompt)])
    return response.content


def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:

    # Initialize missing keys with defaults
    state.setdefault("question", "")
    state.setdefault("pk_hash", "")
    state.setdefault("sitecode", "")
    state.setdefault("rag_result", "")
    state.setdefault("rag_sources", "")
    state.setdefault("answer", "")
    state.setdefault("last_answer", None)
    state.setdefault("last_user_message", None)
    state.setdefault("last_tool", None)
    state.setdefault("idsr_disclaimer_shown", False)
    state.setdefault("summary", None)
    state.setdefault("context", None)
    state.setdefault("context_versions", {})
    state.setdefault("last_context_injected_versions", {})
    state.setdefault("context_version_ready_for_injection", 0)
    state.setdefault("context_first_response_sent", True)

    messages = state.get("messages", [])
    base_messages = [sys_msg]
    messages = base_messages + [m for m in messages if not isinstance(m, SystemMessage)]

    # Filter out existing pk_hash and sitecode system messages and add new ones
    messages = [
        m
        for m in messages
        if not (
            isinstance(m, SystemMessage)
            and (
                m.content.startswith("Patient identifier (pk_hash):")
                or m.content.startswith("Site code:")
            )
        )
    ]

    # Inject pk_hash and sitecode as system messages if they exist and are non-empty
    pk_hash_value = state.get("pk_hash")
    if pk_hash_value:
        pk_hash_msg = SystemMessage(
            content=f"Patient identifier (pk_hash): {pk_hash_value}"
        )
        messages.append(pk_hash_msg)

    sitecode_value = state.get("sitecode")
    if sitecode_value:
        sitecode_msg = SystemMessage(content=f"Site code: {sitecode_value}")
        messages.append(sitecode_msg)

    latest_question = next(
        (m.content for m in reversed(messages) if isinstance(m, HumanMessage)), ""
    )
    user_message_changed = latest_question != state.get("last_user_message")

    if user_message_changed:
        # Clean old tool calls before invoking new ones
        messages = remove_tool_call_messages(messages)
        state["answer"] = ""
        state["rag_result"] = ""

    # Process latest ToolMessage and update context_version
    for msg in reversed(messages):
        if isinstance(msg, ToolMessage):
            try:
                content = msg.content
                data = json.loads(content) if isinstance(content, str) else content

                tool_name = data.get("last_tool")
                new_context = data.get("context")

                if tool_name:
                    old_context = state.get("context", "")
                    old_version = state["context_versions"].get(tool_name, 0)

                    if new_context is not None and new_context != old_context:
                        state["context"] = new_context
                        state["context_versions"][tool_name] = old_version + 1
                        state["context_first_response_sent"] = (
                            False  # Reset flag on new context
                        )

                    state["last_tool"] = tool_name

                for k, v in data.items():
                    if k not in ("context", "last_tool"):
                        state[k] = v

                break
            except json.JSONDecodeError:
                break

    tool_name = "idsr_check"
    current_version = state["context_versions"].get(tool_name, 0)
    last_injected_version = state["last_context_injected_versions"].get(tool_name, 0)

    # On turns where user message is unchanged, advance ready_for_injection to current_version
    if (
        not user_message_changed
        and state["context_version_ready_for_injection"] < current_version
    ):
        state["context_version_ready_for_injection"] = current_version

    # Inject context system message only if:
    # - last_tool matches tool_name
    # - context exists
    # - ready_for_injection > last injected version
    # - AND first AI response after new context has been sent
    if (
        state.get("last_tool") == tool_name
        and state.get("context")
        and state["context_version_ready_for_injection"] > last_injected_version
        and state.get("context_first_response_sent", True)
    ):
        context_msg = SystemMessage(
            content=(
                f"The following information was retrieved from the {tool_name.upper()} database and may help answer the user's question:\n\n"
                f"{state['context']}\n\n"
                "Use this information when responding."
            )
        )
        messages.append(context_msg)

        state["last_context_injected_versions"][tool_name] = state[
            "context_version_ready_for_injection"
        ]
        state["last_tool"] = None

    # Invoke LLM with tools (this returns AIMessage with tool_calls if tool call is needed)
    new_message = llm_with_tools.invoke(messages)
    messages.append(new_message)

    # If the new_message has tool_calls, it means a tool call is pending; return now so tool node runs
    if getattr(new_message, "tool_calls", None):
        state["messages"] = messages
        state["last_user_message"] = latest_question
        return state

    # No more tool calls: generate final answer from state or AIMessage content
    if state.get("answer"):
        final_content = state["answer"]

    elif state.get("rag_result"):
        # Use conversation history + a system message to inject RAG guidance
        rag_msg = SystemMessage(
            content = (
                "Based on the following clinical guideline excerpts, answer the clinician's question as precisely as possible.\n\n"
                "Focus only on information that directly addresses the question.\n"
                "Do not include background or general recommendations unless they are explicitly relevant.\n\n"
                "Guideline excerpts:\n"
                f"{state['rag_result']}\n\n"
                "Respond with a focused summary tailored to the question about advanced HIV disease."
            )
        )
        messages_with_rag = messages + [rag_msg]
        llm_response = llm.invoke(messages_with_rag)
        final_content = llm_response.content

    else:
        final_content = new_message.content

    # Add disclaimer if needed
    if state.get("last_tool") == "idsr_check" and not state.get(
        "idsr_disclaimer_shown", False
    ):
        disclaimer = (
            "Disclaimer: This is not a diagnosis. This is meant to help "
            "identify possible matches based on priority IDSR diseases for clinician awareness.\n\n"
        )
        final_content = disclaimer + final_content
        state["idsr_disclaimer_shown"] = True

    # After generating AI message, mark first response sent
    if (
        state.get("last_tool") == tool_name
        or state.get("context_first_response_sent") is False
    ):
        state["context_first_response_sent"] = True

    # Replace the last AIMessage content with final_content to avoid duplicates
    for i in reversed(range(len(messages))):
        if isinstance(messages[i], AIMessage):
            messages[i] = AIMessage(content=final_content)
            break
    else:
        # fallback: append if no AIMessage found (rare)
        messages.append(AIMessage(content=final_content))

    # Summarization logic
    non_sys_messages = [m for m in messages if not isinstance(m, SystemMessage)]
    human_ai_messages = [
        m for m in non_sys_messages if isinstance(m, (HumanMessage, AIMessage))
    ]

    if len(human_ai_messages) > 10:
        summary_text = summarize_conversation(messages, llm)
        summary_msg = SystemMessage(
            content="Summary of earlier conversation:\n" + summary_text
        )

        # Keep sys_msg, the new summary message, and the last 5 Human/AI messages
        recent_msgs = [
            m for m in reversed(messages) if isinstance(m, (HumanMessage, AIMessage))
        ][:5]
        recent_msgs.reverse()
        messages = [sys_msg, summary_msg] + recent_msgs

    state["answer"] = final_content
    state["messages"] = messages
    state["last_user_message"] = latest_question
    state["question"] = latest_question

    return state