| """Core RecallTrace environment with deterministic action execution.""" |
|
|
| from __future__ import annotations |
|
|
| from copy import deepcopy |
| from typing import Any, Dict, Tuple |
|
|
| from env.models import EnvironmentState, InspectionEvidence, RecallAction, RecallObservation, RewardSignal, StepInfo, TaskDefinition |
| from scenario.scenario import build_scenario, list_task_specs |
|
|
|
|
| class RecallTraceEnv: |
| """Deterministic OpenEnv-style environment for product recall containment.""" |
|
|
| ACTIONS = [ |
| "inspect_node", |
| "trace_lot", |
| "quarantine", |
| "notify", |
| "finalize", |
| ] |
|
|
| def __init__( |
| self, |
| scenario_data: Dict[str, Any] | None = None, |
| task_id: str | None = None, |
| phase: int | None = 1, |
| ): |
| self._scenario_template = deepcopy(scenario_data) if scenario_data is not None else build_scenario(task_id=task_id, phase=phase) |
| self.task = self._build_task_definition(self._scenario_template) |
| self.state_data: Dict[str, Any] = {} |
| self.ground_truth: Dict[str, Any] = {} |
| self.done = False |
| self.last_reward = RewardSignal(value=0.0, reason="Environment initialized.", components={}) |
|
|
| @classmethod |
| def available_tasks(cls) -> list[TaskDefinition]: |
| return [TaskDefinition(**task_spec) for task_spec in list_task_specs()] |
|
|
| def reset(self, task_id: str | None = None, phase: int | None = None) -> RecallObservation: |
| """Start a new deterministic scenario and recompute ground truth.""" |
| if task_id is not None or phase is not None: |
| self._scenario_template = build_scenario(task_id=task_id, phase=phase) |
| self.task = self._build_task_definition(self._scenario_template) |
|
|
| self.done = False |
| self.last_reward = RewardSignal(value=0.0, reason="Episode reset.", components={}) |
|
|
| scenario = deepcopy(self._scenario_template) |
| self.state_data = { |
| "task_id": scenario["task_id"], |
| "phase": scenario["phase"], |
| "recall_notice": scenario["recall_notice"], |
| "contaminated_lot_hint": scenario["contaminated_lot"], |
| "shipment_graph": scenario["shipment_graph"], |
| "lot_catalog": scenario["lot_catalog"], |
| "nodes": scenario["nodes"], |
| "history": [], |
| "discovered_shipments": {}, |
| "inspected_nodes": set(), |
| "inspection_results": {}, |
| "traced_lots": {}, |
| "notified_nodes": set(), |
| "quarantine_log": [], |
| "steps_taken": 0, |
| "max_steps": scenario["max_steps"], |
| } |
| self.ground_truth = self._build_ground_truth(scenario) |
| return self._get_observation() |
|
|
| def step(self, action: RecallAction | Dict[str, Any]) -> Tuple[RecallObservation, float, bool, Dict[str, Any]]: |
| """Execute an action and return observation, reward, done, info.""" |
| if self.done: |
| return self._get_observation(), 0.0, True, { |
| "message": "Environment already finalized.", |
| "action_type": "noop", |
| "reward_breakdown": {}, |
| } |
|
|
| validated_action = action if isinstance(action, RecallAction) else RecallAction.model_validate(action) |
| self.state_data["steps_taken"] += 1 |
|
|
| handler = getattr(self, f"_handle_{validated_action.type.value}") |
| reward_signal, info = handler(validated_action) |
| self.last_reward = reward_signal |
|
|
| if not self.done and self.state_data["steps_taken"] >= self.state_data["max_steps"]: |
| self.done = True |
| timeout_penalty = -0.25 |
| reward_signal = RewardSignal( |
| value=max(-1.0, reward_signal.value + timeout_penalty), |
| reason="Step budget exhausted before finalizing containment.", |
| components={**reward_signal.components, "timeout_penalty": timeout_penalty}, |
| ) |
| info = { |
| **info, |
| "message": "Step budget exhausted before finalizing containment.", |
| "reward_breakdown": reward_signal.components, |
| } |
| self._record_history("Episode terminated after exhausting the step budget") |
| self.last_reward = reward_signal |
|
|
| return self._get_observation(), reward_signal.value, self.done, info |
|
|
| def state(self) -> EnvironmentState: |
| """Return the full internal state for debugging and graders.""" |
| return EnvironmentState( |
| done=self.done, |
| task=self.task, |
| steps_taken=self.state_data.get("steps_taken", 0), |
| state_data=deepcopy(self._serialize_state(self.state_data)), |
| ground_truth=deepcopy(self.ground_truth), |
| ) |
|
|
| def _get_observation(self) -> RecallObservation: |
| return RecallObservation( |
| task_id=self.state_data["task_id"], |
| phase=self.state_data["phase"], |
| recall_notice=self.state_data["recall_notice"], |
| available_actions=list(self.ACTIONS), |
| inventory=self._inventory_snapshot(), |
| discovered_shipments=deepcopy(self.state_data["discovered_shipments"]), |
| inspected_nodes=sorted(self.state_data["inspected_nodes"]), |
| inspection_results=deepcopy(self.state_data["inspection_results"]), |
| trace_results=deepcopy(self.state_data["traced_lots"]), |
| notified_nodes=sorted(self.state_data["notified_nodes"]), |
| quarantined_inventory=self._quarantine_snapshot(), |
| history=list(self.state_data["history"]), |
| steps_taken=self.state_data["steps_taken"], |
| remaining_step_budget=max(0, self.state_data["max_steps"] - self.state_data["steps_taken"]), |
| ) |
|
|
| def _handle_inspect_node(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]: |
| node_id = self._require_node(action.node_id) |
| node = self.state_data["nodes"][node_id] |
| repeated = node_id in self.state_data["inspected_nodes"] |
|
|
| self.state_data["inspected_nodes"].add(node_id) |
| self.state_data["discovered_shipments"][node_id] = list(self.state_data["shipment_graph"].get(node_id, [])) |
| findings = { |
| lot_id: InspectionEvidence.model_validate(payload) |
| for lot_id, payload in node.get("inspection_findings", {}).items() |
| } |
| self.state_data["inspection_results"][node_id] = findings |
| self._record_history(f"Inspected node {node_id}") |
|
|
| unsafe_total = sum(item.unsafe_quantity for item in findings.values()) |
| value = -0.03 if repeated else 0.08 + min(0.12, unsafe_total / 500.0) |
| reason = "Repeated inspection provided no new information." if repeated else "Inspection revealed inventory evidence." |
| reward = RewardSignal( |
| value=round(value, 4), |
| reason=reason, |
| components={ |
| "inspection_value": round(value, 4), |
| }, |
| ) |
| info = StepInfo( |
| message=f"Inspected node {node_id} and collected node evidence.", |
| action_type=action.type.value, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update( |
| { |
| "node_id": node_id, |
| "inventory": deepcopy(node["inventory"]), |
| "quarantined_inventory": deepcopy(node["quarantined_inventory"]), |
| "outbound_shipments": list(self.state_data["shipment_graph"].get(node_id, [])), |
| "inspection_findings": {lot_id: item.model_dump() for lot_id, item in findings.items()}, |
| } |
| ) |
| return reward, info |
|
|
| def _handle_trace_lot(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]: |
| lot_id = action.lot_id |
| if not lot_id: |
| raise ValueError("trace_lot action requires 'lot_id'.") |
|
|
| traced_lots = self._resolve_related_lots(lot_id) |
| impacted_nodes = [] |
| impacted_quantities = {} |
| impacted_lots = {} |
| discovered_nodes = 0 |
|
|
| for node_id, node_data in self.state_data["nodes"].items(): |
| node_total = 0 |
| node_lots = [] |
| for candidate_lot in traced_lots: |
| available_qty = node_data["inventory"].get(candidate_lot, 0) |
| quarantined_qty = node_data["quarantined_inventory"].get(candidate_lot, 0) |
| total_qty = available_qty + quarantined_qty |
| if total_qty > 0: |
| node_total += total_qty |
| node_lots.append(candidate_lot) |
| if node_total > 0: |
| impacted_nodes.append(node_id) |
| impacted_quantities[node_id] = node_total |
| impacted_lots[node_id] = node_lots |
| if node_id not in self.state_data["discovered_shipments"]: |
| discovered_nodes += 1 |
|
|
| self.state_data["traced_lots"][lot_id] = { |
| "root_lot": self._root_lot_for(lot_id), |
| "matched_lots": sorted(traced_lots), |
| "affected_nodes": impacted_nodes, |
| "lots_by_node": impacted_lots, |
| "quantities_by_node": impacted_quantities, |
| } |
| self._record_history(f"Traced lot {lot_id} across {', '.join(sorted(traced_lots))}") |
|
|
| if not impacted_nodes: |
| reward_value = -0.1 |
| reason = "Trace returned no impacted nodes." |
| elif self._root_lot_for(lot_id) in self.ground_truth["affected_roots"]: |
| reward_value = 0.12 + min(0.13, discovered_nodes * 0.03 + len(traced_lots) * 0.02) |
| reason = "Trace identified the affected lineage across the network." |
| else: |
| reward_value = 0.02 |
| reason = "Trace ran, but the lot is outside the affected lineage." |
|
|
| reward = RewardSignal( |
| value=round(reward_value, 4), |
| reason=reason, |
| components={ |
| "trace_value": round(reward_value, 4), |
| }, |
| ) |
| info = StepInfo( |
| message=f"Traced lot {lot_id} across the shipment network.", |
| action_type=action.type.value, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update( |
| { |
| "lot_id": lot_id, |
| "root_lot": self._root_lot_for(lot_id), |
| "matched_lots": sorted(traced_lots), |
| "affected_nodes": impacted_nodes, |
| "lots_by_node": impacted_lots, |
| "quantities_by_node": impacted_quantities, |
| "total_quantity": sum(impacted_quantities.values()), |
| } |
| ) |
| return reward, info |
|
|
| def _handle_quarantine(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]: |
| node_id = self._require_node(action.node_id) |
| lot_id = action.lot_id |
| if not lot_id: |
| raise ValueError("quarantine action requires 'lot_id'.") |
|
|
| node = self.state_data["nodes"][node_id] |
| available_qty = node["inventory"].get(lot_id, 0) |
| if available_qty <= 0: |
| reward = RewardSignal( |
| value=-0.2, |
| reason="Attempted to quarantine stock that is not available.", |
| components={"invalid_quarantine": -0.2}, |
| ) |
| self._record_history(f"Failed quarantine for {lot_id} at {node_id}: no available stock") |
| info = StepInfo( |
| message="No available stock to quarantine.", |
| action_type=action.type.value, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update({"node_id": node_id, "lot_id": lot_id}) |
| return reward, info |
|
|
| requested_qty = action.quantity or available_qty |
| quarantined_qty = min(requested_qty, available_qty) |
| node["inventory"][lot_id] = available_qty - quarantined_qty |
| if node["inventory"][lot_id] == 0: |
| del node["inventory"][lot_id] |
| node["quarantined_inventory"][lot_id] = node["quarantined_inventory"].get(lot_id, 0) + quarantined_qty |
|
|
| self.state_data["quarantine_log"].append({"node_id": node_id, "lot_id": lot_id, "quantity": quarantined_qty}) |
| self._record_history(f"Quarantined {quarantined_qty} units of {lot_id} at {node_id}") |
|
|
| correct_qty = self.ground_truth["correct_quantities"].get(node_id, {}).get(lot_id, 0) |
| cumulative_quarantined = node["quarantined_inventory"].get(lot_id, 0) |
| delta = cumulative_quarantined - correct_qty |
|
|
| if correct_qty == 0: |
| reward_value = -0.35 |
| reason = "Quarantined safe inventory outside the recall scope." |
| elif delta == 0: |
| reward_value = 0.28 |
| reason = "Quarantine exactly matched the unsafe quantity." |
| elif delta < 0: |
| reward_value = max(0.05, 0.22 * (cumulative_quarantined / correct_qty)) |
| reason = "Quarantine made partial progress but missed some unsafe stock." |
| else: |
| reward_value = max(-0.25, -0.08 * delta) |
| reason = "Quarantine overreached and blocked safe inventory." |
|
|
| reward = RewardSignal( |
| value=round(reward_value, 4), |
| reason=reason, |
| components={ |
| "quarantine_value": round(reward_value, 4), |
| "target_quantity": float(correct_qty), |
| "quarantined_quantity": float(cumulative_quarantined), |
| }, |
| ) |
| info = StepInfo( |
| message=f"Updated quarantine for {lot_id} at {node_id}.", |
| action_type=action.type.value, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update( |
| { |
| "node_id": node_id, |
| "lot_id": lot_id, |
| "quarantined_quantity": quarantined_qty, |
| "remaining_inventory": node["inventory"].get(lot_id, 0), |
| "cumulative_quarantined": cumulative_quarantined, |
| "target_contaminated_quantity": correct_qty, |
| } |
| ) |
| return reward, info |
|
|
| def _handle_notify(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]: |
| requested_target = action.node_id or "all" |
| if requested_target in ("all", "all_nodes"): |
| targets = list(self.state_data["nodes"].keys()) |
| else: |
| targets = [self._require_node(requested_target)] |
|
|
| newly_notified = [] |
| for node_id in targets: |
| if node_id not in self.state_data["notified_nodes"]: |
| self.state_data["notified_nodes"].add(node_id) |
| newly_notified.append(node_id) |
|
|
| affected_newly_notified = sum(1 for node_id in newly_notified if node_id in self.ground_truth["affected_nodes"]) |
| unaffected_newly_notified = len(newly_notified) - affected_newly_notified |
|
|
| if not newly_notified: |
| reward_value = -0.05 |
| reason = "Notification repeated without adding new recipients." |
| else: |
| reward_value = min(0.18, affected_newly_notified * 0.04) - unaffected_newly_notified * 0.01 |
| reason = "Notifications dispatched to downstream stakeholders." |
|
|
| reward = RewardSignal( |
| value=round(reward_value, 4), |
| reason=reason, |
| components={ |
| "notification_value": round(reward_value, 4), |
| }, |
| ) |
| if newly_notified: |
| self._record_history(f"Sent notifications to {', '.join(newly_notified)}") |
| else: |
| self._record_history("Notification action repeated without new recipients") |
|
|
| info = StepInfo( |
| message="Processed notification action.", |
| action_type=action.type.value, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update({"notified_nodes": targets, "newly_notified": newly_notified}) |
| return reward, info |
|
|
| def _handle_finalize(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]: |
| del action |
| self.done = True |
| quarantine_match = self._compute_quarantine_match() |
|
|
| missing_quantity_total = sum( |
| quantity |
| for lot_quantities in quarantine_match["missing_quantities"].values() |
| for quantity in lot_quantities.values() |
| ) |
| over_quantity_total = sum( |
| quantity |
| for lot_quantities in quarantine_match["over_quarantined_quantities"].values() |
| for quantity in lot_quantities.values() |
| ) |
| total_affected_quantity = self.ground_truth["total_affected_quantity"] or 1 |
| quarantine_score = max(0.0, 1.0 - ((missing_quantity_total + (1.25 * over_quantity_total)) / total_affected_quantity)) |
|
|
| notified_affected_nodes = set(self.ground_truth["affected_nodes"]).intersection(self.state_data["notified_nodes"]) |
| affected_node_total = len(self.ground_truth["affected_nodes"]) or 1 |
| notification_score = len(notified_affected_nodes) / affected_node_total |
|
|
| investigated_nodes = set(self.state_data["inspected_nodes"]).intersection(self.ground_truth["affected_nodes"]) |
| investigation_score = len(investigated_nodes) / affected_node_total |
|
|
| efficiency_penalty_steps = max(0, self.state_data["steps_taken"] - max(4, affected_node_total + 3)) |
| efficiency_score = max(0.0, 1.0 - (efficiency_penalty_steps / self.state_data["max_steps"])) |
|
|
| score = round( |
| (0.55 * quarantine_score) + (0.2 * notification_score) + (0.15 * investigation_score) + (0.1 * efficiency_score), |
| 4, |
| ) |
|
|
| reward = RewardSignal( |
| value=score, |
| reason="Final recall response scored.", |
| components={ |
| "quarantine_score": round(quarantine_score, 4), |
| "notification_score": round(notification_score, 4), |
| "investigation_score": round(investigation_score, 4), |
| "efficiency_score": round(efficiency_score, 4), |
| }, |
| ) |
| self._record_history("Finalized recall response") |
|
|
| info = StepInfo( |
| message="Finalized recall response.", |
| action_type="finalize", |
| score=score, |
| reward_breakdown=reward.components, |
| ).model_dump() |
| info.update( |
| { |
| "score": score, |
| "quarantine_score": round(quarantine_score, 4), |
| "notification_score": round(notification_score, 4), |
| "investigation_score": round(investigation_score, 4), |
| "efficiency_score": round(efficiency_score, 4), |
| "all_affected_nodes_notified": notification_score == 1.0, |
| "all_affected_stock_quarantined": missing_quantity_total == 0 and over_quantity_total == 0, |
| "quarantine_match": quarantine_match, |
| } |
| ) |
| return reward, info |
|
|
| def _build_ground_truth(self, scenario: Dict[str, Any]) -> Dict[str, Any]: |
| contaminated_roots = { |
| self._root_lot_for(lot_id, scenario["lot_catalog"]) |
| for lot_id, lot_data in scenario["lot_catalog"].items() |
| if lot_data.get("contaminated", False) |
| } |
|
|
| correct_quantities: Dict[str, Dict[str, int]] = {} |
| affected_nodes = set() |
| affected_lots = set() |
|
|
| for node_id, node_data in scenario["nodes"].items(): |
| for lot_id, finding in node_data.get("inspection_findings", {}).items(): |
| unsafe_quantity = int(finding.get("unsafe_quantity", 0)) |
| if unsafe_quantity > 0: |
| affected_nodes.add(node_id) |
| affected_lots.add(lot_id) |
| correct_quantities.setdefault(node_id, {})[lot_id] = unsafe_quantity |
|
|
| total_affected_quantity = sum( |
| quantity |
| for node_quantities in correct_quantities.values() |
| for quantity in node_quantities.values() |
| ) |
| return { |
| "affected_lots": sorted(affected_lots), |
| "affected_nodes": sorted(affected_nodes), |
| "affected_roots": sorted(contaminated_roots), |
| "correct_quantities": correct_quantities, |
| "total_affected_quantity": total_affected_quantity, |
| } |
|
|
| def _compute_quarantine_match(self) -> Dict[str, Any]: |
| missing_quantities: Dict[str, Dict[str, int]] = {} |
| over_quarantined_quantities: Dict[str, Dict[str, int]] = {} |
|
|
| for node_id, node_data in self.state_data["nodes"].items(): |
| expected = self.ground_truth["correct_quantities"].get(node_id, {}) |
| actual = node_data["quarantined_inventory"] |
| relevant_lots = set(expected) | set(actual) |
|
|
| for lot_id in relevant_lots: |
| expected_qty = expected.get(lot_id, 0) |
| actual_qty = actual.get(lot_id, 0) |
| if actual_qty < expected_qty: |
| missing_quantities.setdefault(node_id, {})[lot_id] = expected_qty - actual_qty |
| elif actual_qty > expected_qty: |
| over_quarantined_quantities.setdefault(node_id, {})[lot_id] = actual_qty - expected_qty |
|
|
| return { |
| "missing_quantities": missing_quantities, |
| "over_quarantined_quantities": over_quarantined_quantities, |
| } |
|
|
| def _inventory_snapshot(self) -> Dict[str, Dict[str, int]]: |
| return {node_id: deepcopy(node_data["inventory"]) for node_id, node_data in self.state_data["nodes"].items()} |
|
|
| def _quarantine_snapshot(self) -> Dict[str, Dict[str, int]]: |
| return { |
| node_id: deepcopy(node_data["quarantined_inventory"]) |
| for node_id, node_data in self.state_data["nodes"].items() |
| if node_data["quarantined_inventory"] |
| } |
|
|
| def _resolve_related_lots(self, lot_id: str) -> set[str]: |
| root_lot = self._root_lot_for(lot_id) |
| return { |
| candidate_lot |
| for candidate_lot in self.state_data["lot_catalog"].keys() |
| if self._root_lot_for(candidate_lot) == root_lot or candidate_lot == lot_id |
| } |
|
|
| def _root_lot_for(self, lot_id: str, lot_catalog: Dict[str, Dict[str, Any]] | None = None) -> str: |
| catalog = lot_catalog or self.state_data.get("lot_catalog", {}) |
| if lot_id not in catalog: |
| return lot_id |
| return catalog[lot_id].get("root_lot", lot_id) |
|
|
| def _build_task_definition(self, scenario: Dict[str, Any]) -> TaskDefinition: |
| return TaskDefinition( |
| task_id=scenario["task_id"], |
| name=scenario["name"], |
| difficulty=scenario["difficulty"], |
| objective=scenario["objective"], |
| max_steps=scenario["max_steps"], |
| ) |
|
|
| def _require_node(self, node_id: str | None) -> str: |
| if not node_id: |
| raise ValueError("Action requires 'node_id'.") |
| if node_id not in self.state_data["nodes"]: |
| raise ValueError(f"Unknown node_id '{node_id}'.") |
| return node_id |
|
|
| def _record_history(self, message: str) -> None: |
| self.state_data["history"].append(message) |
|
|
| def _serialize_state(self, value: Any) -> Any: |
| if isinstance(value, dict): |
| return {key: self._serialize_state(item) for key, item in value.items()} |
| if isinstance(value, set): |
| return sorted(value) |
| if isinstance(value, list): |
| return [self._serialize_state(item) for item in value] |
| if hasattr(value, "model_dump"): |
| return value.model_dump() |
| return value |
|
|