Aksel Joonas Reedi commited on
Commit
59b2038
·
unverified ·
1 Parent(s): e8ed637

Preserve thinking state across tool turns (#143)

Browse files

* Preserve thinking state across tool turns

Anthropic thinking responses need their thinking_blocks and reasoning_content replayed with assistant tool-call messages. The loop was rebuilding assistant history from only content and tool calls, causing LiteLLM to strip thinking on continuation turns.

Constraint: Non-thinking providers and responses without reasoning fields must keep the existing message shape.

Rejected: Disable extended thinking for tool-using runs | avoids the warning by removing the feature that improves reasoning quality.

Confidence: high

Scope-risk: moderate

Directive: Any future assistant-message reconstruction must preserve provider reasoning fields when present.

Tested: UV_CACHE_DIR=/tmp/uv-cache uv run --extra dev pytest tests/unit/test_thinking_history.py tests/unit/test_dangling_tool_calls.py tests/unit/test_malformed_args_recovery.py

* Replay thinking metadata only for Anthropic

Review caught that reasoning_content is not safe to echo through OpenAI-compatible schemas such as the HF router. Gate replay and streaming chunk rebuilding to direct Anthropic models, where thinking metadata is required for tool continuations.

Constraint: HF router and OpenAI-compatible providers reject reasoning_content in assistant history.

Rejected: Preserve reasoning_content for all providers | reproduces the schema rejection already avoided in the research loop.

Confidence: high

Scope-risk: moderate

Tested: UV_CACHE_DIR=/tmp/uv-cache uv run --extra dev pytest tests/unit/test_thinking_history.py tests/unit/test_dangling_tool_calls.py tests/unit/test_malformed_args_recovery.py

agent/core/agent_loop.py CHANGED
@@ -8,8 +8,14 @@ import logging
8
  import os
9
  import time
10
  from dataclasses import dataclass, field
11
-
12
- from litellm import ChatCompletionMessageToolCall, Message, acompletion
 
 
 
 
 
 
13
  from litellm.exceptions import ContextWindowExceededError
14
 
15
  from agent.config import Config
@@ -396,6 +402,43 @@ class LLMResult:
396
  token_count: int
397
  finish_reason: str | None
398
  usage: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
 
401
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
@@ -448,8 +491,10 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
448
  token_count = 0
449
  finish_reason = None
450
  final_usage_chunk = None
 
451
 
452
  async for chunk in response:
 
453
  if session.is_cancelled:
454
  tool_calls_acc.clear()
455
  break
@@ -498,6 +543,16 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
498
  latency_ms=int((time.monotonic() - t_start) * 1000),
499
  finish_reason=finish_reason,
500
  )
 
 
 
 
 
 
 
 
 
 
501
 
502
  return LLMResult(
503
  content=full_content or None,
@@ -505,6 +560,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
505
  token_count=token_count,
506
  finish_reason=finish_reason,
507
  usage=usage,
 
 
508
  )
509
 
510
 
@@ -557,6 +614,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
557
  content = message.content or None
558
  finish_reason = choice.finish_reason
559
  token_count = response.usage.total_tokens if response.usage else 0
 
560
 
561
  # Build tool_calls_acc in the same format as streaming
562
  tool_calls_acc: dict[int, dict] = {}
@@ -591,6 +649,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
591
  token_count=token_count,
592
  finish_reason=finish_reason,
593
  usage=usage,
 
 
594
  )
595
 
596
 
@@ -754,7 +814,10 @@ class Handlers:
754
  " • For other tools: reduce the size of your arguments or use bash."
755
  )
756
  if content:
757
- assistant_msg = Message(role="assistant", content=content)
 
 
 
758
  session.context_manager.add_message(assistant_msg, token_count)
759
  session.context_manager.add_message(
760
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
@@ -810,7 +873,10 @@ class Handlers:
810
  (content or "")[:500],
811
  )
812
  if content:
813
- assistant_msg = Message(role="assistant", content=content)
 
 
 
814
  session.context_manager.add_message(assistant_msg, token_count)
