metaa / env /env.py
Majen's picture
Full project upload
aae7b06 verified
"""Core RecallTrace environment with deterministic action execution."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, Tuple
from env.models import EnvironmentState, InspectionEvidence, RecallAction, RecallObservation, RewardSignal, StepInfo, TaskDefinition
from scenario.scenario import build_scenario, list_task_specs
class RecallTraceEnv:
"""Deterministic OpenEnv-style environment for product recall containment."""
ACTIONS = [
"inspect_node",
"trace_lot",
"quarantine",
"notify",
"finalize",
]
def __init__(
self,
scenario_data: Dict[str, Any] | None = None,
task_id: str | None = None,
phase: int | None = 1,
):
self._scenario_template = deepcopy(scenario_data) if scenario_data is not None else build_scenario(task_id=task_id, phase=phase)
self.task = self._build_task_definition(self._scenario_template)
self.state_data: Dict[str, Any] = {}
self.ground_truth: Dict[str, Any] = {}
self.done = False
self.last_reward = RewardSignal(value=0.0, reason="Environment initialized.", components={})
@classmethod
def available_tasks(cls) -> list[TaskDefinition]:
return [TaskDefinition(**task_spec) for task_spec in list_task_specs()]
def reset(self, task_id: str | None = None, phase: int | None = None) -> RecallObservation:
"""Start a new deterministic scenario and recompute ground truth."""
if task_id is not None or phase is not None:
self._scenario_template = build_scenario(task_id=task_id, phase=phase)
self.task = self._build_task_definition(self._scenario_template)
self.done = False
self.last_reward = RewardSignal(value=0.0, reason="Episode reset.", components={})
scenario = deepcopy(self._scenario_template)
self.state_data = {
"task_id": scenario["task_id"],
"phase": scenario["phase"],
"recall_notice": scenario["recall_notice"],
"contaminated_lot_hint": scenario["contaminated_lot"],
"shipment_graph": scenario["shipment_graph"],
"lot_catalog": scenario["lot_catalog"],
"nodes": scenario["nodes"],
"history": [],
"discovered_shipments": {},
"inspected_nodes": set(),
"inspection_results": {},
"traced_lots": {},
"notified_nodes": set(),
"quarantine_log": [],
"steps_taken": 0,
"max_steps": scenario["max_steps"],
}
self.ground_truth = self._build_ground_truth(scenario)
return self._get_observation()
def step(self, action: RecallAction | Dict[str, Any]) -> Tuple[RecallObservation, float, bool, Dict[str, Any]]:
"""Execute an action and return observation, reward, done, info."""
if self.done:
return self._get_observation(), 0.0, True, {
"message": "Environment already finalized.",
"action_type": "noop",
"reward_breakdown": {},
}
validated_action = action if isinstance(action, RecallAction) else RecallAction.model_validate(action)
self.state_data["steps_taken"] += 1
handler = getattr(self, f"_handle_{validated_action.type.value}")
reward_signal, info = handler(validated_action)
self.last_reward = reward_signal
if not self.done and self.state_data["steps_taken"] >= self.state_data["max_steps"]:
self.done = True
timeout_penalty = -0.25
reward_signal = RewardSignal(
value=max(-1.0, reward_signal.value + timeout_penalty),
reason="Step budget exhausted before finalizing containment.",
components={**reward_signal.components, "timeout_penalty": timeout_penalty},
)
info = {
**info,
"message": "Step budget exhausted before finalizing containment.",
"reward_breakdown": reward_signal.components,
}
self._record_history("Episode terminated after exhausting the step budget")
self.last_reward = reward_signal
return self._get_observation(), reward_signal.value, self.done, info
def state(self) -> EnvironmentState:
"""Return the full internal state for debugging and graders."""
return EnvironmentState(
done=self.done,
task=self.task,
steps_taken=self.state_data.get("steps_taken", 0),
state_data=deepcopy(self._serialize_state(self.state_data)),
ground_truth=deepcopy(self.ground_truth),
)
def _get_observation(self) -> RecallObservation:
return RecallObservation(
task_id=self.state_data["task_id"],
phase=self.state_data["phase"],
recall_notice=self.state_data["recall_notice"],
available_actions=list(self.ACTIONS),
inventory=self._inventory_snapshot(),
discovered_shipments=deepcopy(self.state_data["discovered_shipments"]),
inspected_nodes=sorted(self.state_data["inspected_nodes"]),
inspection_results=deepcopy(self.state_data["inspection_results"]),
trace_results=deepcopy(self.state_data["traced_lots"]),
notified_nodes=sorted(self.state_data["notified_nodes"]),
quarantined_inventory=self._quarantine_snapshot(),
history=list(self.state_data["history"]),
steps_taken=self.state_data["steps_taken"],
remaining_step_budget=max(0, self.state_data["max_steps"] - self.state_data["steps_taken"]),
)
def _handle_inspect_node(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
node_id = self._require_node(action.node_id)
node = self.state_data["nodes"][node_id]
repeated = node_id in self.state_data["inspected_nodes"]
self.state_data["inspected_nodes"].add(node_id)
self.state_data["discovered_shipments"][node_id] = list(self.state_data["shipment_graph"].get(node_id, []))
findings = {
lot_id: InspectionEvidence.model_validate(payload)
for lot_id, payload in node.get("inspection_findings", {}).items()
}
self.state_data["inspection_results"][node_id] = findings
self._record_history(f"Inspected node {node_id}")
unsafe_total = sum(item.unsafe_quantity for item in findings.values())
value = -0.03 if repeated else 0.08 + min(0.12, unsafe_total / 500.0)
reason = "Repeated inspection provided no new information." if repeated else "Inspection revealed inventory evidence."
reward = RewardSignal(
value=round(value, 4),
reason=reason,
components={
"inspection_value": round(value, 4),
},
)
info = StepInfo(
message=f"Inspected node {node_id} and collected node evidence.",
action_type=action.type.value,
reward_breakdown=reward.components,
).model_dump()
info.update(
{
"node_id": node_id,
"inventory": deepcopy(node["inventory"]),
"quarantined_inventory": deepcopy(node["quarantined_inventory"]),
"outbound_shipments": list(self.state_data["shipment_graph"].get(node_id, [])),
"inspection_findings": {lot_id: item.model_dump() for lot_id, item in findings.items()},
}
)
return reward, info
def _handle_trace_lot(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
lot_id = action.lot_id
if not lot_id:
raise ValueError("trace_lot action requires 'lot_id'.")
traced_lots = self._resolve_related_lots(lot_id)
impacted_nodes = []
impacted_quantities = {}
impacted_lots = {}
discovered_nodes = 0
for node_id, node_data in self.state_data["nodes"].items():
node_total = 0
node_lots = []
for candidate_lot in traced_lots:
available_qty = node_data["inventory"].get(candidate_lot, 0)
quarantined_qty = node_data["quarantined_inventory"].get(candidate_lot, 0)
total_qty = available_qty + quarantined_qty
if total_qty > 0:
node_total += total_qty
node_lots.append(candidate_lot)
if node_total > 0:
impacted_nodes.append(node_id)
impacted_quantities[node_id] = node_total
impacted_lots[node_id] = node_lots
if node_id not in self.state_data["discovered_shipments"]:
discovered_nodes += 1
self.state_data["traced_lots"][lot_id] = {
"root_lot": self._root_lot_for(lot_id),
"matched_lots": sorted(traced_lots),
"affected_nodes": impacted_nodes,
"lots_by_node": impacted_lots,
"quantities_by_node": impacted_quantities,
}
self._record_history(f"Traced lot {lot_id} across {', '.join(sorted(traced_lots))}")
if not impacted_nodes:
reward_value = -0.1
reason = "Trace returned no impacted nodes."
elif self._root_lot_for(lot_id) in self.ground_truth["affected_roots"]:
reward_value = 0.12 + min(0.13, discovered_nodes * 0.03 + len(traced_lots) * 0.02)
reason = "Trace identified the affected lineage across the network."
else:
reward_value = 0.02
reason = "Trace ran, but the lot is outside the affected lineage."
reward = RewardSignal(
value=round(reward_value, 4),
reason=reason,
components={
"trace_value": round(reward_value, 4),
},
)
info = StepInfo(
message=f"Traced lot {lot_id} across the shipment network.",
action_type=action.type.value,
reward_breakdown=reward.components,
).model_dump()
info.update(
{
"lot_id": lot_id,
"root_lot": self._root_lot_for(lot_id),
"matched_lots": sorted(traced_lots),
"affected_nodes": impacted_nodes,
"lots_by_node": impacted_lots,
"quantities_by_node": impacted_quantities,
"total_quantity": sum(impacted_quantities.values()),
}
)
return reward, info
def _handle_quarantine(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
node_id = self._require_node(action.node_id)
lot_id = action.lot_id
if not lot_id:
raise ValueError("quarantine action requires 'lot_id'.")
node = self.state_data["nodes"][node_id]
available_qty = node["inventory"].get(lot_id, 0)
if available_qty <= 0:
reward = RewardSignal(
value=-0.2,
reason="Attempted to quarantine stock that is not available.",
components={"invalid_quarantine": -0.2},
)
self._record_history(f"Failed quarantine for {lot_id} at {node_id}: no available stock")
info = StepInfo(
message="No available stock to quarantine.",
action_type=action.type.value,
reward_breakdown=reward.components,
).model_dump()
info.update({"node_id": node_id, "lot_id": lot_id})
return reward, info
requested_qty = action.quantity or available_qty
quarantined_qty = min(requested_qty, available_qty)
node["inventory"][lot_id] = available_qty - quarantined_qty
if node["inventory"][lot_id] == 0:
del node["inventory"][lot_id]
node["quarantined_inventory"][lot_id] = node["quarantined_inventory"].get(lot_id, 0) + quarantined_qty
self.state_data["quarantine_log"].append({"node_id": node_id, "lot_id": lot_id, "quantity": quarantined_qty})
self._record_history(f"Quarantined {quarantined_qty} units of {lot_id} at {node_id}")
correct_qty = self.ground_truth["correct_quantities"].get(node_id, {}).get(lot_id, 0)
cumulative_quarantined = node["quarantined_inventory"].get(lot_id, 0)
delta = cumulative_quarantined - correct_qty
if correct_qty == 0:
reward_value = -0.35
reason = "Quarantined safe inventory outside the recall scope."
elif delta == 0:
reward_value = 0.28
reason = "Quarantine exactly matched the unsafe quantity."
elif delta < 0:
reward_value = max(0.05, 0.22 * (cumulative_quarantined / correct_qty))
reason = "Quarantine made partial progress but missed some unsafe stock."
else:
reward_value = max(-0.25, -0.08 * delta)
reason = "Quarantine overreached and blocked safe inventory."
reward = RewardSignal(
value=round(reward_value, 4),
reason=reason,
components={
"quarantine_value": round(reward_value, 4),
"target_quantity": float(correct_qty),
"quarantined_quantity": float(cumulative_quarantined),
},
)
info = StepInfo(
message=f"Updated quarantine for {lot_id} at {node_id}.",
action_type=action.type.value,
reward_breakdown=reward.components,
).model_dump()
info.update(
{
"node_id": node_id,
"lot_id": lot_id,
"quarantined_quantity": quarantined_qty,
"remaining_inventory": node["inventory"].get(lot_id, 0),
"cumulative_quarantined": cumulative_quarantined,
"target_contaminated_quantity": correct_qty,
}
)
return reward, info
def _handle_notify(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
requested_target = action.node_id or "all"
if requested_target in ("all", "all_nodes"):
targets = list(self.state_data["nodes"].keys())
else:
targets = [self._require_node(requested_target)]
newly_notified = []
for node_id in targets:
if node_id not in self.state_data["notified_nodes"]:
self.state_data["notified_nodes"].add(node_id)
newly_notified.append(node_id)
affected_newly_notified = sum(1 for node_id in newly_notified if node_id in self.ground_truth["affected_nodes"])
unaffected_newly_notified = len(newly_notified) - affected_newly_notified
if not newly_notified:
reward_value = -0.05
reason = "Notification repeated without adding new recipients."
else:
reward_value = min(0.18, affected_newly_notified * 0.04) - unaffected_newly_notified * 0.01
reason = "Notifications dispatched to downstream stakeholders."
reward = RewardSignal(
value=round(reward_value, 4),
reason=reason,
components={
"notification_value": round(reward_value, 4),
},
)
if newly_notified:
self._record_history(f"Sent notifications to {', '.join(newly_notified)}")
else:
self._record_history("Notification action repeated without new recipients")
info = StepInfo(
message="Processed notification action.",
action_type=action.type.value,
reward_breakdown=reward.components,
).model_dump()
info.update({"notified_nodes": targets, "newly_notified": newly_notified})
return reward, info
def _handle_finalize(self, action: RecallAction) -> tuple[RewardSignal, Dict[str, Any]]:
del action
self.done = True
quarantine_match = self._compute_quarantine_match()
missing_quantity_total = sum(
quantity
for lot_quantities in quarantine_match["missing_quantities"].values()
for quantity in lot_quantities.values()
)
over_quantity_total = sum(
quantity
for lot_quantities in quarantine_match["over_quarantined_quantities"].values()
for quantity in lot_quantities.values()
)
total_affected_quantity = self.ground_truth["total_affected_quantity"] or 1
quarantine_score = max(0.0, 1.0 - ((missing_quantity_total + (1.25 * over_quantity_total)) / total_affected_quantity))
notified_affected_nodes = set(self.ground_truth["affected_nodes"]).intersection(self.state_data["notified_nodes"])
affected_node_total = len(self.ground_truth["affected_nodes"]) or 1
notification_score = len(notified_affected_nodes) / affected_node_total
investigated_nodes = set(self.state_data["inspected_nodes"]).intersection(self.ground_truth["affected_nodes"])
investigation_score = len(investigated_nodes) / affected_node_total
efficiency_penalty_steps = max(0, self.state_data["steps_taken"] - max(4, affected_node_total + 3))
efficiency_score = max(0.0, 1.0 - (efficiency_penalty_steps / self.state_data["max_steps"]))
score = round(
(0.55 * quarantine_score) + (0.2 * notification_score) + (0.15 * investigation_score) + (0.1 * efficiency_score),
4,
)
reward = RewardSignal(
value=score,
reason="Final recall response scored.",
components={
"quarantine_score": round(quarantine_score, 4),
"notification_score": round(notification_score, 4),
"investigation_score": round(investigation_score, 4),
"efficiency_score": round(efficiency_score, 4),
},
)
self._record_history("Finalized recall response")
info = StepInfo(
message="Finalized recall response.",
action_type="finalize",
score=score,
reward_breakdown=reward.components,
).model_dump()
info.update(
{
"score": score,
"quarantine_score": round(quarantine_score, 4),
"notification_score": round(notification_score, 4),
"investigation_score": round(investigation_score, 4),
"efficiency_score": round(efficiency_score, 4),
"all_affected_nodes_notified": notification_score == 1.0,
"all_affected_stock_quarantined": missing_quantity_total == 0 and over_quantity_total == 0,
"quarantine_match": quarantine_match,
}
)
return reward, info
def _build_ground_truth(self, scenario: Dict[str, Any]) -> Dict[str, Any]:
contaminated_roots = {
self._root_lot_for(lot_id, scenario["lot_catalog"])
for lot_id, lot_data in scenario["lot_catalog"].items()
if lot_data.get("contaminated", False)
}
correct_quantities: Dict[str, Dict[str, int]] = {}
affected_nodes = set()
affected_lots = set()
for node_id, node_data in scenario["nodes"].items():
for lot_id, finding in node_data.get("inspection_findings", {}).items():
unsafe_quantity = int(finding.get("unsafe_quantity", 0))
if unsafe_quantity > 0:
affected_nodes.add(node_id)
affected_lots.add(lot_id)
correct_quantities.setdefault(node_id, {})[lot_id] = unsafe_quantity
total_affected_quantity = sum(
quantity
for node_quantities in correct_quantities.values()
for quantity in node_quantities.values()
)
return {
"affected_lots": sorted(affected_lots),
"affected_nodes": sorted(affected_nodes),
"affected_roots": sorted(contaminated_roots),
"correct_quantities": correct_quantities,
"total_affected_quantity": total_affected_quantity,
}
def _compute_quarantine_match(self) -> Dict[str, Any]:
missing_quantities: Dict[str, Dict[str, int]] = {}
over_quarantined_quantities: Dict[str, Dict[str, int]] = {}
for node_id, node_data in self.state_data["nodes"].items():
expected = self.ground_truth["correct_quantities"].get(node_id, {})
actual = node_data["quarantined_inventory"]
relevant_lots = set(expected) | set(actual)
for lot_id in relevant_lots:
expected_qty = expected.get(lot_id, 0)
actual_qty = actual.get(lot_id, 0)
if actual_qty < expected_qty:
missing_quantities.setdefault(node_id, {})[lot_id] = expected_qty - actual_qty
elif actual_qty > expected_qty:
over_quarantined_quantities.setdefault(node_id, {})[lot_id] = actual_qty - expected_qty
return {
"missing_quantities": missing_quantities,
"over_quarantined_quantities": over_quarantined_quantities,
}
def _inventory_snapshot(self) -> Dict[str, Dict[str, int]]:
return {node_id: deepcopy(node_data["inventory"]) for node_id, node_data in self.state_data["nodes"].items()}
def _quarantine_snapshot(self) -> Dict[str, Dict[str, int]]:
return {
node_id: deepcopy(node_data["quarantined_inventory"])
for node_id, node_data in self.state_data["nodes"].items()
if node_data["quarantined_inventory"]
}
def _resolve_related_lots(self, lot_id: str) -> set[str]:
root_lot = self._root_lot_for(lot_id)
return {
candidate_lot
for candidate_lot in self.state_data["lot_catalog"].keys()
if self._root_lot_for(candidate_lot) == root_lot or candidate_lot == lot_id
}
def _root_lot_for(self, lot_id: str, lot_catalog: Dict[str, Dict[str, Any]] | None = None) -> str:
catalog = lot_catalog or self.state_data.get("lot_catalog", {})
if lot_id not in catalog:
return lot_id
return catalog[lot_id].get("root_lot", lot_id)
def _build_task_definition(self, scenario: Dict[str, Any]) -> TaskDefinition:
return TaskDefinition(
task_id=scenario["task_id"],
name=scenario["name"],
difficulty=scenario["difficulty"],
objective=scenario["objective"],
max_steps=scenario["max_steps"],
)
def _require_node(self, node_id: str | None) -> str:
if not node_id:
raise ValueError("Action requires 'node_id'.")
if node_id not in self.state_data["nodes"]:
raise ValueError(f"Unknown node_id '{node_id}'.")
return node_id
def _record_history(self, message: str) -> None:
self.state_data["history"].append(message)
def _serialize_state(self, value: Any) -> Any:
if isinstance(value, dict):
return {key: self._serialize_state(item) for key, item in value.items()}
if isinstance(value, set):
return sorted(value)
if isinstance(value, list):
return [self._serialize_state(item) for item in value]
if hasattr(value, "model_dump"):
return value.model_dump()
return value