File size: 13,423 Bytes
f209a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Optional, Sequence

from agent_base.model_profiles import ModelProfile
from agent_base.utils import safe_jsonable


COMPACT_MEMORY_PREFIX = (
    "Runtime memory summary from earlier turns.\n"
    "This is compressed context, not ground truth.\n"
    "The workspace files remain authoritative; re-read any file if exact details matter.\n\n"
)


@dataclass
class CompactionOutcome:
    status: str
    compacted_messages: list[dict[str, Any]]
    summary_text: str = ""
    error: str = ""
    trigger_reason: str = ""
    prior_token_estimate: int = 0
    new_token_estimate: int = 0
    compacted_group_count: int = 0
    kept_group_count: int = 0
    existing_memory_text: str = ""
    summary_request: list[dict[str, Any]] | None = None
    summary_response: dict[str, Any] | None = None
    pre_messages: list[dict[str, Any]] | None = None
    post_messages: list[dict[str, Any]] | None = None


def should_compact_messages(
    *,
    last_input_tokens: Optional[int],
    current_token_estimate: int,
    model_profile: ModelProfile,
) -> tuple[bool, str]:
    usage_hit = last_input_tokens is not None and int(last_input_tokens) >= model_profile.compact_trigger_tokens
    estimate_hit = current_token_estimate >= model_profile.compact_trigger_tokens
    if usage_hit and estimate_hit:
        return True, "usage+estimate"
    if usage_hit:
        return True, "usage"
    if estimate_hit:
        return True, "estimate"
    return False, ""


