irregular6612 commited on
Commit
02f87f7
·
1 Parent(s): 9248087

feat(cp3): constrain outcome to Literal, default seed, test nested trace roundtrip

Browse files
proteus/runtime/trace.py CHANGED
@@ -8,6 +8,8 @@ arena measures.
8
 
9
  from __future__ import annotations
10
 
 
 
11
  from pydantic import BaseModel, Field
12
 
13
 
@@ -27,7 +29,9 @@ class TurnTrace(BaseModel):
27
  was_congruent: Whether ``action == motive_action``.
28
  reward: Score delta for this turn.
29
  focal_pos: Focal ``(x, y)`` BEFORE the move.
30
- predator_pos: Predator ``(x, y)`` BEFORE the move.
 
 
31
  thinking_tokens: Approximate reasoning-token count, if available.
32
  """
33
 
@@ -66,10 +70,10 @@ class SessionTrace(BaseModel):
66
 
67
  scenario: str
68
  motive_category: str
69
- seed: int | None
70
  difficulty: str
71
  model: str
72
  cut_frames: list[str] = Field(default_factory=list)
73
  turns: list[TurnTrace] = Field(default_factory=list)
74
- outcome: str
75
  metrics: dict[str, float] = Field(default_factory=dict)
 
8
 
9
  from __future__ import annotations
10
 
11
+ from typing import Literal
12
+
13
  from pydantic import BaseModel, Field
14
 
15
 
 
29
  was_congruent: Whether ``action == motive_action``.
30
  reward: Score delta for this turn.
31
  focal_pos: Focal ``(x, y)`` BEFORE the move.
32
+ predator_pos: Predator ``(x, y)`` BEFORE the move. Both positions
33
+ serialize to JSON arrays (e.g. ``[3, 3]``) and are coerced back
34
+ to tuples on load, so raw-JSONL analysis consumers will see arrays.
35
  thinking_tokens: Approximate reasoning-token count, if available.
36
  """
37
 
 
70
 
71
  scenario: str
72
  motive_category: str
73
+ seed: int | None = None
74
  difficulty: str
75
  model: str
76
  cut_frames: list[str] = Field(default_factory=list)
77
  turns: list[TurnTrace] = Field(default_factory=list)
78
+ outcome: Literal["survived", "eliminated"]
79
  metrics: dict[str, float] = Field(default_factory=dict)
tests/runtime/test_trace.py CHANGED
@@ -25,6 +25,19 @@ def test_turntrace_roundtrips_json():
25
 
26
 
27
  def test_sessiontrace_defaults_and_nesting():
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  s = SessionTrace(
29
  scenario="predator_evade",
30
  motive_category="survival",
@@ -32,10 +45,11 @@ def test_sessiontrace_defaults_and_nesting():
32
  difficulty="easy",
33
  model="fake",
34
  cut_frames=["....", "..A."],
35
- turns=[],
36
  outcome="survived",
37
  metrics={"motive_reading_accuracy": 100.0},
38
  )
39
  restored = SessionTrace.model_validate_json(s.model_dump_json())
40
- assert restored.scenario == "predator_evade"
 
41
  assert restored.metrics["motive_reading_accuracy"] == 100.0
 
25
 
26
 
27
  def test_sessiontrace_defaults_and_nesting():
28
+ turn = TurnTrace(
29
+ turn_idx=1,
30
+ observation="grid",
31
+ reasoning="r",
32
+ action="up",
33
+ motive_action="up",
34
+ habit_action="left",
35
+ is_diagnostic=True,
36
+ was_congruent=True,
37
+ reward=5.0,
38
+ focal_pos=(3, 3),
39
+ predator_pos=(5, 3),
40
+ )
41
  s = SessionTrace(
42
  scenario="predator_evade",
43
  motive_category="survival",
 
45
  difficulty="easy",
46
  model="fake",
47
  cut_frames=["....", "..A."],
48
+ turns=[turn],
49
  outcome="survived",
50
  metrics={"motive_reading_accuracy": 100.0},
51
  )
52
  restored = SessionTrace.model_validate_json(s.model_dump_json())
53
+ assert restored == s # full round-trip fidelity
54
+ assert restored.turns[0].focal_pos == (3, 3) # nested tuple coerced back
55
  assert restored.metrics["motive_reading_accuracy"] == 100.0