ayushozha Claude Opus 4.6 commited on
Commit
2f4ed4a
·
1 Parent(s): 5f8c92c

Type EpisodeState and EpisodeLog with Protocol, ConversationEntry, RewardBreakdown (MOD 04)

Browse files

Replace loose dict fields with typed Pydantic models:
- EpisodeState.current_protocol: Optional[Protocol]
- EpisodeState.conversation_history: list[ConversationEntry]
- EpisodeLog.transcript: list[ConversationEntry]
- EpisodeLog.reward_breakdown: Optional[RewardBreakdown]

Stub server now explicitly constructs Protocol and ConversationEntry
objects instead of raw dicts. Fix WS handler info serialization.
Add 8 tests covering typed construction, JSON round-trip, and nesting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. replicalab/models.py +4 -4
  2. server/app.py +30 -28
  3. tests/test_models.py +180 -0
replicalab/models.py CHANGED
@@ -374,8 +374,8 @@ class EpisodeState(BaseModel):
374
  lab_reagents: list[str] = Field(default_factory=list)
375
  lab_staff_count: int = 0
376
  lab_time_limit_days: int = 0
377
- current_protocol: Optional[dict] = None
378
- conversation_history: list[dict] = Field(default_factory=list)
379
  round_number: int = 0
380
  max_rounds: int = 0
381
  done: bool = False
@@ -394,8 +394,8 @@ class EpisodeLog(BaseModel):
394
  scenario_template: str = ""
395
  difficulty: str = "easy"
396
  final_state: Optional[EpisodeState] = None
397
- transcript: list[dict] = Field(default_factory=list)
398
- reward_breakdown: dict = Field(default_factory=dict)
399
  total_reward: float = 0.0
400
  rounds_used: int = 0
401
  agreement_reached: bool = False
 
374
  lab_reagents: list[str] = Field(default_factory=list)
375
  lab_staff_count: int = 0
376
  lab_time_limit_days: int = 0
377
+ current_protocol: Optional[Protocol] = None
378
+ conversation_history: list[ConversationEntry] = Field(default_factory=list)
379
  round_number: int = 0
380
  max_rounds: int = 0
381
  done: bool = False
 
394
  scenario_template: str = ""
395
  difficulty: str = "easy"
396
  final_state: Optional[EpisodeState] = None
397
+ transcript: list[ConversationEntry] = Field(default_factory=list)
398
+ reward_breakdown: Optional[RewardBreakdown] = None
399
  total_reward: float = 0.0
400
  rounds_used: int = 0
401
  agreement_reached: bool = False
server/app.py CHANGED
@@ -42,10 +42,12 @@ from replicalab.config import (
42
  )
43
  from replicalab.scenarios import available_scenario_families, generate_scenario
44
  from replicalab.models import (
 
45
  EpisodeLog,
46
  EpisodeState,
47
  LabManagerObservation,
48
  Observation,
 
49
  RewardBreakdown,
50
  ScientistAction,
51
  ScientistObservation,
@@ -117,7 +119,7 @@ class _StubEnv:
117
 
118
  def __init__(self) -> None:
119
  self._state = EpisodeState()
120
- self._logs: list[dict] = []
121
  self._episode_id: str = ""
122
 
123
  # ── public interface (matches ReplicaLabEnv) ──────────────────────────
@@ -202,46 +204,46 @@ class _StubEnv:
202
 
203
  # ── internal helpers ──────────────────────────────────────────────────
204
 
205
- def _scientist_log_entry(self, action: ScientistAction) -> dict[str, Any]:
206
  action_type = (
207
  action.action_type.value
208
  if hasattr(action.action_type, "value")
209
  else str(action.action_type)
210
  )
211
  message = action.rationale or f"Scientist chose action '{action_type}'."
212
- return {
213
- "role": "scientist",
214
- "message": message,
215
- "round_number": self._state.round_number,
216
- "action_type": action_type,
217
- }
218
-
219
- def _lab_manager_log_entry(self, action: ScientistAction) -> dict[str, Any]:
220
  if action.action_type == "accept":
221
  message = "Stub review: agreement recorded and episode will close."
222
  action_type = "accept"
223
  else:
224
  message = "Stub review: proposal received and remains feasible under the stub lab."
225
  action_type = "report_feasibility"
226
- return {
227
- "role": "lab_manager",
228
- "message": message,
229
- "round_number": self._state.round_number,
230
- "action_type": action_type,
231
- }
232
-
233
- def _protocol_from_action(self, action: ScientistAction) -> dict[str, Any] | None:
234
  if action.action_type not in {"propose_protocol", "revise_protocol"}:
235
  return self._state.current_protocol
236
- return {
237
- "technique": action.technique,
238
- "sample_size": action.sample_size,
239
- "controls": list(action.controls),
240
- "duration_days": action.duration_days,
241
- "required_equipment": list(action.required_equipment),
242
- "required_reagents": list(action.required_reagents),
243
- "rationale": action.rationale,
244
- }
245
 
246
  def _make_observation(self) -> Observation:
247
  s = self._state
@@ -572,7 +574,7 @@ async def websocket_endpoint(ws: WebSocket):
572
  else None,
573
  "reward": result.reward,
574
  "done": result.done,
575
- "info": result.info,
576
  },