def compact_messages(
    *,
    messages: Sequence[dict[str, Any]],
    original_prompt_text: str,
    model_name: str,
    model_profile: ModelProfile,
    llm_caller: Callable[..., dict[str, Any]],
    token_counter: Callable[[Sequence[dict[str, Any]]], int],
    runtime_deadline: Optional[float] = None,
) -> CompactionOutcome:
    safe_messages = [dict(message) for message in messages]
    if len(safe_messages) <= 2:
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            pre_messages=safe_messages,
            post_messages=safe_messages,
            error="context compaction requires at least one conversational turn beyond the initial prompt",
        )

    prior_token_estimate = token_counter(safe_messages)
    existing_memory_text, eligible_messages = _split_existing_memory_messages(safe_messages[2:])
    turn_groups = _turn_groups(eligible_messages)
    if not turn_groups:
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            prior_token_estimate=prior_token_estimate,
            existing_memory_text=existing_memory_text,
            pre_messages=safe_messages,
            post_messages=safe_messages,
            error="context compaction found no eligible conversational turns",
        )

    compacted_groups, recent_groups = _split_turn_groups(turn_groups, model_profile)
    if not compacted_groups:
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            prior_token_estimate=prior_token_estimate,
            existing_memory_text=existing_memory_text,
            pre_messages=safe_messages,
            post_messages=safe_messages,
            error="context compaction did not find any older turns to summarize",
        )

    history_text = _render_history_text(compacted_groups, model_profile)
    prior_memory_block = ""
    if existing_memory_text:
        prior_memory_block = (
            "Previously compressed memory to preserve and refine:\n"
            f"{_truncate_summary_text(existing_memory_text, max_chars=max(1200, model_profile.context_window // 3))}\n\n"
        )
    summary_request = [
        {
            "role": "system",
            "content": (
                "You compress older tool-using agent history into short working memory for continued execution. "
                "Return plain text only. Do not call tools. Do not invent facts."
            ),
        },
        {
            "role": "user",
            "content": (
                "Summarize the earlier conversation history for a tool-using agent.\n\n"
                f"Original task:\n{original_prompt_text}\n\n"
                "Write a concise working memory with these sections:\n"
                "- Goal\n"
                "- Constraints\n"
                "- Files and artifacts\n"
                "- Evidence and results\n"
                "- Open issues\n"
                "- Next useful actions\n\n"
                "Rules:\n"
                "- Prefer concrete file paths, numeric results, and grounded facts.\n"
                "- Mention uncertainty when details may need to be re-read from files.\n"
                "- Merge any prior compressed memory with the newer history below into one refreshed memory.\n"
                "- Deduplicate repeated sections and do not repeat earlier summaries verbatim.\n"
                "- The workspace remains authoritative.\n\n"
                f"{prior_memory_block}"
                f"Older history to compress:\n{history_text}"
            ),
        },
    ]
    summary_reply = llm_caller(
        summary_request,
        runtime_deadline=runtime_deadline,
        max_output_tokens=model_profile.compact_summary_max_tokens,
    )
    if not isinstance(summary_reply, dict) or summary_reply.get("status") != "ok":
        error = summary_reply.get("error", "context compaction summary call failed") if isinstance(summary_reply, dict) else str(summary_reply)
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            prior_token_estimate=prior_token_estimate,
            existing_memory_text=existing_memory_text,
            summary_request=summary_request,
            summary_response=safe_jsonable(summary_reply) if isinstance(summary_reply, dict) else {"status": "error", "error": error},
            pre_messages=safe_messages,
            post_messages=safe_messages,
            error=error,
            compacted_group_count=len(compacted_groups),
            kept_group_count=len(recent_groups),
        )

    if summary_reply.get("tool_calls"):
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            prior_token_estimate=prior_token_estimate,
            existing_memory_text=existing_memory_text,
            summary_request=summary_request,
            summary_response=safe_jsonable(summary_reply),
            pre_messages=safe_messages,
            post_messages=safe_messages,
            compacted_group_count=len(compacted_groups),
            kept_group_count=len(recent_groups),
            error="context compaction summary call returned tool calls",
        )

    summary_text = str(summary_reply.get("content", "") or "").strip()
    if not summary_text:
        return CompactionOutcome(
            status="error",
            compacted_messages=safe_messages,
            prior_token_estimate=prior_token_estimate,
            existing_memory_text=existing_memory_text,
            summary_request=summary_request,
            summary_response=safe_jsonable(summary_reply),
            pre_messages=safe_messages,
            post_messages=safe_messages,
            compacted_group_count=len(compacted_groups),
            kept_group_count=len(recent_groups),
            error="context compaction summary call returned empty text",
        )

    summary_message = {"role": "user", "content": COMPACT_MEMORY_PREFIX + summary_text}
    compacted_messages = safe_messages[:2] + [summary_message]
    for group in recent_groups:
        compacted_messages.extend(group)
    new_token_estimate = token_counter(compacted_messages)
    return CompactionOutcome(
        status="ok",
        compacted_messages=compacted_messages,
        summary_text=summary_text,
        prior_token_estimate=prior_token_estimate,
        new_token_estimate=new_token_estimate,
        compacted_group_count=len(compacted_groups),
        kept_group_count=len(recent_groups),
        existing_memory_text=existing_memory_text,
        summary_request=summary_request,
        summary_response=safe_jsonable(summary_reply),
        pre_messages=safe_messages,
        post_messages=compacted_messages,
    )


def _turn_groups(messages: Sequence[dict[str, Any]]) -> list[list[dict[str, Any]]]:
    groups: list[list[dict[str, Any]]] = []
    current_group: list[dict[str, Any]] = []
    for message in messages:
        role = str(message.get("role", ""))
        if role == "assistant" and current_group:
            groups.append(current_group)
            current_group = [message]
            continue
        current_group.append(message)
    if current_group:
        groups.append(current_group)
    return groups


