File size: 1,747 Bytes
8986591
 
 
2633945
 
 
8986591
 
 
 
 
2633945
8986591
 
 
 
 
07e18ac
 
8986591
 
 
 
 
 
 
 
 
 
 
2633945
8986591
 
 
07e18ac
 
2633945
 
 
 
 
07e18ac
2633945
07e18ac
8986591
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
app/nodes/memory.py — CHECKPOINT 5: Memory

Only summarizes safe, on-topic conversation turns.
Harmful turns are never in messages (scrubbed by output_node).
Memory summary is topic-neutral — no harmful context bleeds through.
"""
from langchain_core.messages import HumanMessage, AIMessage
from app.state import AgentState
from app.utils.llm import llm

SUMMARY_THRESHOLD = 6  # min messages before summarizing


def memory_node(state: AgentState) -> AgentState:
    log = state.get("node_log", []) + ["memory"]

    if state.get("is_harmful"):
        return {**state, "node_log": log}

    clean = [
        m for m in state["messages"]
        if isinstance(m, HumanMessage)
        or (isinstance(m, AIMessage) and not getattr(m, "tool_calls", []))
    ]

    if len(clean) < SUMMARY_THRESHOLD:
        return {**state, "node_log": log}

    recent_text = "\n".join(
        f"{'User' if isinstance(m, HumanMessage) else 'Assistant'}: {m.content[:400]}"
        for m in clean[-SUMMARY_THRESHOLD:]
    )

    try:
        summary = llm.invoke([HumanMessage(content=(
            "Summarise this conversation in 2-3 sentences.\n"
            "Include ONLY factual topics discussed (concepts, tools, questions answered).\n"
            "Do NOT include any violent, harmful, or sensitive content in the summary.\n"
            "If the conversation contains harmful topics, summarise only the safe parts.\n\n"
            + recent_text
        ))]).content
        print(f"[MEMORY] Summary updated.")
        print(f"[MEMORY] Summary : {summary}")
        return {**state, "memory_summary": summary, "node_log": log}
    except Exception as e:
        print(f"[MEMORY] Summarisation failed: {e}")
        return {**state, "node_log": log}