577
  )
578
  except Exception as exc:
 
42
  )
43
  from replicalab.scenarios import available_scenario_families, generate_scenario
44
  from replicalab.models import (
45
+ ConversationEntry,
46
  EpisodeLog,
47
  EpisodeState,
48
  LabManagerObservation,
49
  Observation,
50
+ Protocol,
51
  RewardBreakdown,
52
  ScientistAction,
53
  ScientistObservation,
 
119
 
120
  def __init__(self) -> None:
121
  self._state = EpisodeState()
122
+ self._logs: list[ConversationEntry] = []
123
  self._episode_id: str = ""
124
 
125
  # ── public interface (matches ReplicaLabEnv) ──────────────────────────
 
204
 
205
  # ── internal helpers ──────────────────────────────────────────────────
206
 
207
+ def _scientist_log_entry(self, action: ScientistAction) -> ConversationEntry:
208
  action_type = (
209
  action.action_type.value
210
  if hasattr(action.action_type, "value")
211
  else str(action.action_type)
212
  )
213
  message = action.rationale or f"Scientist chose action '{action_type}'."
214
+ return ConversationEntry(
215
+ role="scientist",
216
+ message=message,
217
+ round_number=self._state.round_number,
218
+ action_type=action_type,
219
+ )
220
+
221
+ def _lab_manager_log_entry(self, action: ScientistAction) -> ConversationEntry:
222
  if action.action_type == "accept":
223
  message = "Stub review: agreement recorded and episode will close."
224
  action_type = "accept"
225
  else:
226
  message = "Stub review: proposal received and remains feasible under the stub lab."
227
  action_type = "report_feasibility"
228
+ return ConversationEntry(
229
+ role="lab_manager",
230
+ message=message,
231
+ round_number=self._state.round_number,
232
+ action_type=action_type,
233
+ )
234
+
235
+ def _protocol_from_action(self, action: ScientistAction) -> Optional[Protocol]:
236
  if action.action_type not in {"propose_protocol", "revise_protocol"}:
237
  return self._state.current_protocol
238
+ return Protocol(
239
+ technique=action.technique,
240
+ sample_size=action.sample_size,
241
+ controls=list(action.controls),
242
+ duration_days=action.duration_days,
243
+ required_equipment=list(action.required_equipment),
244
+ required_reagents=list(action.required_reagents),
245
+ rationale=action.rationale,
246
+ )
247
 
248
  def _make_observation(self) -> Observation:
249
  s = self._state
 
574
  else None,
575
  "reward": result.reward,
576
  "done": result.done,
577
+ "info": result.info.model_dump(),
578
  },
579
  )
580
  except Exception as exc:
tests/test_models.py CHANGED
@@ -5,14 +5,19 @@ from pydantic import ValidationError
5
 