815
  final_response = content
816
  break
@@ -832,9 +898,9 @@ class Handlers:
832
  bad_tools.append(tc)
833
 
834
  # Add assistant message with all tool calls to context
835
- assistant_msg = Message(
836
- role="assistant",
837
- content=content,
838
  tool_calls=tool_calls,
839
  )
840
  session.context_manager.add_message(assistant_msg, token_count)
 
8
  import os
9
  import time
10
  from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from litellm import (
14
+ ChatCompletionMessageToolCall,
15
+ Message,
16
+ acompletion,
17
+ stream_chunk_builder,
18
+ )
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
 
402
  token_count: int
403
  finish_reason: str | None
404
  usage: dict = field(default_factory=dict)
405
+ thinking_blocks: list[dict[str, Any]] | None = None
406
+ reasoning_content: str | None = None
407
+
408
+
409
+ def _extract_thinking_state(
410
+ message: Any,
411
+ ) -> tuple[list[dict[str, Any]] | None, str | None]:
412
+ """Return provider reasoning fields that must be replayed after tool calls."""
413
+ thinking_blocks = getattr(message, "thinking_blocks", None) or None
414
+ reasoning_content = getattr(message, "reasoning_content", None) or None
415
+ return thinking_blocks, reasoning_content
416
+
417
+
418
+ def _should_replay_thinking_state(model_name: str | None) -> bool:
419
+ """Only Anthropic's native adapter accepts replayed thinking metadata."""
420
+ return bool(model_name and model_name.startswith("anthropic/"))
421
+
422
+
423
+ def _assistant_message_from_result(
424
+ llm_result: LLMResult,
425
+ *,
426
+ model_name: str | None,
427
+ tool_calls: list[ToolCall] | None = None,
428
+ ) -> Message:
429
+ """Build an assistant history message without dropping reasoning state."""
430
+ kwargs: dict[str, Any] = {
431
+ "role": "assistant",
432
+ "content": llm_result.content,
433
+ }
434
+ if tool_calls is not None:
435
+ kwargs["tool_calls"] = tool_calls
436
+ if _should_replay_thinking_state(model_name):
437
+ if llm_result.thinking_blocks:
438
+ kwargs["thinking_blocks"] = llm_result.thinking_blocks
439
+ if llm_result.reasoning_content:
440
+ kwargs["reasoning_content"] = llm_result.reasoning_content
441
+ return Message(**kwargs)
442
 
443
 
444
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
 
491
  token_count = 0
492
  finish_reason = None
493
  final_usage_chunk = None
494
+ chunks = []
495
 
496
  async for chunk in response:
497
+ chunks.append(chunk)
498
  if session.is_cancelled:
499
  tool_calls_acc.clear()
500
  break
 
543
  latency_ms=int((time.monotonic() - t_start) * 1000),
544
  finish_reason=finish_reason,
545
  )
546
+ thinking_blocks = None
547
+ reasoning_content = None
548
+ if chunks and _should_replay_thinking_state(llm_params.get("model")):
549
+ try:
550
+ rebuilt = stream_chunk_builder(chunks, messages=messages)
551
+ if rebuilt and getattr(rebuilt, "choices", None):
552
+ rebuilt_msg = rebuilt.choices[0].message
553
+ thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg)
554
+ except Exception:
555
+ logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
556
 
557
  return LLMResult(
558
  content=full_content or None,
 
560
  token_count=token_count,
561
  finish_reason=finish_reason,
562
  usage=usage,
563
+ thinking_blocks=thinking_blocks,
564
+ reasoning_content=reasoning_content,
565
  )
566
 
567
 
 
614
  content = message.content or None
615
  finish_reason = choice.finish_reason
616
  token_count = response.usage.total_tokens if response.usage else 0
617
+ thinking_blocks, reasoning_content = _extract_thinking_state(message)
618
 
619
  # Build tool_calls_acc in the same format as streaming
620
  tool_calls_acc: dict[int, dict] = {}
 
649
  token_count=token_count,
