File size: 6,496 Bytes
b8e5043
a7c4301
b8e5043
a7c4301
b8e5043
a7c4301
6d49dc7
b8e5043
a7c4301
 
 
b8e5043
a7c4301
 
 
6d49dc7
b8e5043
a7c4301
 
 
b8e5043
 
a7c4301
 
b8e5043
a7c4301
b8e5043
 
a7c4301
 
b8e5043
a7c4301
 
b8e5043
a7c4301
b8e5043
a7c4301
b8e5043
 
a7c4301
b8e5043
 
 
 
a7c4301
b8e5043
a7c4301
b8e5043
 
a7c4301
b8e5043
 
 
 
 
a7c4301
b8e5043
 
a7c4301
b8e5043
 
 
 
a7c4301
b8e5043
 
a7c4301
b8e5043
 
a7c4301
b8e5043
 
 
 
 
 
 
 
 
a7c4301
b8e5043
a7c4301
b8e5043
 
 
 
 
a7c4301
b8e5043
 
 
a7c4301
b8e5043
a7c4301
b8e5043
 
a7c4301
 
b8e5043
a7c4301
 
b8e5043
a7c4301
b8e5043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c4301
 
6d49dc7
a7c4301
6d49dc7
b8e5043
 
 
 
 
 
 
 
 
 
 
 
 
 
6d49dc7
b8e5043
6d49dc7
 
a7c4301
 
6d49dc7
 
 
 
 
 
 
 
a7c4301
b8e5043
a7c4301
b8e5043
 
 
 
 
 
a7c4301
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Context manager - AI summary compaction for long conversations."""

from __future__ import annotations

from loguru import logger

from agent.llm_client import BaseLLMClient
from utils.token_counter import estimate_messages_tokens


class ContextManager:
    """Manage context window size using AI-powered summarization."""

    def __init__(
        self,
        client: BaseLLMClient,
        model: str,
        context_window: int = 200000,
        compact_threshold: float = 0.9,
    ):
        self.client = client
        self.model = model
        self.context_window = context_window
        self.compact_threshold = compact_threshold
        self.max_tokens = int(context_window * compact_threshold)

    async def maybe_compact(self, messages: list[dict]) -> list[dict]:
        """Check if messages exceed the threshold and compact if needed.

        Args:
            messages: Current message history.

        Returns:
            Possibly compacted message history.
        """
        tokens = estimate_messages_tokens(messages)

        if tokens < self.max_tokens:
            return messages

        logger.info(
            f"Context compaction triggered: {tokens} tokens "
            f"(threshold: {self.max_tokens})"
        )

        return await self._compact(messages)

    async def _compact(self, messages: list[dict]) -> list[dict]:
        """Compact messages by summarizing older ones.

        Strategy:
        1. Split messages into two halves
        2. Summarize the first half using the LLM
        3. Keep the second half intact
        4. Prepend the summary as the first message
        """
        if len(messages) <= 2:
            return messages

        # Find the split point — keep at least the last 30% of messages intact
        keep_count = max(2, len(messages) * 3 // 10)
        to_summarize = messages[:-keep_count]
        to_keep = messages[-keep_count:]

        if not to_summarize:
            return messages

        # Build the summary
        summary = await self._summarize_messages(to_summarize)

        # Assemble new message list with summary prepended
        summary_message = {
            "role": "user",
            "content": (
                f"[Context Summary - Previous conversation compressed]\n\n"
                f"{summary}\n\n"
                f"[End of Summary - Recent conversation follows]"
            ),
        }

        compacted = [summary_message] + to_keep

        new_tokens = estimate_messages_tokens(compacted)
        logger.info(
            f"Compacted: {len(messages)} -> {len(compacted)} messages, "
            f"{estimate_messages_tokens(messages)} -> {new_tokens} tokens"
        )

        # If still over budget, compact again recursively
        if new_tokens > self.max_tokens and len(compacted) > 3:
            return await self._compact(compacted)

        return compacted

    async def _summarize_messages(self, messages: list[dict]) -> str:
        """Use the LLM to summarize a chunk of messages.

        Args:
            messages: Messages to summarize.

        Returns:
            Summary text.
        """
        # Format messages for summarization
        formatted_parts = []
        for msg in messages:
            role = msg.get("role", "unknown")
            content = msg.get("content", "")
            if isinstance(content, list):
                # Flatten content blocks
                text_parts = []
                for block in content:
                    if isinstance(block, dict):
                        if block.get("type") == "text":
                            text_parts.append(block.get("text", ""))
                        elif block.get("type") == "tool_use":
                            text_parts.append(
                                f"[Tool call: {block.get('name', '?')}({block.get('input', {})})]"
                            )
                        elif block.get("type") == "tool_result":
                            result_content = str(block.get("content", ""))
                            if len(result_content) > 500:
                                result_content = result_content[:500] + "..."
                            text_parts.append(f"[Tool result: {result_content}]")
                    else:
                        text_parts.append(str(block))
                content = "\n".join(text_parts)
            formatted_parts.append(f"**{role}**: {content}")

        conversation_text = "\n\n".join(formatted_parts)

        # Truncate if extremely long
        if len(conversation_text) > 100000:
            conversation_text = conversation_text[:100000] + "\n\n... (truncated)"

        try:
            response = await self.client.create_message(
                model=self.model,
                system_prompt=(
                    "You are a conversation summarizer. Create a concise but comprehensive "
                    "summary of the following conversation. Preserve:\n"
                    "- Key decisions made\n"
                    "- Important facts and information learned\n"
                    "- Open questions or pending tasks\n"
                    "- User preferences and constraints\n"
                    "- Tool actions taken and their results\n"
                    "- Any errors encountered and how they were resolved\n\n"
                    "Be thorough but concise. Use bullet points."
                ),
                messages=[
                    {
                        "role": "user",
                        "content": f"Summarize this conversation:\n\n{conversation_text}",
                    },
                ],
                tools=[],
                max_tokens=4000,
            )

            summary_parts: list[str] = []
            for block in response.content_blocks:
                if isinstance(block, dict) and block.get("type") == "text":
                    text = block.get("text")
                    if isinstance(text, str):
                        summary_parts.append(text)

            summary = "".join(summary_parts)

            return summary or "No summary generated."
        except Exception as e:
            logger.error(f"Summarization failed: {e}")
            # Fallback: create a basic summary
            return (
                f"[Summarization failed - {len(messages)} messages dropped]\n"
                f"Messages covered roles: "
                f"{', '.join(set(m.get('role', '?') for m in messages))}"
            )