File size: 23,557 Bytes
d19137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
"""Core RecallTrace environment with deterministic action execution."""

from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, Tuple

from env.models import EnvironmentState, InspectionEvidence, RecallAction, RecallObservation, RewardSignal, StepInfo, TaskDefinition
from scenario.scenario import build_scenario, list_task_specs


class RecallTraceEnv:
    """Deterministic OpenEnv-style environment for product recall containment."""

    ACTIONS = [
        "inspect_node",
        "trace_lot",
        "quarantine",
        "notify",
        "finalize",
    ]

    def __init__(
        self,
        scenario_data: Dict[str, Any] | None = None,
        task_id: str | None = None,
        phase: int | None = 1,
    ):
        self._scenario_template = deepcopy(scenario_data) if scenario_data is not None else build_scenario(task_id=task_id, phase=phase)
        self.task = self._build_task_definition(self._scenario_template)
        self.state_data: Dict[str, Any] = {}
        self.ground_truth: Dict[str, Any] = {}
        self.done = False
        self.last_reward = RewardSignal(value=0.0, reason="Environment initialized.", components={})

    @classmethod
    def available_tasks(cls) -> list[TaskDefinition]:
        return [TaskDefinition(**task_spec) for task_spec in list_task_specs()]

    def reset(self, task_id: str | None = None, phase: int | None = None) -> RecallObservation:
        """Start a new deterministic scenario and recompute ground truth."""
        if task_id is not None or phase is not None:
            self._scenario_template = build_scenario(task_id=task_id, phase=phase)
            self.task = self._build_task_definition(self._scenario_template)

        self.done = False
        self.last_reward = RewardSignal(value=0.0, reason="Episode reset.", components={})

        scenario = deepcopy(self._scenario_template)
        self.state_data = {
            "task_id": scenario["task_id"],
            "phase": scenario["phase"],
            "recall_notice": scenario["recall_notice"],
            "contaminated_lot_hint": scenario["contaminated_lot"],
            "shipment_graph": scenario["shipment_graph"],
            "lot_catalog": scenario["lot_catalog"],
            "nodes": scenario["nodes"],
            "history": [],
            "discovered_shipments": {},
            "inspected_nodes": set(),
            "inspection_results": {},
            "traced_lots": {},
            "notified_nodes": set(),
            "quarantine_log": [],
            "steps_taken": 0,
            "max_steps": scenario["max_steps"],
        }
        self.ground_truth = self._build_ground_truth(scenario)
        return self._get_observation()

    def step(self, action: RecallAction | Dict[str, Any]) -> Tuple[RecallObservation, float, bool, Dict[str, Any]]:
        """Execute an action and return observation, reward, done, info."""
        if self.done:
            return self._get_observation(), 0.0, True, {
                "message": "Environment already finalized.",
                "action_type": "noop",
                "reward_breakdown": {},
            }

        validated_action = action if isinstance(action, RecallAction) else RecallAction.model_validate(action)
        self.state_data["steps_taken"] += 1

        handler = getattr(self, f"_handle_{validated_action.type.value}")
        reward_signal, info = handler(validated_action)
        self.last_reward = reward_signal

        if not self.done and self.state_data["steps_taken"] >= self.state_data["max_steps"]:
            self.done = True
            timeout_penalty = -0.25
            reward_signal = RewardSignal(
                value=max(-1.0, reward_signal.value + timeout_penalty),
                reason="Step budget exhausted before finalizing containment.",
                components={**reward_signal.components, "timeout_penalty": timeout_penalty},
            )
            info = {
                **info,
                "message": "Step budget exhausted before finalizing containment.",
                "reward_breakdown": reward_signal.components,
            }
            self._record_history("Episode terminated after exhausting the step budget")
            self.last_reward = reward_signal

        return self._get_observation(), reward_signal.value, self.done, info

    def state(self) -> EnvironmentState:
        """Return the full internal state for debugging and graders."""
        return EnvironmentState(
            done=self.done,
            task=self.task,
            steps_taken=self.state_data.get("steps_taken", 0),
            state_data=deepcopy(self._serialize_state(self.state_data)),
            ground_truth=deepcopy(self.ground_truth),
        )

    def _get_observation(self) -> RecallObservation:
        return RecallObservation(
            task_id=self.state_data["task_id"],
            phase=self.state_data["phase"],
            recall_notice=self.state_data["recall_notice"],
            available_actions=list(self.ACTIONS),
            inventory=self._inventory_snapshot(),
            discovered_shipments=deepcopy(self.state_data["discovered_shipments"]),
            inspected_nodes=sorted(self.state_data["inspected_nodes"]),
            inspection_results=deepcopy(self.state_data["inspection_results"]),
            trace_results=deepcopy(self.state_data["traced_lots"]),
            notified_nodes=sorted(self.state_data["notified_nodes"]),
            quarantined_inventory=self._quarantine_snapshot(),
            history=list(self.state_data["history"]),
            steps_taken=self.state_data["steps_taken"],
            remaining_step_budget=max(0, self.state_data["max_steps"] - self.state_data["steps_taken"]),
        )

    def _handle_inspect_node(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
        node_id = self._require_node(action.node_id)
        node = self.state_data["nodes"][node_id]
        repeated = node_id in self.state_data["inspected_nodes"]

        self.state_data["inspected_nodes"].add(node_id)
        self.state_data["discovered_shipments"][node_id] = list(self.state_data["shipment_graph"].get(node_id, []))
        findings = {
            lot_id: InspectionEvidence.model_validate(payload)
            for lot_id, payload in node.get("inspection_findings", {}).items()
        }
        self.state_data["inspection_results"][node_id] = findings
        self._record_history(f"Inspected node {node_id}")

        unsafe_total = sum(item.unsafe_quantity for item in findings.values())
        value = -0.03 if repeated else 0.08 + min(0.12, unsafe_total / 500.0)
        reason = "Repeated inspection provided no new information." if repeated else "Inspection revealed inventory evidence."
        reward = RewardSignal(
            value=round(value, 4),
            reason=reason,
            components={
                "inspection_value": round(value, 4),
            },
        )
        info = StepInfo(
            message=f"Inspected node {node_id} and collected node evidence.",
            action_type=action.type.value,
            reward_breakdown=reward.components,
        ).model_dump()
        info.update(
            {
                "node_id": node_id,
                "inventory": deepcopy(node["inventory"]),
                "quarantined_inventory": deepcopy(node["quarantined_inventory"]),
                "outbound_shipments": list(self.state_data["shipment_graph"].get(node_id, [])),
                "inspection_findings": {lot_id: item.model_dump() for lot_id, item in findings.items()},
            }
        )
        return reward, info

    def _handle_trace_lot(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
        lot_id = action.lot_id
        if not lot_id:
            raise ValueError("trace_lot action requires 'lot_id'.")

        traced_lots = self._resolve_related_lots(lot_id)
        impacted_nodes = []
        impacted_quantities = {}
        impacted_lots = {}
        discovered_nodes = 0

        for node_id, node_data in self.state_data["nodes"].items():
            node_total = 0
            node_lots = []
            for candidate_lot in traced_lots:
                available_qty = node_data["inventory"].get(candidate_lot, 0)
                quarantined_qty = node_data["quarantined_inventory"].get(candidate_lot, 0)
                total_qty = available_qty + quarantined_qty
                if total_qty > 0:
                    node_total += total_qty
                    node_lots.append(candidate_lot)
            if node_total > 0:
                impacted_nodes.append(node_id)
                impacted_quantities[node_id] = node_total
                impacted_lots[node_id] = node_lots
                if node_id not in self.state_data["discovered_shipments"]:
                    discovered_nodes += 1

        self.state_data["traced_lots"][lot_id] = {
            "root_lot": self._root_lot_for(lot_id),
            "matched_lots": sorted(traced_lots),
            "affected_nodes": impacted_nodes,
            "lots_by_node": impacted_lots,
            "quantities_by_node": impacted_quantities,
        }
        self._record_history(f"Traced lot {lot_id} across {', '.join(sorted(traced_lots))}")

        if not impacted_nodes:
            reward_value = -0.1
            reason = "Trace returned no impacted nodes."
        elif self._root_lot_for(lot_id) in self.ground_truth["affected_roots"]:
            reward_value = 0.12 + min(0.13, discovered_nodes * 0.03 + len(traced_lots) * 0.02)
            reason = "Trace identified the affected lineage across the network."
        else:
            reward_value = 0.02
            reason = "Trace ran, but the lot is outside the affected lineage."

        reward = RewardSignal(
            value=round(reward_value, 4),
            reason=reason,
            components={
                "trace_value": round(reward_value, 4),
            },
        )
        info = StepInfo(
            message=f"Traced lot {lot_id} across the shipment network.",
            action_type=action.type.value,
            reward_breakdown=reward.components,
        ).model_dump()
        info.update(
            {
                "lot_id": lot_id,
                "root_lot": self._root_lot_for(lot_id),
                "matched_lots": sorted(traced_lots),
                "affected_nodes": impacted_nodes,
                "lots_by_node": impacted_lots,
                "quantities_by_node": impacted_quantities,
                "total_quantity": sum(impacted_quantities.values()),
            }
        )
        return reward, info

    def _handle_quarantine(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
        node_id = self._require_node(action.node_id)
        lot_id = action.lot_id
        if not lot_id:
            raise ValueError("quarantine action requires 'lot_id'.")

        node = self.state_data["nodes"][node_id]
        available_qty = node["inventory"].get(lot_id, 0)
        if available_qty <= 0:
            reward = RewardSignal(
                value=-0.2,
                reason="Attempted to quarantine stock that is not available.",
                components={"invalid_quarantine": -0.2},
            )
            self._record_history(f"Failed quarantine for {lot_id} at {node_id}: no available stock")
            info = StepInfo(
                message="No available stock to quarantine.",
                action_type=action.type.value,
                reward_breakdown=reward.components,
            ).model_dump()
            info.update({"node_id": node_id, "lot_id": lot_id})
            return reward, info

        requested_qty = action.quantity or available_qty
        quarantined_qty = min(requested_qty, available_qty)
        node["inventory"][lot_id] = available_qty - quarantined_qty
        if node["inventory"][lot_id] == 0:
            del node["inventory"][lot_id]
        node["quarantined_inventory"][lot_id] = node["quarantined_inventory"].get(lot_id, 0) + quarantined_qty

        self.state_data["quarantine_log"].append({"node_id": node_id, "lot_id": lot_id, "quantity": quarantined_qty})
        self._record_history(f"Quarantined {quarantined_qty} units of {lot_id} at {node_id}")

        correct_qty = self.ground_truth["correct_quantities"].get(node_id, {}).get(lot_id, 0)
        cumulative_quarantined = node["quarantined_inventory"].get(lot_id, 0)
        delta = cumulative_quarantined - correct_qty

        if correct_qty == 0:
            reward_value = -0.35
            reason = "Quarantined safe inventory outside the recall scope."
        elif delta == 0:
            reward_value = 0.28
            reason = "Quarantine exactly matched the unsafe quantity."
        elif delta < 0:
            reward_value = max(0.05, 0.22 * (cumulative_quarantined / correct_qty))
            reason = "Quarantine made partial progress but missed some unsafe stock."
        else:
            reward_value = max(-0.25, -0.08 * delta)
            reason = "Quarantine overreached and blocked safe inventory."

        reward = RewardSignal(
            value=round(reward_value, 4),
            reason=reason,
            components={
                "quarantine_value": round(reward_value, 4),
                "target_quantity": float(correct_qty),
                "quarantined_quantity": float(cumulative_quarantined),
            },
        )
        info = StepInfo(
            message=f"Updated quarantine for {lot_id} at {node_id}.",
            action_type=action.type.value,
            reward_breakdown=reward.components,
        ).model_dump()
        info.update(
            {
                "node_id": node_id,
                "lot_id": lot_id,
                "quarantined_quantity": quarantined_qty,
                "remaining_inventory": node["inventory"].get(lot_id, 0),
                "cumulative_quarantined": cumulative_quarantined,
                "target_contaminated_quantity": correct_qty,
            }
        )
        return reward, info

    def _handle_notify(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
        requested_target = action.node_id or "all"
        if requested_target in ("all", "all_nodes"):
            targets = list(self.state_data["nodes"].keys())
        else:
            targets = [self._require_node(requested_target)]

        newly_notified = []
        for node_id in targets:
            if node_id not in self.state_data["notified_nodes"]:
                self.state_data["notified_nodes"].add(node_id)
                newly_notified.append(node_id)

        affected_newly_notified = sum(1 for node_id in newly_notified if node_id in self.ground_truth["affected_nodes"])
        unaffected_newly_notified = len(newly_notified) - affected_newly_notified

        if not newly_notified:
            reward_value = -0.05
            reason = "Notification repeated without adding new recipients."
        else:
            reward_value = min(0.18, affected_newly_notified * 0.04) - unaffected_newly_notified * 0.01
            reason = "Notifications dispatched to downstream stakeholders."

        reward = RewardSignal(
            value=round(reward_value, 4),
            reason=reason,
            components={
                "notification_value": round(reward_value, 4),
            },
        )
        if newly_notified:
            self._record_history(f"Sent notifications to {', '.join(newly_notified)}")
        else:
            self._record_history("Notification action repeated without new recipients")

        info = StepInfo(
            message="Processed notification action.",
            action_type=action.type.value,
            reward_breakdown=reward.components,
        ).model_dump()
        info.update({"notified_nodes": targets, "newly_notified": newly_notified})
        return reward, info

    def _handle_finalize(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
        del action
        self.done = True
        quarantine_match = self._compute_quarantine_match()

        missing_quantity_total = sum(
            quantity
            for lot_quantities in quarantine_match["missing_quantities"].values()
            for quantity in lot_quantities.values()
        )
        over_quantity_total = sum(
            quantity
            for lot_quantities in quarantine_match["over_quarantined_quantities"].values()
            for quantity in lot_quantities.values()
        )
        total_affected_quantity = self.ground_truth["total_affected_quantity"] or 1
        quarantine_score = max(0.0, 1.0 - ((missing_quantity_total + (1.25 * over_quantity_total)) / total_affected_quantity))

        notified_affected_nodes = set(self.ground_truth["affected_nodes"]).intersection(self.state_data["notified_nodes"])
        affected_node_total = len(self.ground_truth["affected_nodes"]) or 1
        notification_score = len(notified_affected_nodes) / affected_node_total

        investigated_nodes = set(self.state_data["inspected_nodes"]).intersection(self.ground_truth["affected_nodes"])
        investigation_score = len(investigated_nodes) / affected_node_total

        efficiency_penalty_steps = max(0, self.state_data["steps_taken"] - max(4, affected_node_total + 3))
        efficiency_score = max(0.0, 1.0 - (efficiency_penalty_steps / self.state_data["max_steps"]))

        score = round(
            (0.55 * quarantine_score) + (0.2 * notification_score) + (0.15 * investigation_score) + (0.1 * efficiency_score),
            4,
        )

        reward = RewardSignal(
            value=score,
            reason="Final recall response scored.",
            components={
                "quarantine_score": round(quarantine_score, 4),
                "notification_score": round(notification_score, 4),
                "investigation_score": round(investigation_score, 4),
                "efficiency_score": round(efficiency_score, 4),
            },
        )
        self._record_history("Finalized recall response")

        info = StepInfo(
            message="Finalized recall response.",
            action_type="finalize",
            score=score,
            reward_breakdown=reward.components,
        ).model_dump()
        info.update(
            {
                "score": score,
                "quarantine_score": round(quarantine_score, 4),
                "notification_score": round(notification_score, 4),
                "investigation_score": round(investigation_score, 4),
                "efficiency_score": round(efficiency_score, 4),
                "all_affected_nodes_notified": notification_score == 1.0,
                "all_affected_stock_quarantined": missing_quantity_total == 0 and over_quantity_total == 0,
                "quarantine_match": quarantine_match,
            }
        )
        return reward, info

    def _build_ground_truth(self, scenario: Dict[str, Any]) -> Dict[str, Any]:
        contaminated_roots = {
            self._root_lot_for(lot_id, scenario["lot_catalog"])
            for lot_id, lot_data in scenario["lot_catalog"].items()
            if lot_data.get("contaminated", False)
        }

        correct_quantities: Dict[str, Dict[str, int]] = {}
        affected_nodes = set()
        affected_lots = set()

        for node_id, node_data in scenario["nodes"].items():
            for lot_id, finding in node_data.get("inspection_findings", {}).items():
                unsafe_quantity = int(finding.get("unsafe_quantity", 0))
                if unsafe_quantity > 0:
                    affected_nodes.add(node_id)
                    affected_lots.add(lot_id)
                    correct_quantities.setdefault(node_id, {})[lot_id] = unsafe_quantity

        total_affected_quantity = sum(
            quantity
            for node_quantities in correct_quantities.values()
            for quantity in node_quantities.values()
        )
        return {
            "affected_lots": sorted(affected_lots),
            "affected_nodes": sorted(affected_nodes),
            "affected_roots": sorted(contaminated_roots),
            "correct_quantities": correct_quantities,
            "total_affected_quantity": total_affected_quantity,
        }

    def _compute_quarantine_match(self) -> Dict[str, Any]:
        missing_quantities: Dict[str, Dict[str, int]] = {}
        over_quarantined_quantities: Dict[str, Dict[str, int]] = {}

        for node_id, node_data in self.state_data["nodes"].items():
            expected = self.ground_truth["correct_quantities"].get(node_id, {})
            actual = node_data["quarantined_inventory"]
            relevant_lots = set(expected) | set(actual)

            for lot_id in relevant_lots:
                expected_qty = expected.get(lot_id, 0)
                actual_qty = actual.get(lot_id, 0)
                if actual_qty < expected_qty:
                    missing_quantities.setdefault(node_id, {})[lot_id] = expected_qty - actual_qty
                elif actual_qty > expected_qty:
                    over_quarantined_quantities.setdefault(node_id, {})[lot_id] = actual_qty - expected_qty

        return {
            "missing_quantities": missing_quantities,
            "over_quarantined_quantities": over_quarantined_quantities,
        }

    def _inventory_snapshot(self) -> Dict[str, Dict[str, int]]:
        return {node_id: deepcopy(node_data["inventory"]) for node_id, node_data in self.state_data["nodes"].items()}

    def _quarantine_snapshot(self) -> Dict[str, Dict[str, int]]:
        return {
            node_id: deepcopy(node_data["quarantined_inventory"])
            for node_id, node_data in self.state_data["nodes"].items()
            if node_data["quarantined_inventory"]
        }

    def _resolve_related_lots(self, lot_id: str) -> set[str]:
        root_lot = self._root_lot_for(lot_id)
        return {
            candidate_lot
            for candidate_lot in self.state_data["lot_catalog"].keys()
            if self._root_lot_for(candidate_lot) == root_lot or candidate_lot == lot_id
        }

    def _root_lot_for(self, lot_id: str, lot_catalog: Dict[str, Dict[str, Any]] | None = None) -> str:
        catalog = lot_catalog or self.state_data.get("lot_catalog", {})
        if lot_id not in catalog:
            return lot_id
        return catalog[lot_id].get("root_lot", lot_id)

    def _build_task_definition(self, scenario: Dict[str, Any]) -> TaskDefinition:
        return TaskDefinition(
            task_id=scenario["task_id"],
            name=scenario["name"],
            difficulty=scenario["difficulty"],
            objective=scenario["objective"],
            max_steps=scenario["max_steps"],
        )

    def _require_node(self, node_id: str | None) -> str:
        if not node_id:
            raise ValueError("Action requires 'node_id'.")
        if node_id not in self.state_data["nodes"]:
            raise ValueError(f"Unknown node_id '{node_id}'.")
        return node_id

    def _record_history(self, message: str) -> None:
        self.state_data["history"].append(message)

    def _serialize_state(self, value: Any) -> Any:
        if isinstance(value, dict):
            return {key: self._serialize_state(item) for key, item in value.items()}
        if isinstance(value, set):
            return sorted(value)
        if isinstance(value, list):
            return [self._serialize_state(item) for item in value]
        if hasattr(value, "model_dump"):
            return value.model_dump()
        return value