650
  finish_reason=finish_reason,
651
  usage=usage,
652
+ thinking_blocks=thinking_blocks,
653
+ reasoning_content=reasoning_content,
654
  )
655
 
656
 
 
814
  " • For other tools: reduce the size of your arguments or use bash."
815
  )
816
  if content:
817
+ assistant_msg = _assistant_message_from_result(
818
+ llm_result,
819
+ model_name=llm_params.get("model"),
820
+ )
821
  session.context_manager.add_message(assistant_msg, token_count)
822
  session.context_manager.add_message(
823
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
 
873
  (content or "")[:500],
874
  )
875
  if content:
876
+ assistant_msg = _assistant_message_from_result(
877
+ llm_result,
878
+ model_name=llm_params.get("model"),
879
+ )
880
  session.context_manager.add_message(assistant_msg, token_count)
881
  final_response = content
882
  break
 
898
  bad_tools.append(tc)
899
 
900
  # Add assistant message with all tool calls to context
901
+ assistant_msg = _assistant_message_from_result(
902
+ llm_result,
903
+ model_name=llm_params.get("model"),
904
  tool_calls=tool_calls,
905
  )
906
  session.context_manager.add_message(assistant_msg, token_count)
tests/unit/test_thinking_history.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pytest
4
+ from litellm import ChatCompletionMessageToolCall, Message
5
+
6
+ from agent.core import agent_loop
7
+ from agent.core.agent_loop import (
8
+ LLMResult,
9
+ _call_llm_streaming,
10
+ _assistant_message_from_result,
11
+ _extract_thinking_state,
12
+ )
13
+
14
+
15
+ def test_extract_thinking_state_from_litellm_message():
16
+ message = Message(
17
+ role="assistant",
18
+ content="working",
19
+ thinking_blocks=[{"type": "thinking", "thinking": "reasoned"}],
20
+ reasoning_content="reasoned",
21
+ )
22
+
23
+ thinking_blocks, reasoning_content = _extract_thinking_state(message)
24
+
25
+ assert thinking_blocks == [{"type": "thinking", "thinking": "reasoned"}]
26
+ assert reasoning_content == "reasoned"
27
+
28
+
29
+ def test_assistant_message_from_result_preserves_thinking_with_tool_calls():
30
+ tool_call = ChatCompletionMessageToolCall(
31
+ id="call_1",
32
+ type="function",
33
+ function={"name": "bash", "arguments": '{"command": "date"}'},
34
+ )
35
+ result = LLMResult(
36
+ content=None,
37
+ tool_calls_acc={},
38
+ token_count=12,
39
+ finish_reason="tool_calls",
40
+ thinking_blocks=[{"type": "thinking", "thinking": "reasoned"}],
41
+ reasoning_content="reasoned",
42
+ )
43
+
44
+ message = _assistant_message_from_result(
45
+ result,
46
+ model_name="anthropic/claude-opus-4-6",
47
+ tool_calls=[tool_call],
48
+ )
49
+
50
+ assert message.tool_calls == [tool_call]
51
+ assert message.thinking_blocks == [{"type": "thinking", "thinking": "reasoned"}]
52
+ assert message.reasoning_content == "reasoned"
53
+
54
+
55
+ def test_assistant_message_from_result_strips_non_anthropic_reasoning_content():
56
+ result = LLMResult(
57
+ content=None,
58
+ tool_calls_acc={},
59
+ token_count=12,
60
+ finish_reason="tool_calls",
61
+ thinking_blocks=[{"type": "thinking", "thinking": "reasoned"}],
62
+ reasoning_content="reasoned",
63
+ )
64
+
65
+ message = _assistant_message_from_result(
66
+ result,
67
+ model_name="openai/Qwen/Qwen3-Next-80B-A3B-Instruct",
68
+ )
69
+
70
+ assert getattr(message, "thinking_blocks", None) is None
71
+ assert getattr(message, "reasoning_content", None) is None
72
+
73
+
74
+ def test_assistant_message_from_result_omits_absent_thinking_fields():
75
+ result = LLMResult(
76
+ content="done",
77
+ tool_calls_acc={},
78
+ token_count=12,
79
+ finish_reason="stop",
80
+ )
81
+
82
+ message = _assistant_message_from_result(
83
+ result,
84
+ model_name="anthropic/claude-opus-4-6",
85
+ )
86
+
87
+ assert message.content == "done"
88
+ assert getattr(message, "thinking_blocks", None) is None
89
+ assert getattr(message, "reasoning_content", None) is None
90
+
91
+
92
+ @pytest.mark.asyncio
93
+ async def test_streaming_call_rebuilds_anthropic_thinking_state(monkeypatch):
94
+ async def fake_stream():
95
+ yield SimpleNamespace(
96
+ choices=[
97
+ SimpleNamespace(
98
+ delta=SimpleNamespace(content="done", tool_calls=None),
99
+ finish_reason="stop",
100
+ )
101
+ ],
102
+ )
103
+ yield SimpleNamespace(choices=[], usage=SimpleNamespace(total_tokens=3))
104
+
105
+ async def fake_acompletion(**_kwargs):
106
+ return fake_stream()
107
+
108
+ def fake_chunk_builder(chunks, **_kwargs):
109
+ assert len(chunks) == 2
110
+ return SimpleNamespace(
111
+ choices=[
112
+ SimpleNamespace(
113
+ message=Message(
114
+ role="assistant",
115
+ content="done",
116
+ thinking_blocks=[{"type": "thinking", "thinking": "reasoned"}],
117
+ reasoning_content="reasoned",
118
+ )
119
+ )
120
+ ]
121
+ )
122
+
123
+ events = []
124
+ async def send_event(event):
125
+ events.append(event)
126
+
127
+ session = SimpleNamespace(
128
+ config=SimpleNamespace(model_name="anthropic/claude-opus-4-6"),
129
+ is_cancelled=False,
130
+ send_event=send_event,
131
+ )
132
+ monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
133
+ monkeypatch.setattr(agent_loop, "stream_chunk_builder", fake_chunk_builder)
134
+
135
+ result = await _call_llm_streaming(
136
+ session,
137
+ messages=[Message(role="user", content="hi")],
138
+ tools=[],
139
+ llm_params={"model": "anthropic/claude-opus-4-6"},
140
+ )
141
+
142
+ assert result.content == "done"
143
+ assert result.thinking_blocks == [{"type": "thinking", "thinking": "reasoned"}]
144
+ assert result.reasoning_content == "reasoned"
145
+
146
+
147
+ @pytest.mark.asyncio
148
+ async def test_streaming_call_skips_chunk_rebuild_for_non_anthropic(monkeypatch):
149
+ async def fake_stream():
150
+ yield SimpleNamespace(
151
+ choices=[
152
+ SimpleNamespace(
153
+ delta=SimpleNamespace(content="done", tool_calls=None),
154
+ finish_reason="stop",
155
+ )
156
+ ],
157
+ )
158
+
159
+ async def fake_acompletion(**_kwargs):
160
+ return fake_stream()
161
+
162
+ def fail_chunk_builder(*_args, **_kwargs):
163
+ raise AssertionError("stream_chunk_builder should not run")
164
+
165
+ events = []
166
+ async def send_event(event):
167
+ events.append(event)
168
+
169
+ session = SimpleNamespace(
170
+ config=SimpleNamespace(model_name="openai/Qwen/Qwen3"),
171
+ is_cancelled=False,
172
+ send_event=send_event,
173
+ )
174
+ monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
175
+ monkeypatch.setattr(agent_loop, "stream_chunk_builder", fail_chunk_builder)
176
+
177
+ result = await _call_llm_streaming(
178
+ session,
179
+ messages=[Message(role="user", content="hi")],
180
+ tools=[],
181
+ llm_params={"model": "openai/Qwen/Qwen3"},
182
+ )
183
+
184
+ assert result.content == "done"
185
+ assert result.thinking_blocks is None
186
+ assert result.reasoning_content is None