| """Heuristic baseline policy for RecallTrace.""" | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from typing import Any, Dict, Optional | |
| from openai import OpenAI | |
| from env.models import RecallAction, RecallObservation | |
| LOT_PATTERN = re.compile(r"\bLot[A-Za-z0-9_]+\b") | |
| def _extract_root_lot(observation: RecallObservation) -> str: | |
| match = LOT_PATTERN.search(observation.recall_notice) | |
| return match.group(0) if match else "LotA" | |
| def choose_heuristic_action(observation: RecallObservation) -> RecallAction: | |
| """Choose the next deterministic action using only observable state.""" | |
| root_lot = _extract_root_lot(observation) | |
| trace_result = observation.trace_results.get(root_lot) | |
| if trace_result is None: | |
| return RecallAction(type="trace_lot", lot_id=root_lot, rationale="Map the recall lineage first.") | |
| affected_nodes = trace_result.get("affected_nodes", []) | |
| for node_id in affected_nodes: | |
| if node_id not in observation.inspected_nodes: | |
| return RecallAction(type="inspect_node", node_id=node_id, rationale="Collect local evidence before quarantining.") | |
| for node_id, findings in observation.inspection_results.items(): | |
| for lot_id, finding in findings.items(): | |
| unsafe_quantity = finding.unsafe_quantity | |
| quarantined_quantity = observation.quarantined_inventory.get(node_id, {}).get(lot_id, 0) | |
| available_quantity = observation.inventory.get(node_id, {}).get(lot_id, 0) | |
| remaining_target = unsafe_quantity - quarantined_quantity | |
| if remaining_target > 0 and available_quantity > 0: | |
| return RecallAction( | |
| type="quarantine", | |
| node_id=node_id, | |
| lot_id=lot_id, | |
| quantity=min(remaining_target, available_quantity), | |
| rationale="Isolate the exact unsafe quantity discovered during inspection.", | |
| ) | |
| missing_notifications = [node_id for node_id in affected_nodes if node_id not in observation.notified_nodes] | |
| if missing_notifications: | |
| return RecallAction(type="notify", node_id="all", rationale="Alert every impacted stakeholder before closing the incident.") | |
| return RecallAction(type="finalize", rationale="Containment actions are complete.") | |
| def choose_llm_action( | |
| client: Optional[OpenAI], | |
| model_name: str, | |
| observation: RecallObservation, | |
| history: list[dict[str, Any]], | |
| ) -> Optional[RecallAction]: | |
| """Ask an LLM for the next action, returning None on failure.""" | |
| if client is None: | |
| return None | |
| prompt = { | |
| "task_id": observation.task_id, | |
| "phase": observation.phase, | |
| "notice": observation.recall_notice, | |
| "inventory": observation.inventory, | |
| "inspection_results": { | |
| node_id: {lot_id: evidence.model_dump() for lot_id, evidence in findings.items()} | |
| for node_id, findings in observation.inspection_results.items() | |
| }, | |
| "trace_results": observation.trace_results, | |
| "notified_nodes": observation.notified_nodes, | |
| "quarantined_inventory": observation.quarantined_inventory, | |
| "steps_taken": observation.steps_taken, | |
| "remaining_step_budget": observation.remaining_step_budget, | |
| "history": history[-6:], | |
| "instruction": "Return only compact JSON with keys type,node_id,lot_id,quantity,rationale. Use one valid action.", | |
| } | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| temperature=0, | |
| max_tokens=180, | |
| messages=[ | |
| {"role": "system", "content": "You are operating a deterministic product recall environment. Respond with only valid JSON for the next action."}, | |
| {"role": "user", "content": json.dumps(prompt, sort_keys=True)}, | |
| ], | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| if not text: | |
| return None | |
| return RecallAction.model_validate_json(text) | |
| except Exception: | |
| return None | |