def _split_existing_memory_messages(messages: Sequence[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
    existing_summaries: list[str] = []
    remaining_messages: list[dict[str, Any]] = []
    preserving_summary_prefix = True
    for message in messages:
        content = message.get("content", "")
        if (
            preserving_summary_prefix
            and str(message.get("role", "")) == "user"
            and isinstance(content, str)
            and content.startswith(COMPACT_MEMORY_PREFIX)
        ):
            existing_summaries.append(content[len(COMPACT_MEMORY_PREFIX) :].strip())
            continue
        preserving_summary_prefix = False
        remaining_messages.append(dict(message))
    merged_summary = "\n\n".join(summary for summary in existing_summaries if summary).strip()
    return merged_summary, remaining_messages


def _split_turn_groups(turn_groups: Sequence[Sequence[dict[str, Any]]], model_profile: ModelProfile) -> tuple[list[list[dict[str, Any]]], list[list[dict[str, Any]]]]:
    recent_char_budget = max(400, model_profile.recent_history_budget_tokens * 4)
    recent_groups: list[list[dict[str, Any]]] = []
    recent_chars = 0

    for group in reversed(turn_groups):
        rendered = _render_group(group, max_chars_per_message=240)
        if recent_groups and recent_chars >= recent_char_budget:
            break
        recent_groups.insert(0, [dict(message) for message in group])
        recent_chars += len(rendered)
        if len(recent_groups) >= 4:
            break

    if len(recent_groups) >= len(turn_groups):
        recent_groups = recent_groups[1:]
    compacted_count = max(0, len(turn_groups) - len(recent_groups))
    compacted_groups = [[dict(message) for message in group] for group in turn_groups[:compacted_count]]
    return compacted_groups, recent_groups


def _render_history_text(turn_groups: Sequence[Sequence[dict[str, Any]]], model_profile: ModelProfile) -> str:
    max_history_chars = max(600, min(64000, model_profile.context_window * 2))
    max_chars_per_message = max(200, min(4000, max_history_chars // 10))
    parts: list[str] = []
    used = 0
    for index, group in enumerate(turn_groups, start=1):
        rendered = f"[Turn group {index}]\n{_render_group(group, max_chars_per_message=max_chars_per_message)}"
        if parts and used + len(rendered) > max_history_chars:
            remaining = max_history_chars - used
            if remaining > 80:
                parts.append(rendered[: remaining - 40].rstrip() + "\n...[history truncated]")
            break
        parts.append(rendered)
        used += len(rendered)
    return "\n\n".join(parts).strip()


def _render_group(group: Sequence[dict[str, Any]], *, max_chars_per_message: int) -> str:
    lines: list[str] = []
    for message in group:
        role = str(message.get("role", ""))
        content = _message_excerpt(message, max_chars=max_chars_per_message)
        lines.append(f"{role}: {content}")
    return "\n".join(lines).strip()


def _message_excerpt(message: dict[str, Any], *, max_chars: int) -> str:
    content = message.get("content", "")
    text: str
    if isinstance(content, str):
        text = content
    elif isinstance(content, list):
        parts: list[str] = []
        for part in content:
            if isinstance(part, dict) and part.get("type") == "text":
                parts.append(str(part.get("text", "")))
            elif isinstance(part, dict) and part.get("type") == "image_url":
                parts.append("[image_url]")
            else:
                parts.append(str(part))
        text = " ".join(part for part in parts if part)
    else:
        text = str(content)
    tool_calls = message.get("tool_calls")
    if tool_calls:
        tool_names = []
        for tool_call in tool_calls:
            function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {}
            tool_names.append(str(function_block.get("name", "")))
        if tool_names:
            text = (text + "\nTool calls: " + ", ".join(name for name in tool_names if name)).strip()
    compacted = " ".join(text.split())
    if len(compacted) <= max_chars:
        return compacted
    return compacted[: max_chars - 16].rstrip() + "...[truncated]"


def _truncate_summary_text(text: str, *, max_chars: int) -> str:
    compacted = " ".join(str(text).split())
    if len(compacted) <= max_chars:
        return compacted
    return compacted[: max_chars - 16].rstrip() + "...[truncated]"