File size: 3,577 Bytes
2f1a55a
 
 
 
 
 
 
 
 
1252cb9
 
 
 
2f1a55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1252cb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1a55a
1252cb9
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1a55a
 
 
1252cb9
 
2f1a55a
 
 
 
 
 
 
 
1252cb9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Pure transforms on agent-trace event lists. No I/O.

Supports two on-disk formats:
1. Claude-Code style — `{type: "user"|"assistant", message: {role, content}}`.
   Example dataset: `merve/ml-intern-sessions`.
2. pi-sessions style — `{type: "message", message: {role: "user"|"assistant"|"toolResult", content: [...]}}`.
   Example dataset: `julien-c/pi-sessions`. Tool calls use `toolCall` blocks;
   tool outputs come back as role=toolResult messages which we drop.
"""

from typing import Any


def event_role(ev: dict) -> str | None:
    """Normalised role of a trace event. Returns 'user' / 'assistant' or None
    for non-content events (session metadata, tool results, model_change, etc.)."""
    t = ev.get("type")
    if t in ("user", "assistant"):
        return t
    if t == "message":
        msg = ev.get("message") or {}
        role = msg.get("role")
        if role in ("user", "assistant"):
            return role
        return None
    return None


def event_tool_names(ev: dict) -> list[str]:
    """Extract tool names invoked in this event, across both formats."""
    msg = ev.get("message") or {}
    content = msg.get("content")
    if not isinstance(content, list):
        return []
    out: list[str] = []
    for block in content:
        if not isinstance(block, dict):
            continue
        bt = block.get("type")
        if bt in ("tool_use", "toolCall"):
            name = block.get("name")
            if isinstance(name, str) and name:
                out.append(name)
    return out


def _user_content_to_text(content: Any) -> str:
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for block in content:
            if not isinstance(block, dict):
                continue
            if block.get("type") == "tool_result":
                continue
            if block.get("type") == "text" and isinstance(block.get("text"), str):
                parts.append(block["text"])
            elif "content" in block and isinstance(block["content"], str) and block.get("type") != "tool_result":
                parts.append(block["content"])
        return "\n".join(parts)
    return ""


def _assistant_content_to_text(content: Any) -> str:
    """Concatenate text blocks; drop thinking / tool_use / toolCall blocks."""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for block in content:
            if isinstance(block, dict) and block.get("type") == "text" and isinstance(block.get("text"), str):
                parts.append(block["text"])
        return "".join(parts)
    return ""


def events_to_transcript(events: list[dict]) -> str:
    lines: list[str] = []
    for ev in events:
        role = event_role(ev)
        if role not in ("user", "assistant"):
            continue
        msg = ev.get("message") or {}
        content = msg.get("content")
        text = (
            _user_content_to_text(content)
            if role == "user"
            else _assistant_content_to_text(content)
        ).strip()
        if text:
            label = "User" if role == "user" else "Assistant"
            lines.append(f"{label}: {text}")
    return "\n\n".join(lines)


def truncate_transcript(text: str, max_chars: int = 40_000) -> str:
    if len(text) <= max_chars:
        return text
    head_len = max_chars // 2
    tail_len = max_chars // 4
    head = text[:head_len]
    tail = text[-tail_len:]
    return f"{head}\n\n[... truncated ...]\n\n{tail}"