6
  from replicalab.models import (
7
  ConversationEntry,
 
 
8
  LabManagerAction,
9
  LabManagerActionType,
10
  LabManagerObservation,
11
  Observation,
12
  Protocol,
 
13
  ScientistAction,
14
  ScientistActionType,
15
  ScientistObservation,
 
 
16
  )
17
 
18
 
@@ -230,3 +235,178 @@ def test_observation_rejects_negative_budget() -> None:
230
 
231
  with pytest.raises(ValidationError):
232
  Observation.model_validate(payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from replicalab.models import (
7
  ConversationEntry,
8
+ EpisodeLog,
9
+ EpisodeState,
10
  LabManagerAction,
11
  LabManagerActionType,
12
  LabManagerObservation,
13
  Observation,
14
  Protocol,
15
+ RewardBreakdown,
16
  ScientistAction,
17
  ScientistActionType,
18
  ScientistObservation,
19
+ StepInfo,
20
+ StepResult,
21
  )
22
 
23
 
 
235
 
236
  with pytest.raises(ValidationError):
237
  Observation.model_validate(payload)
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # MOD 04 — Typed EpisodeState and EpisodeLog
242
+ # ---------------------------------------------------------------------------
243
+
244
+
245
+ def _sample_protocol() -> Protocol:
246
+ return Protocol(
247
+ sample_size=32,
248
+ controls=["vehicle_control", "positive_control"],
249
+ technique="manual_cell_counting",
250
+ duration_days=5,
251
+ required_equipment=["microscope", "co2_incubator"],
252
+ required_reagents=["dmso", "drug_x", "culture_media"],
253
+ rationale="Uses available equipment while preserving controls.",
254
+ )
255
+
256
+
257
+ def _sample_conversation_entry() -> ConversationEntry:
258
+ return ConversationEntry(
259
+ role="scientist",
260
+ message="I propose a manual counting protocol.",
261
+ round_number=1,
262
+ action_type="propose_protocol",
263
+ )
264
+
265
+
266
+ def test_episode_state_accepts_typed_protocol_and_history() -> None:
267
+ protocol = _sample_protocol()
268
+ entry = _sample_conversation_entry()
269
+ state = EpisodeState(
270
+ seed=42,
271
+ current_protocol=protocol,
272
+ conversation_history=[entry],
273
+ round_number=1,
274
+ max_rounds=6,
275
+ )
276
+
277
+ assert isinstance(state.current_protocol, Protocol)
278
+ assert state.current_protocol.technique == "manual_cell_counting"
279
+ assert isinstance(state.conversation_history[0], ConversationEntry)
280
+ assert state.conversation_history[0].role == "scientist"
281
+
282
+
283
+ def test_episode_state_accepts_none_protocol() -> None:
284
+ state = EpisodeState(current_protocol=None, conversation_history=[])
285
+ assert state.current_protocol is None
286
+ assert state.conversation_history == []
287
+
288
+
289
+ def test_episode_state_json_round_trip() -> None:
290
+ protocol = _sample_protocol()
291
+ entry = _sample_conversation_entry()
292
+ state = EpisodeState(
293
+ seed=7,
294
+ scenario_template="math_reasoning",
295
+ difficulty="hard",
296
+ paper_title="Test Paper",
297
+ current_protocol=protocol,
298
+ conversation_history=[entry],
299
+ round_number=2,
300
+ max_rounds=6,
301
+ )
302
+
303
+ dumped = state.model_dump_json()
304
+ restored = EpisodeState.model_validate_json(dumped)
305
+
306
+ assert isinstance(restored.current_protocol, Protocol)
307
+ assert restored.current_protocol.sample_size == 32
308
+ assert isinstance(restored.conversation_history[0], ConversationEntry)
309
+ assert restored.conversation_history[0].action_type == "propose_protocol"
310
+ assert restored.seed == 7
311
+
312
+
313
+ def test_episode_log_accepts_typed_fields() -> None:
314
+ entry = _sample_conversation_entry()
315
+ breakdown = RewardBreakdown(rigor=0.8, feasibility=0.7, fidelity=0.9)
316
+ log = EpisodeLog(
317
+ episode_id="ep-001",
318
+ seed=42,
319
+ transcript=[entry],
320
+ reward_breakdown=breakdown,
321
+ total_reward=5.0,
322
+ rounds_used=3,
323
+ agreement_reached=True,
324
+ )
325
+
326
+ assert isinstance(log.transcript[0], ConversationEntry)
327
+ assert isinstance(log.reward_breakdown, RewardBreakdown)
328
+ assert log.reward_breakdown.rigor == 0.8
329
+
330
+
331
+ def test_episode_log_none_reward_breakdown() -> None:
332
+ log = EpisodeLog(episode_id="ep-002")
333
+ assert log.reward_breakdown is None
334
+ assert log.transcript == []
335
+
336
+
337
+ def test_episode_log_json_round_trip() -> None:
338
+ entry = _sample_conversation_entry()
339
+ breakdown = RewardBreakdown(
340
+ rigor=0.6, feasibility=0.5, fidelity=0.7,
341
+ efficiency_bonus=0.1, communication_bonus=0.05,
342
+ penalties={"timeout": 0.02},
343
+ )
344
+ state = EpisodeState(
345
+ seed=99,
346
+ current_protocol=_sample_protocol(),
347
+ conversation_history=[entry],
348
+ round_number=3,
349
+ max_rounds=6,
350
+ done=True,
351
+ agreement_reached=True,
352
+ reward=5.0,
353
+ rigor_score=0.6,
354
+ )
355
+ log = EpisodeLog(
356
+ episode_id="ep-round-trip",
357
+ seed=99,
358
+ final_state=state,
359
+ transcript=[entry],
360
+ reward_breakdown=breakdown,
361
+ total_reward=5.0,
362
+ rounds_used=3,
363
+ agreement_reached=True,
364
+ judge_notes="Good protocol.",
365
+ verdict="accept",
366
+ )
367
+
368
+ dumped = log.model_dump_json()
369
+ restored = EpisodeLog.model_validate_json(dumped)
370
+
371
+ assert isinstance(restored.final_state, EpisodeState)
372
+ assert isinstance(restored.final_state.current_protocol, Protocol)
373
+ assert isinstance(restored.final_state.conversation_history[0], ConversationEntry)
374
+ assert isinstance(restored.transcript[0], ConversationEntry)
375
+ assert isinstance(restored.reward_breakdown, RewardBreakdown)
376
+ assert restored.reward_breakdown.penalties == {"timeout": 0.02}
377
+ assert restored.episode_id == "ep-round-trip"
378
+
379
+
380
+ def test_episode_log_nested_state_preserves_typed_fields() -> None:
381
+ protocol = _sample_protocol()
382
+ entry = _sample_conversation_entry()
383
+ state = EpisodeState(
384
+ current_protocol=protocol,
385
+ conversation_history=[entry],
386
+ )
387
+ log = EpisodeLog(final_state=state)
388
+
389
+ assert isinstance(log.final_state.current_protocol, Protocol)
390
+ assert log.final_state.current_protocol.technique == "manual_cell_counting"
391
+ assert isinstance(log.final_state.conversation_history[0], ConversationEntry)
392
+
393
+
394
+ def test_step_result_with_typed_info() -> None:
395
+ breakdown = RewardBreakdown(rigor=0.8, feasibility=0.8, fidelity=0.8)
396
+ info = StepInfo(
397
+ agreement_reached=True,
398
+ reward_breakdown=breakdown,
399
+ judge_notes="All checks passed.",
400
+ verdict="accept",
401
+ round=3,
402
+ stub=True,
403
+ )
404
+ result = StepResult(reward=5.0, done=True, info=info)
405
+
406
+ dumped = result.model_dump_json()
407
+ restored = StepResult.model_validate_json(dumped)
408
+
409
+ assert isinstance(restored.info, StepInfo)
410
+ assert isinstance(restored.info.reward_breakdown, RewardBreakdown)
411
+ assert restored.info.agreement_reached is True
412
+ assert restored.info.verdict == "accept"