| """Core RecallTrace environment with deterministic action execution."""
|
|
|
| from __future__ import annotations
|
|
|
| from copy import deepcopy
|
| from typing import Any, Dict, Tuple
|
|
|
| from env.models import EnvironmentState, InspectionEvidence, RecallAction, RecallObservation, RewardSignal, StepInfo, TaskDefinition
|
| from scenario.scenario import build_scenario, list_task_specs
|
|
|
|
|
| class RecallTraceEnv:
|
| """Deterministic OpenEnv-style environment for product recall containment."""
|
|
|
| ACTIONS = [
|
| "inspect_node",
|
| "trace_lot",
|
| "quarantine",
|
| "notify",
|
| "finalize",
|
| ]
|
|
|
| def __init__(
|
| self,
|
| scenario_data: Dict[str, Any] | None = None,
|
| task_id: str | None = None,
|
| phase: int | None = 1,
|
| ):
|
| self._scenario_template = deepcopy(scenario_data) if scenario_data is not None else build_scenario(task_id=task_id, phase=phase)
|
| self.task = self._build_task_definition(self._scenario_template)
|
| self.state_data: Dict[str, Any] = {}
|
| self.ground_truth: Dict[str, Any] = {}
|
| self.done = False
|
| self.last_reward = RewardSignal(value=0.0, reason="Environment initialized.", components={})
|
|
|
| @classmethod
|
| def available_tasks(cls) -> list[TaskDefinition]:
|
| return [TaskDefinition(**task_spec) for task_spec in list_task_specs()]
|
|
|
| def reset(self, task_id: str | None = None, phase: int | None = None) -> RecallObservation:
|
| """Start a new deterministic scenario and recompute ground truth."""
|
| if task_id is not None or phase is not None:
|
| self._scenario_template = build_scenario(task_id=task_id, phase=phase)
|
| self.task = self._build_task_definition(self._scenario_template)
|
|
|
| self.done = False
|
| self.last_reward = RewardSignal(value=0.0, reason="Episode reset.", components={})
|
|
|
| scenario = deepcopy(self._scenario_template)
|
| self.state_data = {
|
| "task_id": scenario["task_id"],
|
| "phase": scenario["phase"],
|
| "recall_notice": scenario["recall_notice"],
|
| "contaminated_lot_hint": scenario["contaminated_lot"],
|
| "shipment_graph": scenario["shipment_graph"],
|
| "lot_catalog": scenario["lot_catalog"],
|
| "nodes": scenario["nodes"],
|
| "history": [],
|
| "discovered_shipments": {},
|
| "inspected_nodes": set(),
|
| "inspection_results": {},
|
| "traced_lots": {},
|
| "notified_nodes": set(),
|
| "quarantine_log": [],
|
| "steps_taken": 0,
|
| "max_steps": scenario["max_steps"],
|
| }
|
| self.ground_truth = self._build_ground_truth(scenario)
|
| return self._get_observation()
|
|
|
| def step(self, action: RecallAction | Dict[str, Any]) -> Tuple[RecallObservation, float, bool, Dict[str, Any]]:
|
| """Execute an action and return observation, reward, done, info."""
|
| if self.done:
|
| return self._get_observation(), 0.0, True, {
|
| "message": "Environment already finalized.",
|
| "action_type": "noop",
|
| "reward_breakdown": {},
|
| }
|
|
|
| validated_action = action if isinstance(action, RecallAction) else RecallAction.model_validate(action)
|
| self.state_data["steps_taken"] += 1
|
|
|
| handler = getattr(self, f"_handle_{validated_action.type.value}")
|
| reward_signal, info = handler(validated_action)
|
| self.last_reward = reward_signal
|
|
|
| if not self.done and self.state_data["steps_taken"] >= self.state_data["max_steps"]:
|
| self.done = True
|
| timeout_penalty = -0.25
|
| reward_signal = RewardSignal(
|
| value=max(-1.0, reward_signal.value + timeout_penalty),
|
| reason="Step budget exhausted before finalizing containment.",
|
| components={**reward_signal.components, "timeout_penalty": timeout_penalty},
|
| )
|
| info = {
|
| **info,
|
| "message": "Step budget exhausted before finalizing containment.",
|
| "reward_breakdown": reward_signal.components,
|
| }
|
| self._record_history("Episode terminated after exhausting the step budget")
|
| self.last_reward = reward_signal
|
|
|
| return self._get_observation(), reward_signal.value, self.done, info
|
|
|
| def state(self) -> EnvironmentState:
|
| """Return the full internal state for debugging and graders."""
|
| return EnvironmentState(
|
| done=self.done,
|
| task=self.task,
|
| steps_taken=self.state_data.get("steps_taken", 0),
|
| state_data=deepcopy(self._serialize_state(self.state_data)),
|
| ground_truth=deepcopy(self.ground_truth),
|
| )
|
|
|
| def _get_observation(self) -> RecallObservation:
|
| return RecallObservation(
|
| task_id=self.state_data["task_id"],
|
| phase=self.state_data["phase"],
|
| recall_notice=self.state_data["recall_notice"],
|
| available_actions=list(self.ACTIONS),
|
| inventory=self._inventory_snapshot(),
|
| discovered_shipments=deepcopy(self.state_data["discovered_shipments"]),
|
| inspected_nodes=sorted(self.state_data["inspected_nodes"]),
|
| inspection_results=deepcopy(self.state_data["inspection_results"]),
|
| trace_results=deepcopy(self.state_data["traced_lots"]),
|
| notified_nodes=sorted(self.state_data["notified_nodes"]),
|
| quarantined_inventory=self._quarantine_snapshot(),
|
| history=list(self.state_data["history"]),
|
| steps_taken=self.state_data["steps_taken"],
|
| remaining_step_budget=max(0, self.state_data["max_steps"] - self.state_data["steps_taken"]),
|
| )
|
|
|
| def _handle_inspect_node(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
|
| node_id = self._require_node(action.node_id)
|
| node = self.state_data["nodes"][node_id]
|
| repeated = node_id in self.state_data["inspected_nodes"]
|
|
|
| self.state_data["inspected_nodes"].add(node_id)
|
| self.state_data["discovered_shipments"][node_id] = list(self.state_data["shipment_graph"].get(node_id, []))
|
| findings = {
|
| lot_id: InspectionEvidence.model_validate(payload)
|
| for lot_id, payload in node.get("inspection_findings", {}).items()
|
| }
|
| self.state_data["inspection_results"][node_id] = findings
|
| self._record_history(f"Inspected node {node_id}")
|
|
|
| unsafe_total = sum(item.unsafe_quantity for item in findings.values())
|
| value = -0.03 if repeated else 0.08 + min(0.12, unsafe_total / 500.0)
|
| reason = "Repeated inspection provided no new information." if repeated else "Inspection revealed inventory evidence."
|
| reward = RewardSignal(
|
| value=round(value, 4),
|
| reason=reason,
|
| components={
|
| "inspection_value": round(value, 4),
|
| },
|
| )
|
| info = StepInfo(
|
| message=f"Inspected node {node_id} and collected node evidence.",
|
| action_type=action.type.value,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update(
|
| {
|
| "node_id": node_id,
|
| "inventory": deepcopy(node["inventory"]),
|
| "quarantined_inventory": deepcopy(node["quarantined_inventory"]),
|
| "outbound_shipments": list(self.state_data["shipment_graph"].get(node_id, [])),
|
| "inspection_findings": {lot_id: item.model_dump() for lot_id, item in findings.items()},
|
| }
|
| )
|
| return reward, info
|
|
|
| def _handle_trace_lot(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
|
| lot_id = action.lot_id
|
| if not lot_id:
|
| raise ValueError("trace_lot action requires 'lot_id'.")
|
|
|
| traced_lots = self._resolve_related_lots(lot_id)
|
| impacted_nodes = []
|
| impacted_quantities = {}
|
| impacted_lots = {}
|
| discovered_nodes = 0
|
|
|
| for node_id, node_data in self.state_data["nodes"].items():
|
| node_total = 0
|
| node_lots = []
|
| for candidate_lot in traced_lots:
|
| available_qty = node_data["inventory"].get(candidate_lot, 0)
|
| quarantined_qty = node_data["quarantined_inventory"].get(candidate_lot, 0)
|
| total_qty = available_qty + quarantined_qty
|
| if total_qty > 0:
|
| node_total += total_qty
|
| node_lots.append(candidate_lot)
|
| if node_total > 0:
|
| impacted_nodes.append(node_id)
|
| impacted_quantities[node_id] = node_total
|
| impacted_lots[node_id] = node_lots
|
| if node_id not in self.state_data["discovered_shipments"]:
|
| discovered_nodes += 1
|
|
|
| self.state_data["traced_lots"][lot_id] = {
|
| "root_lot": self._root_lot_for(lot_id),
|
| "matched_lots": sorted(traced_lots),
|
| "affected_nodes": impacted_nodes,
|
| "lots_by_node": impacted_lots,
|
| "quantities_by_node": impacted_quantities,
|
| }
|
| self._record_history(f"Traced lot {lot_id} across {', '.join(sorted(traced_lots))}")
|
|
|
| if not impacted_nodes:
|
| reward_value = -0.1
|
| reason = "Trace returned no impacted nodes."
|
| elif self._root_lot_for(lot_id) in self.ground_truth["affected_roots"]:
|
| reward_value = 0.12 + min(0.13, discovered_nodes * 0.03 + len(traced_lots) * 0.02)
|
| reason = "Trace identified the affected lineage across the network."
|
| else:
|
| reward_value = 0.02
|
| reason = "Trace ran, but the lot is outside the affected lineage."
|
|
|
| reward = RewardSignal(
|
| value=round(reward_value, 4),
|
| reason=reason,
|
| components={
|
| "trace_value": round(reward_value, 4),
|
| },
|
| )
|
| info = StepInfo(
|
| message=f"Traced lot {lot_id} across the shipment network.",
|
| action_type=action.type.value,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update(
|
| {
|
| "lot_id": lot_id,
|
| "root_lot": self._root_lot_for(lot_id),
|
| "matched_lots": sorted(traced_lots),
|
| "affected_nodes": impacted_nodes,
|
| "lots_by_node": impacted_lots,
|
| "quantities_by_node": impacted_quantities,
|
| "total_quantity": sum(impacted_quantities.values()),
|
| }
|
| )
|
| return reward, info
|
|
|
| def _handle_quarantine(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
|
| node_id = self._require_node(action.node_id)
|
| lot_id = action.lot_id
|
| if not lot_id:
|
| raise ValueError("quarantine action requires 'lot_id'.")
|
|
|
| node = self.state_data["nodes"][node_id]
|
| available_qty = node["inventory"].get(lot_id, 0)
|
| if available_qty <= 0:
|
| reward = RewardSignal(
|
| value=-0.2,
|
| reason="Attempted to quarantine stock that is not available.",
|
| components={"invalid_quarantine": -0.2},
|
| )
|
| self._record_history(f"Failed quarantine for {lot_id} at {node_id}: no available stock")
|
| info = StepInfo(
|
| message="No available stock to quarantine.",
|
| action_type=action.type.value,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update({"node_id": node_id, "lot_id": lot_id})
|
| return reward, info
|
|
|
| requested_qty = action.quantity or available_qty
|
| quarantined_qty = min(requested_qty, available_qty)
|
| node["inventory"][lot_id] = available_qty - quarantined_qty
|
| if node["inventory"][lot_id] == 0:
|
| del node["inventory"][lot_id]
|
| node["quarantined_inventory"][lot_id] = node["quarantined_inventory"].get(lot_id, 0) + quarantined_qty
|
|
|
| self.state_data["quarantine_log"].append({"node_id": node_id, "lot_id": lot_id, "quantity": quarantined_qty})
|
| self._record_history(f"Quarantined {quarantined_qty} units of {lot_id} at {node_id}")
|
|
|
| correct_qty = self.ground_truth["correct_quantities"].get(node_id, {}).get(lot_id, 0)
|
| cumulative_quarantined = node["quarantined_inventory"].get(lot_id, 0)
|
| delta = cumulative_quarantined - correct_qty
|
|
|
| if correct_qty == 0:
|
| reward_value = -0.35
|
| reason = "Quarantined safe inventory outside the recall scope."
|
| elif delta == 0:
|
| reward_value = 0.28
|
| reason = "Quarantine exactly matched the unsafe quantity."
|
| elif delta < 0:
|
| reward_value = max(0.05, 0.22 * (cumulative_quarantined / correct_qty))
|
| reason = "Quarantine made partial progress but missed some unsafe stock."
|
| else:
|
| reward_value = max(-0.25, -0.08 * delta)
|
| reason = "Quarantine overreached and blocked safe inventory."
|
|
|
| reward = RewardSignal(
|
| value=round(reward_value, 4),
|
| reason=reason,
|
| components={
|
| "quarantine_value": round(reward_value, 4),
|
| "target_quantity": float(correct_qty),
|
| "quarantined_quantity": float(cumulative_quarantined),
|
| },
|
| )
|
| info = StepInfo(
|
| message=f"Updated quarantine for {lot_id} at {node_id}.",
|
| action_type=action.type.value,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update(
|
| {
|
| "node_id": node_id,
|
| "lot_id": lot_id,
|
| "quarantined_quantity": quarantined_qty,
|
| "remaining_inventory": node["inventory"].get(lot_id, 0),
|
| "cumulative_quarantined": cumulative_quarantined,
|
| "target_contaminated_quantity": correct_qty,
|
| }
|
| )
|
| return reward, info
|
|
|
| def _handle_notify(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
|
| requested_target = action.node_id or "all"
|
| if requested_target in ("all", "all_nodes"):
|
| targets = list(self.state_data["nodes"].keys())
|
| else:
|
| targets = [self._require_node(requested_target)]
|
|
|
| newly_notified = []
|
| for node_id in targets:
|
| if node_id not in self.state_data["notified_nodes"]:
|
| self.state_data["notified_nodes"].add(node_id)
|
| newly_notified.append(node_id)
|
|
|
| affected_newly_notified = sum(1 for node_id in newly_notified if node_id in self.ground_truth["affected_nodes"])
|
| unaffected_newly_notified = len(newly_notified) - affected_newly_notified
|
|
|
| if not newly_notified:
|
| reward_value = -0.05
|
| reason = "Notification repeated without adding new recipients."
|
| else:
|
| reward_value = min(0.18, affected_newly_notified * 0.04) - unaffected_newly_notified * 0.01
|
| reason = "Notifications dispatched to downstream stakeholders."
|
|
|
| reward = RewardSignal(
|
| value=round(reward_value, 4),
|
| reason=reason,
|
| components={
|
| "notification_value": round(reward_value, 4),
|
| },
|
| )
|
| if newly_notified:
|
| self._record_history(f"Sent notifications to {', '.join(newly_notified)}")
|
| else:
|
| self._record_history("Notification action repeated without new recipients")
|
|
|
| info = StepInfo(
|
| message="Processed notification action.",
|
| action_type=action.type.value,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update({"notified_nodes": targets, "newly_notified": newly_notified})
|
| return reward, info
|
|
|
| def _handle_finalize(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
|
| del action
|
| self.done = True
|
| quarantine_match = self._compute_quarantine_match()
|
|
|
| missing_quantity_total = sum(
|
| quantity
|
| for lot_quantities in quarantine_match["missing_quantities"].values()
|
| for quantity in lot_quantities.values()
|
| )
|
| over_quantity_total = sum(
|
| quantity
|
| for lot_quantities in quarantine_match["over_quarantined_quantities"].values()
|
| for quantity in lot_quantities.values()
|
| )
|
| total_affected_quantity = self.ground_truth["total_affected_quantity"] or 1
|
| quarantine_score = max(0.0, 1.0 - ((missing_quantity_total + (1.25 * over_quantity_total)) / total_affected_quantity))
|
|
|
| notified_affected_nodes = set(self.ground_truth["affected_nodes"]).intersection(self.state_data["notified_nodes"])
|
| affected_node_total = len(self.ground_truth["affected_nodes"]) or 1
|
| notification_score = len(notified_affected_nodes) / affected_node_total
|
|
|
| investigated_nodes = set(self.state_data["inspected_nodes"]).intersection(self.ground_truth["affected_nodes"])
|
| investigation_score = len(investigated_nodes) / affected_node_total
|
|
|
| efficiency_penalty_steps = max(0, self.state_data["steps_taken"] - max(4, affected_node_total + 3))
|
| efficiency_score = max(0.0, 1.0 - (efficiency_penalty_steps / self.state_data["max_steps"]))
|
|
|
| score = round(
|
| (0.55 * quarantine_score) + (0.2 * notification_score) + (0.15 * investigation_score) + (0.1 * efficiency_score),
|
| 4,
|
| )
|
|
|
| reward = RewardSignal(
|
| value=score,
|
| reason="Final recall response scored.",
|
| components={
|
| "quarantine_score": round(quarantine_score, 4),
|
| "notification_score": round(notification_score, 4),
|
| "investigation_score": round(investigation_score, 4),
|
| "efficiency_score": round(efficiency_score, 4),
|
| },
|
| )
|
| self._record_history("Finalized recall response")
|
|
|
| info = StepInfo(
|
| message="Finalized recall response.",
|
| action_type="finalize",
|
| score=score,
|
| reward_breakdown=reward.components,
|
| ).model_dump()
|
| info.update(
|
| {
|
| "score": score,
|
| "quarantine_score": round(quarantine_score, 4),
|
| "notification_score": round(notification_score, 4),
|
| "investigation_score": round(investigation_score, 4),
|
| "efficiency_score": round(efficiency_score, 4),
|
| "all_affected_nodes_notified": notification_score == 1.0,
|
| "all_affected_stock_quarantined": missing_quantity_total == 0 and over_quantity_total == 0,
|
| "quarantine_match": quarantine_match,
|
| }
|
| )
|
| return reward, info
|
|
|
| def _build_ground_truth(self, scenario: Dict[str, Any]) -> Dict[str, Any]:
|
| contaminated_roots = {
|
| self._root_lot_for(lot_id, scenario["lot_catalog"])
|
| for lot_id, lot_data in scenario["lot_catalog"].items()
|
| if lot_data.get("contaminated", False)
|
| }
|
|
|
| correct_quantities: Dict[str, Dict[str, int]] = {}
|
| affected_nodes = set()
|
| affected_lots = set()
|
|
|
| for node_id, node_data in scenario["nodes"].items():
|
| for lot_id, finding in node_data.get("inspection_findings", {}).items():
|
| unsafe_quantity = int(finding.get("unsafe_quantity", 0))
|
| if unsafe_quantity > 0:
|
| affected_nodes.add(node_id)
|
| affected_lots.add(lot_id)
|
| correct_quantities.setdefault(node_id, {})[lot_id] = unsafe_quantity
|
|
|
| total_affected_quantity = sum(
|
| quantity
|
| for node_quantities in correct_quantities.values()
|
| for quantity in node_quantities.values()
|
| )
|
| return {
|
| "affected_lots": sorted(affected_lots),
|
| "affected_nodes": sorted(affected_nodes),
|
| "affected_roots": sorted(contaminated_roots),
|
| "correct_quantities": correct_quantities,
|
| "total_affected_quantity": total_affected_quantity,
|
| }
|
|
|
| def _compute_quarantine_match(self) -> Dict[str, Any]:
|
| missing_quantities: Dict[str, Dict[str, int]] = {}
|
| over_quarantined_quantities: Dict[str, Dict[str, int]] = {}
|
|
|
| for node_id, node_data in self.state_data["nodes"].items():
|
| expected = self.ground_truth["correct_quantities"].get(node_id, {})
|
| actual = node_data["quarantined_inventory"]
|
| relevant_lots = set(expected) | set(actual)
|
|
|
| for lot_id in relevant_lots:
|
| expected_qty = expected.get(lot_id, 0)
|
| actual_qty = actual.get(lot_id, 0)
|
| if actual_qty < expected_qty:
|
| missing_quantities.setdefault(node_id, {})[lot_id] = expected_qty - actual_qty
|
| elif actual_qty > expected_qty:
|
| over_quarantined_quantities.setdefault(node_id, {})[lot_id] = actual_qty - expected_qty
|
|
|
| return {
|
| "missing_quantities": missing_quantities,
|
| "over_quarantined_quantities": over_quarantined_quantities,
|
| }
|
|
|
| def _inventory_snapshot(self) -> Dict[str, Dict[str, int]]:
|
| return {node_id: deepcopy(node_data["inventory"]) for node_id, node_data in self.state_data["nodes"].items()}
|
|
|
| def _quarantine_snapshot(self) -> Dict[str, Dict[str, int]]:
|
| return {
|
| node_id: deepcopy(node_data["quarantined_inventory"])
|
| for node_id, node_data in self.state_data["nodes"].items()
|
| if node_data["quarantined_inventory"]
|
| }
|
|
|
| def _resolve_related_lots(self, lot_id: str) -> set[str]:
|
| root_lot = self._root_lot_for(lot_id)
|
| return {
|
| candidate_lot
|
| for candidate_lot in self.state_data["lot_catalog"].keys()
|
| if self._root_lot_for(candidate_lot) == root_lot or candidate_lot == lot_id
|
| }
|
|
|
| def _root_lot_for(self, lot_id: str, lot_catalog: Dict[str, Dict[str, Any]] | None = None) -> str:
|
| catalog = lot_catalog or self.state_data.get("lot_catalog", {})
|
| if lot_id not in catalog:
|
| return lot_id
|
| return catalog[lot_id].get("root_lot", lot_id)
|
|
|
| def _build_task_definition(self, scenario: Dict[str, Any]) -> TaskDefinition:
|
| return TaskDefinition(
|
| task_id=scenario["task_id"],
|
| name=scenario["name"],
|
| difficulty=scenario["difficulty"],
|
| objective=scenario["objective"],
|
| max_steps=scenario["max_steps"],
|
| )
|
|
|
| def _require_node(self, node_id: str | None) -> str:
|
| if not node_id:
|
| raise ValueError("Action requires 'node_id'.")
|
| if node_id not in self.state_data["nodes"]:
|
| raise ValueError(f"Unknown node_id '{node_id}'.")
|
| return node_id
|
|
|
| def _record_history(self, message: str) -> None:
|
| self.state_data["history"].append(message)
|
|
|
| def _serialize_state(self, value: Any) -> Any:
|
| if isinstance(value, dict):
|
| return {key: self._serialize_state(item) for key, item in value.items()}
|
| if isinstance(value, set):
|
| return sorted(value)
|
| if isinstance(value, list):
|
| return [self._serialize_state(item) for item in value]
|
| if hasattr(value, "model_dump"):
|
| return value.model_dump()
|
| return value
|
|
|