her / engine /contract.py
geekwrestler's picture
Squash history (purge pre-scrub demo session blobs)
5f43c7d
"""The Event/Turn contract — the seam between loaders and the engine.
Non-negotiable #3: the engine consumes ONLY these normalized shapes; it never
reaches into raw JSONL. Loaders (jsonl now, hf later) emit these dataclasses.
Each dataclass carries a `to_dict()` that emits the CONTRACT JSON exactly:
Event = { id, turn, role, kind, tool?, input?, resultText?, tokens?, ts, mcp? }
Turn = { i, prompt, origin, reply, ts, tools:[ToolCall], tokens:Tokens,
reqs, direct, indirect, heavy, guide? }
ToolCall= { id, name, input, summary, mcp?, provenance, sourceTool, flowValue, errored }
Tokens = { in, out, cacheRead, cacheCreate } # NB: JSON key is "in", not "in_"
Phase-1 note: provenance / direct / indirect / heavy / guide are populated by the
Phase-2 engine. The loader leaves them at their neutral defaults. NO model here.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional
# --------------------------------------------------------------------------- #
# Tokens
# --------------------------------------------------------------------------- #
# Anthropic cost weights, expressed as multiples of the base (uncached) input
# token price. They are UNIFORM across Opus / Sonnet / Haiku (output is 5x input,
# a cache write is 1.25x, a cache read is 0.1x on every Claude model), so a single
# weighted sum is a faithful, model-agnostic proxy for what the run actually costs.
# `cost()` returns "input-token-equivalents": multiply by the model's input $/token
# for dollars. cacheCreate uses the 5-minute-TTL write multiplier (Claude Code's
# default ephemeral cache); a 1-hour write would be 2.0, which the usage object
# does not let us distinguish, so 1.25 is the documented assumption.
COST_WEIGHTS = {"in": 1.0, "cacheCreate": 1.25, "cacheRead": 0.1, "out": 5.0}
@dataclass
class Tokens:
"""Token rollup. Field `in_` serializes to the JSON key "in" (a reserved word)."""
in_: int = 0
out: int = 0
cacheRead: int = 0
cacheCreate: int = 0
def cost(self) -> int:
"""Cost-weighted tokens (input-token-equivalents) — the real-money signal,
NOT raw cacheRead. cacheRead is cheap (0.1x) and re-paid every round-trip; a
turn's true cost is dominated by generation (5x). Linear in the fields, so
summing per-turn costs == cost of the summed totals (additive-safe)."""
return round(
self.in_ * COST_WEIGHTS["in"]
+ self.cacheCreate * COST_WEIGHTS["cacheCreate"]
+ self.cacheRead * COST_WEIGHTS["cacheRead"]
+ self.out * COST_WEIGHTS["out"]
)
def to_dict(self) -> dict[str, int]:
return {
"in": self.in_,
"out": self.out,
"cacheRead": self.cacheRead,
"cacheCreate": self.cacheCreate,
"cost": self.cost(), # cost-weighted total (the ranking key)
}
def add(self, other: "Tokens") -> "Tokens":
return Tokens(
in_=self.in_ + other.in_,
out=self.out + other.out,
cacheRead=self.cacheRead + other.cacheRead,
cacheCreate=self.cacheCreate + other.cacheCreate,
)
# --------------------------------------------------------------------------- #
# ToolCall
# --------------------------------------------------------------------------- #
@dataclass
class ToolCall:
"""One tool_use, linked to its tool_result.
Internal-only fields (id, result_text, ts, errored, provenance, sourceTool,
flowValue) may be carried for later phases; to_dict emits the contract fields
plus `id` (the UI needs the stable node id).
"""
name: str
input: Any
summary: str
# contract fields the engine/Phase-2 will fill; neutral defaults in Phase 1
mcp: Optional[dict[str, str]] = None
provenance: Optional[str] = None # 'direct' | 'indirect' (Phase 2)
sourceTool: Optional[str] = None # (Phase 2)
flowValue: Optional[str] = None # (Phase 2)
errored: Optional[bool] = None # (Phase 2)
# internal carriers (also useful to the UI / Phase 2)
id: Optional[str] = None
result_text: Optional[str] = None
ts: Optional[str] = None
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"input": self.input,
"summary": self.summary,
"mcp": self.mcp,
"provenance": self.provenance,
"sourceTool": self.sourceTool,
"flowValue": self.flowValue,
"errored": self.errored,
}
# --------------------------------------------------------------------------- #
# Turn
# --------------------------------------------------------------------------- #
@dataclass
class Turn:
"""A reconstructed query: one real prompt + everything it caused."""
i: int
prompt: str
origin: str # 'human' | 'system'
reply: str = ""
ts: Optional[str] = None
tools: list[ToolCall] = field(default_factory=list)
tokens: Tokens = field(default_factory=Tokens)
reqs: int = 0
# Point-in-time context-WINDOW occupancy (input + cacheRead + cacheCreate of a
# single request) — the "fuel gauge", NOT the cumulative token sums above. Bounded
# by the model's window (≤1M on Opus 4.8); a single request can never exceed it.
# Excludes sidechain (sub-agent) requests — those run in their own window. ctxStart/
# ctxEnd/ctxPeak are the first/last/max occupancy across this turn's main-thread
# requests; a sharp drop between turns signals a compaction. 0 == no usage seen.
ctxStart: int = 0
ctxPeak: int = 0
ctxEnd: int = 0
# Phase-2 (engine) fields — neutral defaults in Phase 1
direct: int = 0
indirect: int = 0
heavy: bool = False # top-N by COST (relative "heaviest" — drives the graph glow)
overBudget: bool = False # cost >= an absolute floor (every expensive turn, not just top-N)
guide: Optional[dict[str, str]] = None
def to_dict(self) -> dict[str, Any]:
out: dict[str, Any] = {
"i": self.i,
"prompt": self.prompt,
"origin": self.origin,
"reply": self.reply,
"ts": self.ts,
"tools": [t.to_dict() for t in self.tools],
"tokens": self.tokens.to_dict(),
"reqs": self.reqs,
"ctxStart": self.ctxStart,
"ctxPeak": self.ctxPeak,
"ctxEnd": self.ctxEnd,
"direct": self.direct,
"indirect": self.indirect,
"heavy": self.heavy,
"overBudget": self.overBudget,
}
# `guide` is present ONLY when a pattern fires (Phase 2). Silence otherwise.
if self.guide is not None:
out["guide"] = self.guide
return out
# --------------------------------------------------------------------------- #
# Event
# --------------------------------------------------------------------------- #
@dataclass
class Event:
"""A flat, ordered atom of the session timeline (prompt/text/tool_use/tool_result)."""
id: str
turn: int
role: str # 'user' | 'assistant'
kind: str # 'prompt' | 'tool_use' | 'tool_result' | 'text'
ts: Optional[str] = None
tool: Optional[str] = None
input: Any = None
resultText: Optional[str] = None
tokens: Optional[Tokens] = None
mcp: Optional[dict[str, str]] = None
def to_dict(self) -> dict[str, Any]:
out: dict[str, Any] = {
"id": self.id,
"turn": self.turn,
"role": self.role,
"kind": self.kind,
"ts": self.ts,
}
# optional fields — only emit when present, keeps the contract clean
if self.tool is not None:
out["tool"] = self.tool
if self.input is not None:
out["input"] = self.input
if self.resultText is not None:
out["resultText"] = self.resultText
if self.tokens is not None:
out["tokens"] = self.tokens.to_dict()
if self.mcp is not None:
out["mcp"] = self.mcp
return out
# --------------------------------------------------------------------------- #
# Helper
# --------------------------------------------------------------------------- #
def to_jsonable(obj: Any) -> Any:
"""Recursively turn contract dataclasses (and containers of them) into
plain JSON-serializable structures. Used to dump load() output to JSON."""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
if isinstance(obj, dict):
return {k: to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [to_jsonable(v) for v in obj]
return obj