Spaces:
Sleeping
Sleeping
Commit ·
02f87f7
1
Parent(s): 9248087
feat(cp3): constrain outcome to Literal, default seed, test nested trace roundtrip
Browse files- proteus/runtime/trace.py +7 -3
- tests/runtime/test_trace.py +16 -2
proteus/runtime/trace.py
CHANGED
|
@@ -8,6 +8,8 @@ arena measures.
|
|
| 8 |
|
| 9 |
from __future__ import annotations
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from pydantic import BaseModel, Field
|
| 12 |
|
| 13 |
|
|
@@ -27,7 +29,9 @@ class TurnTrace(BaseModel):
|
|
| 27 |
was_congruent: Whether ``action == motive_action``.
|
| 28 |
reward: Score delta for this turn.
|
| 29 |
focal_pos: Focal ``(x, y)`` BEFORE the move.
|
| 30 |
-
predator_pos: Predator ``(x, y)`` BEFORE the move.
|
|
|
|
|
|
|
| 31 |
thinking_tokens: Approximate reasoning-token count, if available.
|
| 32 |
"""
|
| 33 |
|
|
@@ -66,10 +70,10 @@ class SessionTrace(BaseModel):
|
|
| 66 |
|
| 67 |
scenario: str
|
| 68 |
motive_category: str
|
| 69 |
-
seed: int | None
|
| 70 |
difficulty: str
|
| 71 |
model: str
|
| 72 |
cut_frames: list[str] = Field(default_factory=list)
|
| 73 |
turns: list[TurnTrace] = Field(default_factory=list)
|
| 74 |
-
outcome:
|
| 75 |
metrics: dict[str, float] = Field(default_factory=dict)
|
|
|
|
| 8 |
|
| 9 |
from __future__ import annotations
|
| 10 |
|
| 11 |
+
from typing import Literal
|
| 12 |
+
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
|
| 15 |
|
|
|
|
| 29 |
was_congruent: Whether ``action == motive_action``.
|
| 30 |
reward: Score delta for this turn.
|
| 31 |
focal_pos: Focal ``(x, y)`` BEFORE the move.
|
| 32 |
+
predator_pos: Predator ``(x, y)`` BEFORE the move. Both positions
|
| 33 |
+
serialize to JSON arrays (e.g. ``[3, 3]``) and are coerced back
|
| 34 |
+
to tuples on load, so raw-JSONL analysis consumers will see arrays.
|
| 35 |
thinking_tokens: Approximate reasoning-token count, if available.
|
| 36 |
"""
|
| 37 |
|
|
|
|
| 70 |
|
| 71 |
scenario: str
|
| 72 |
motive_category: str
|
| 73 |
+
seed: int | None = None
|
| 74 |
difficulty: str
|
| 75 |
model: str
|
| 76 |
cut_frames: list[str] = Field(default_factory=list)
|
| 77 |
turns: list[TurnTrace] = Field(default_factory=list)
|
| 78 |
+
outcome: Literal["survived", "eliminated"]
|
| 79 |
metrics: dict[str, float] = Field(default_factory=dict)
|
tests/runtime/test_trace.py
CHANGED
|
@@ -25,6 +25,19 @@ def test_turntrace_roundtrips_json():
|
|
| 25 |
|
| 26 |
|
| 27 |
def test_sessiontrace_defaults_and_nesting():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
s = SessionTrace(
|
| 29 |
scenario="predator_evade",
|
| 30 |
motive_category="survival",
|
|
@@ -32,10 +45,11 @@ def test_sessiontrace_defaults_and_nesting():
|
|
| 32 |
difficulty="easy",
|
| 33 |
model="fake",
|
| 34 |
cut_frames=["....", "..A."],
|
| 35 |
-
turns=[],
|
| 36 |
outcome="survived",
|
| 37 |
metrics={"motive_reading_accuracy": 100.0},
|
| 38 |
)
|
| 39 |
restored = SessionTrace.model_validate_json(s.model_dump_json())
|
| 40 |
-
assert restored
|
|
|
|
| 41 |
assert restored.metrics["motive_reading_accuracy"] == 100.0
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def test_sessiontrace_defaults_and_nesting():
|
| 28 |
+
turn = TurnTrace(
|
| 29 |
+
turn_idx=1,
|
| 30 |
+
observation="grid",
|
| 31 |
+
reasoning="r",
|
| 32 |
+
action="up",
|
| 33 |
+
motive_action="up",
|
| 34 |
+
habit_action="left",
|
| 35 |
+
is_diagnostic=True,
|
| 36 |
+
was_congruent=True,
|
| 37 |
+
reward=5.0,
|
| 38 |
+
focal_pos=(3, 3),
|
| 39 |
+
predator_pos=(5, 3),
|
| 40 |
+
)
|
| 41 |
s = SessionTrace(
|
| 42 |
scenario="predator_evade",
|
| 43 |
motive_category="survival",
|
|
|
|
| 45 |
difficulty="easy",
|
| 46 |
model="fake",
|
| 47 |
cut_frames=["....", "..A."],
|
| 48 |
+
turns=[turn],
|
| 49 |
outcome="survived",
|
| 50 |
metrics={"motive_reading_accuracy": 100.0},
|
| 51 |
)
|
| 52 |
restored = SessionTrace.model_validate_json(s.model_dump_json())
|
| 53 |
+
assert restored == s # full round-trip fidelity
|
| 54 |
+
assert restored.turns[0].focal_pos == (3, 3) # nested tuple coerced back
|
| 55 |
assert restored.metrics["motive_reading_accuracy"] == 100.0
|