File size: 2,848 Bytes
f209a8f
 
 
 
 
 
 
 
 
 
 
49476b4
f209a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49476b4
 
 
 
 
f209a8f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import annotations

import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Optional, Sequence

from agent_base.model_profiles import ModelProfile
from agent_base.utils import safe_jsonable


SESSION_STATE_PREFIX = "session_state"


@dataclass
class CompactionRecord:
    turn_index: int
    status: str
    trigger_reason: str
    prior_token_estimate: int
    prior_message_count: int
    compacted_group_count: int = 0
    kept_group_count: int = 0
    new_token_estimate: Optional[int] = None
    new_message_count: Optional[int] = None
    summary_text: str = ""
    error: str = ""


@dataclass
class AgentSessionState:
    run_id: str
    model_name: str
    workspace_root: str
    prompt: str
    trace_path: str = ""
    turn_index: int = 0
    max_rounds: int = 0
    max_input_tokens: int = 0
    max_output_tokens: int = 0
    last_input_tokens: Optional[int] = None
    current_token_estimate: int = 0
    termination: str = ""
    error: str = ""
    messages: list[dict[str, Any]] = field(default_factory=list)
    compactions: list[CompactionRecord] = field(default_factory=list)
    model_profile: Optional[ModelProfile] = None

    def capture_messages(self, messages: Sequence[dict[str, Any]]) -> None:
        self.messages = safe_jsonable(list(messages))

    def payload(self) -> dict[str, Any]:
        profile = self.model_profile
        return {
            "version": 1,
            "run_id": self.run_id,
            "model_name": self.model_name,
            "workspace_root": self.workspace_root,
            "prompt": self.prompt,
            "trace_path": self.trace_path,
            "turn_index": self.turn_index,
            "max_rounds": self.max_rounds,
            "max_input_tokens": self.max_input_tokens,
            "max_output_tokens": self.max_output_tokens,
            "last_input_tokens": self.last_input_tokens,
            "current_token_estimate": self.current_token_estimate,
            "termination": self.termination,
            "error": self.error,
            "messages": self.messages,
            "compactions": [safe_jsonable(asdict(record)) for record in self.compactions],
            "model_profile": safe_jsonable(asdict(profile)) if profile is not None else None,
        }


def resolve_session_state_path(trace_path: str | Path) -> Path:
    trace = Path(trace_path)
    stem = trace.stem
    suffix = stem[len("trace_") :] if stem.startswith("trace_") else stem
    return trace.with_name(f"{SESSION_STATE_PREFIX}_{suffix}.json")


def persist_session_state(path: str | Path, state: AgentSessionState) -> None:
    output_path = Path(path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(json.dumps(state.payload(), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")