""" Submission inference runner. Requirements covered: - Script name is inference.py in repo root. - Uses OpenAI client for model calls. - Uses internal Groq + Llama 8B defaults (overridable via environment). - Emits structured stdout logs with [START], [STEP], [END]. """ import argparse from copy import deepcopy import json import os import pickle import re import sys from typing import Any, Dict, List, Optional from openai import OpenAI from environment import Action, ActionType, CabinClass, FlightRebookingEnv, PriorityTier from ml_policy import choose_action_from_ranked_types, observation_to_features from tasks import TASKS, grade_task SYSTEM_PROMPT = """You are an airline disruption operations agent. Return exactly one JSON object on each turn with this schema: { \"action_type\": \"rebook_passenger\" | \"offer_downgrade\" | \"book_hotel\" | \"rebook_on_partner\" | \"mark_no_solution\" | \"finalize\", \"passenger_id\": \"optional passenger id\", \"flight_id\": \"optional flight id\" } Policy: - Process one pending passenger per step. - Respect tiers (Platinum > Gold > Silver > Standard). - Prefer earlier departures for deadline passengers. - Prefer same-airline rebooking over partner when feasible. - Minimize budget usage. - Output raw JSON only. """ DEFAULT_API_BASE_URL = "https://api.groq.com/openai/v1" DEFAULT_LLM_MODEL = "llama-3.1-8b-instant" INTERNAL_GROQ_API_KEY = "" BENCHMARK_NAME = os.getenv("BENCHMARK", "flight-rebooking-openenv") SUCCESS_SCORE_THRESHOLD = 0.1 GIT_LFS_POINTER_HEADER = "version https://git-lfs.github.com/spec/v1" def _first_non_empty(*values: str) -> str: for value in values: cleaned = (value or "").strip() if cleaned: return cleaned return "" def _resolve_model_config() -> Dict[str, str]: api_base_url = _first_non_empty( os.getenv("API_BASE_URL", ""), os.getenv("OPENAI_BASE_URL", ""), DEFAULT_API_BASE_URL, ) model_name = _first_non_empty( os.getenv("MODEL_NAME", ""), os.getenv("OPENAI_MODEL", ""), DEFAULT_LLM_MODEL, ) api_key = _first_non_empty( os.getenv("GROQ_API_KEY", ""), os.getenv("HF_TOKEN", ""), os.getenv("OPENAI_API_KEY", ""), INTERNAL_GROQ_API_KEY, ) if not api_key: raise SystemExit( "No API key configured. Set GROQ_API_KEY (preferred), OPENAI_API_KEY, or HF_TOKEN." ) return { "api_base_url": api_base_url, "model_name": model_name, "api_key": api_key, } def _load_ml_policy_artifact(path: str) -> Optional[Dict[str, Any]]: if not path: return None if not os.path.exists(path): return None try: with open(path, "rb") as handle: artifact = pickle.load(handle) except Exception as exc: print(f"[WARN] Failed to load ML policy artifact at {path}: {exc}", file=sys.stderr) return None if not isinstance(artifact, dict) or "model" not in artifact: print(f"[WARN] Invalid ML policy artifact format at {path}; ignoring.", file=sys.stderr) return None return artifact def _is_git_lfs_pointer_file(path: str) -> bool: try: with open(path, "r", encoding="utf-8") as handle: lines = [handle.readline().strip() for _ in range(3)] except (UnicodeDecodeError, OSError): return False if not lines or lines[0] != GIT_LFS_POINTER_HEADER: return False return any(line.startswith("oid sha256:") for line in lines[1:]) def _ml_policy_fix_instructions(path: str) -> str: return ( "Fix options:\n" "1) Materialize artifact bytes with Git LFS (if this repo stores models in LFS):\n" f" git lfs pull --include \"{path}\"\n" "2) Regenerate the artifact locally:\n" " python train_ml_policy.py --episodes-per-task 450 --seed 42 --output artifacts/ml_policy.pkl --report artifacts/ml_policy_report.json" ) def _require_ml_policy_artifact(path: str, policy_name: str) -> Dict[str, Any]: if not path: raise SystemExit( f"Policy '{policy_name}' requires --ml-policy-path.\n" + _ml_policy_fix_instructions("artifacts/ml_policy.pkl") ) if not os.path.exists(path): raise SystemExit( f"Policy '{policy_name}' requires an ML artifact, but '{path}' was not found.\n" + _ml_policy_fix_instructions(path) ) if _is_git_lfs_pointer_file(path): raise SystemExit( f"Policy '{policy_name}' cannot run because '{path}' is a Git LFS pointer, not a pickle artifact.\n" + _ml_policy_fix_instructions(path) ) artifact = _load_ml_policy_artifact(path) if artifact is None: raise SystemExit( f"Policy '{policy_name}' requires a valid ML artifact, but '{path}' could not be loaded as a pickle.\n" + _ml_policy_fix_instructions(path) ) return artifact def _rank_action_types_from_model(model: Any, features: List[float]) -> List[str]: ranked: List[str] if hasattr(model, "predict_proba") and hasattr(model, "classes_"): probabilities = model.predict_proba([features])[0] classes = [str(cls) for cls in model.classes_] ranked = [ label for _, label in sorted( zip(probabilities, classes), key=lambda item: item[0], reverse=True, ) ] else: ranked = [str(model.predict([features])[0])] for action_type in ( ActionType.REBOOK_PASSENGER.value, ActionType.OFFER_DOWNGRADE.value, ActionType.REBOOK_ON_PARTNER.value, ActionType.BOOK_HOTEL.value, ActionType.MARK_NO_SOLUTION.value, ActionType.FINALIZE.value, ): if action_type not in ranked: ranked.append(action_type) return ranked def _predict_ml_policy_action(observation: Dict[str, Any], ml_policy_artifact: Dict[str, Any]) -> Dict[str, Any]: model = ml_policy_artifact["model"] features = observation_to_features(observation) ranked_action_types = _rank_action_types_from_model(model, features) return choose_action_from_ranked_types(observation, ranked_action_types) def _predict_ml_ranked_action_types(observation: Dict[str, Any], ml_policy_artifact: Dict[str, Any]) -> List[str]: model = ml_policy_artifact["model"] features = observation_to_features(observation) return _rank_action_types_from_model(model, features) def _feasible_actions_from_observation(observation: Dict[str, Any]) -> List[Action]: pending = list(observation.get("pending_passengers", [])) flights = list(observation.get("available_flights", [])) budget_remaining = float(observation.get("budget_remaining", 0.0)) if not pending: return [Action(action_type=ActionType.FINALIZE)] actions: List[Action] = [] for passenger in pending: for flight in flights: if (not flight.get("is_partner", False)) and _has_seat(flight, str(passenger.get("cabin_class", ""))): actions.append( Action( action_type=ActionType.REBOOK_PASSENGER, passenger_id=passenger["id"], flight_id=flight["id"], ) ) if ( passenger.get("cabin_class") == CabinClass.BUSINESS.value and (not flight.get("is_partner", False)) and int(flight.get("economy_seats", 0)) > 0 and budget_remaining >= 500.0 ): actions.append( Action( action_type=ActionType.OFFER_DOWNGRADE, passenger_id=passenger["id"], flight_id=flight["id"], ) ) if ( flight.get("is_partner", False) and _has_seat(flight, str(passenger.get("cabin_class", ""))) and budget_remaining >= 800.0 ): actions.append( Action( action_type=ActionType.REBOOK_ON_PARTNER, passenger_id=passenger["id"], flight_id=flight["id"], ) ) if budget_remaining >= 250.0: actions.append( Action( action_type=ActionType.BOOK_HOTEL, passenger_id=passenger["id"], ) ) actions.append( Action( action_type=ActionType.MARK_NO_SOLUTION, passenger_id=passenger["id"], ) ) actions.append(Action(action_type=ActionType.FINALIZE)) return actions def _action_cost(action_type: ActionType) -> float: return { ActionType.REBOOK_PASSENGER: 0.0, ActionType.OFFER_DOWNGRADE: 500.0, ActionType.BOOK_HOTEL: 250.0, ActionType.REBOOK_ON_PARTNER: 800.0, ActionType.MARK_NO_SOLUTION: 0.0, ActionType.FINALIZE: 0.0, }.get(action_type, 0.0) def _action_priority_score(observation: Dict[str, Any], action: Action) -> float: pending = list(observation.get("pending_passengers", [])) if action.action_type == ActionType.FINALIZE: return 10.0 if not pending else -10.0 pending_by_id = {p["id"]: p for p in pending} flights_by_id = {f["id"]: f for f in observation.get("available_flights", [])} passenger = pending_by_id.get(action.passenger_id or "") if passenger is None: return -100.0 tier_component = _tier_weight(str(passenger.get("priority_tier", ""))) / 4.0 deadline = passenger.get("connection_deadline_hrs") if deadline is None: deadline_component = 0.0 else: deadline_component = (12.0 - min(max(float(deadline), 0.0), 12.0)) / 12.0 score = (0.65 * tier_component) + (0.35 * deadline_component) type_bonus = { ActionType.REBOOK_PASSENGER: 0.60, ActionType.OFFER_DOWNGRADE: 0.30, ActionType.REBOOK_ON_PARTNER: 0.18, ActionType.BOOK_HOTEL: 0.10, ActionType.MARK_NO_SOLUTION: -0.60, ActionType.FINALIZE: 0.0, }[action.action_type] score += type_bonus if action.flight_id: flight = flights_by_id.get(action.flight_id) if flight is not None and deadline is not None: departure = float(flight.get("departure_hrs", 99.0)) if departure <= float(deadline): score += 0.22 else: score -= 0.22 budget_remaining = float(observation.get("budget_remaining", 0.0)) budget_spent = float(observation.get("budget_spent", 0.0)) budget_total = max(budget_remaining + budget_spent, 1.0) score -= 0.35 * min(_action_cost(action.action_type) / budget_total, 1.0) return score def _prune_candidate_actions( observation: Dict[str, Any], actions: List[Action], ranked_action_types: Optional[List[str]], max_candidates: int, ) -> List[Action]: deduped: List[Action] = [] seen = set() for action in actions: signature = (action.action_type.value, action.passenger_id, action.flight_id) if signature in seen: continue seen.add(signature) deduped.append(action) rank_index: Dict[str, int] = {} if ranked_action_types: rank_index = {action_type: idx for idx, action_type in enumerate(ranked_action_types)} deduped.sort( key=lambda action: ( rank_index.get(action.action_type.value, 999), -_action_priority_score(observation, action), ) ) return deduped[: max(1, max_candidates)] def _rollout_heuristic_to_end(env: FlightRebookingEnv) -> None: done = False while not done: observation = env._get_observation().model_dump(mode="json") action = Action(**_heuristic_action(observation)) _, _, done, _ = env.step(action) def _evaluate_state_with_lookahead( env: FlightRebookingEnv, task_key: str, lookahead_depth: int, lookahead_width: int, ranked_action_types: Optional[List[str]] = None, ) -> float: observation = env._get_observation().model_dump(mode="json") candidate_actions = _feasible_actions_from_observation(observation) if ranked_action_types: preferred_types = set(ranked_action_types[:5]) preferred_types.add(ActionType.FINALIZE.value) preferred_types.add(ActionType.MARK_NO_SOLUTION.value) preferred_candidates = [a for a in candidate_actions if a.action_type.value in preferred_types] if preferred_candidates: candidate_actions = preferred_candidates candidate_actions = _prune_candidate_actions( observation=observation, actions=candidate_actions, ranked_action_types=ranked_action_types, max_candidates=lookahead_width, ) best_score = -1.0 for action in candidate_actions: env_copy = deepcopy(env) _, _, done, _ = env_copy.step(action) if done: score = float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"])) elif lookahead_depth <= 1: _rollout_heuristic_to_end(env_copy) score = float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"])) else: score = _evaluate_state_with_lookahead( env=env_copy, task_key=task_key, lookahead_depth=lookahead_depth - 1, lookahead_width=lookahead_width, ranked_action_types=None, ) if score > best_score: best_score = score if best_score >= 0.0: return best_score env_fallback = deepcopy(env) _rollout_heuristic_to_end(env_fallback) return float(grade_task(task_key, env_fallback.state(), TASKS[task_key]["max_budget"])) def _projected_score_for_action( env: FlightRebookingEnv, task_key: str, action: Action, lookahead_depth: int, lookahead_width: int, ) -> float: env_copy = deepcopy(env) _, _, done, _ = env_copy.step(action) if done: return float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"])) if lookahead_depth <= 1: _rollout_heuristic_to_end(env_copy) return float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"])) return _evaluate_state_with_lookahead( env=env_copy, task_key=task_key, lookahead_depth=lookahead_depth - 1, lookahead_width=lookahead_width, ranked_action_types=None, ) def _choose_lookahead_action( env: FlightRebookingEnv, task_key: str, lookahead_depth: int, lookahead_width: int, ranked_action_types: Optional[List[str]] = None, ) -> Dict[str, Any]: observation = env._get_observation().model_dump(mode="json") candidate_actions = _feasible_actions_from_observation(observation) if ranked_action_types: preferred_types = set(ranked_action_types[:5]) preferred_types.add(ActionType.FINALIZE.value) preferred_types.add(ActionType.MARK_NO_SOLUTION.value) preferred_candidates = [a for a in candidate_actions if a.action_type.value in preferred_types] if preferred_candidates: candidate_actions = preferred_candidates best_action = candidate_actions[0] best_score = -1.0 for action in candidate_actions: try: projected_score = _projected_score_for_action( env=env, task_key=task_key, action=action, lookahead_depth=lookahead_depth, lookahead_width=lookahead_width, ) except Exception: continue if projected_score > best_score: best_score = projected_score best_action = action return best_action.model_dump(mode="json") def _pick_best_payload_by_projection( env: FlightRebookingEnv, task_key: str, payloads: List[Dict[str, Any]], lookahead_depth: int, lookahead_width: int, ) -> Dict[str, Any]: best_payload = payloads[0] best_score = -1.0 seen_signatures = set() for payload in payloads: try: action = Action(**payload) except Exception: continue signature = (action.action_type.value, action.passenger_id, action.flight_id) if signature in seen_signatures: continue seen_signatures.add(signature) try: projected_score = _projected_score_for_action( env=env, task_key=task_key, action=action, lookahead_depth=lookahead_depth, lookahead_width=lookahead_width, ) except Exception: continue if projected_score > best_score: best_score = projected_score best_payload = action.model_dump(mode="json") return best_payload def _extract_json(text: str) -> Dict[str, Any]: text = (text or "").strip() try: return json.loads(text) except json.JSONDecodeError: pass fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) if fenced: return json.loads(fenced.group(1)) inline = re.search(r"\{.*\}", text, re.DOTALL) if inline: return json.loads(inline.group(0)) raise ValueError("No valid JSON action in model output") def _tier_weight(tier: str) -> int: return { PriorityTier.PLATINUM.value: 4, PriorityTier.GOLD.value: 3, PriorityTier.SILVER.value: 2, PriorityTier.STANDARD.value: 1, }.get(tier, 1) def _has_seat(flight: Dict[str, Any], cabin_class: str) -> bool: if cabin_class == CabinClass.BUSINESS.value: return flight["business_seats"] > 0 return flight["economy_seats"] > 0 def _heuristic_action(observation: Dict[str, Any]) -> Dict[str, Any]: pending = list(observation["pending_passengers"]) if not pending: return {"action_type": ActionType.FINALIZE.value} pending.sort( key=lambda p: ( -_tier_weight(p["priority_tier"]), p["connection_deadline_hrs"] if p["connection_deadline_hrs"] is not None else 10**9, ) ) passenger = pending[0] flights = sorted(observation["available_flights"], key=lambda f: f["departure_hrs"]) for flight in flights: if flight["is_partner"]: continue if _has_seat(flight, passenger["cabin_class"]): return { "action_type": ActionType.REBOOK_PASSENGER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } if passenger["cabin_class"] == CabinClass.BUSINESS.value: for flight in flights: if flight["is_partner"]: continue if flight["economy_seats"] > 0 and observation["budget_remaining"] >= 500: return { "action_type": ActionType.OFFER_DOWNGRADE.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } for flight in flights: if not flight["is_partner"]: continue if _has_seat(flight, passenger["cabin_class"]) and observation["budget_remaining"] >= 800: return { "action_type": ActionType.REBOOK_ON_PARTNER.value, "passenger_id": passenger["id"], "flight_id": flight["id"], } if observation["budget_remaining"] >= 250: return { "action_type": ActionType.BOOK_HOTEL.value, "passenger_id": passenger["id"], } return { "action_type": ActionType.MARK_NO_SOLUTION.value, "passenger_id": passenger["id"], } def _is_action_feasible(observation: Dict[str, Any], payload: Dict[str, Any]) -> bool: action_type = payload["action_type"] if action_type == ActionType.FINALIZE.value: return True pending_by_id = {p["id"]: p for p in observation["pending_passengers"]} flights_by_id = {f["id"]: f for f in observation["available_flights"]} budget_remaining = float(observation["budget_remaining"]) passenger = pending_by_id.get(payload.get("passenger_id")) if passenger is None: return False if action_type == ActionType.BOOK_HOTEL.value: return budget_remaining >= 250 if action_type == ActionType.MARK_NO_SOLUTION.value: return True flight = flights_by_id.get(payload.get("flight_id")) if flight is None: return False passenger_cabin = passenger["cabin_class"] needs_business = passenger_cabin == CabinClass.BUSINESS.value has_matching_cabin_seat = (flight["business_seats"] > 0) if needs_business else (flight["economy_seats"] > 0) if action_type == ActionType.REBOOK_PASSENGER.value: return (not flight["is_partner"]) and has_matching_cabin_seat if action_type == ActionType.OFFER_DOWNGRADE.value: return ( passenger_cabin == CabinClass.BUSINESS.value and budget_remaining >= 500 and flight["economy_seats"] > 0 ) if action_type == ActionType.REBOOK_ON_PARTNER.value: return flight["is_partner"] and budget_remaining >= 800 and has_matching_cabin_seat return False def _sanitize_action_payload(observation: Dict[str, Any], payload: Any) -> Dict[str, Any]: fallback = _heuristic_action(observation) if not isinstance(payload, dict): return fallback valid_action_types = {action_type.value for action_type in ActionType} action_type = str(payload.get("action_type", "")).strip() if action_type not in valid_action_types: return fallback sanitized: Dict[str, Any] = {"action_type": action_type} passenger_id = str(payload.get("passenger_id", "")).strip() flight_id = str(payload.get("flight_id", "")).strip() if passenger_id: sanitized["passenger_id"] = passenger_id if flight_id: sanitized["flight_id"] = flight_id if action_type == ActionType.FINALIZE.value: return sanitized pending_ids = {p["id"] for p in observation["pending_passengers"]} if sanitized.get("passenger_id") not in pending_ids: return fallback if action_type in { ActionType.REBOOK_PASSENGER.value, ActionType.OFFER_DOWNGRADE.value, ActionType.REBOOK_ON_PARTNER.value, }: flight_ids = {f["id"] for f in observation["available_flights"]} if sanitized.get("flight_id") not in flight_ids: return fallback if not _is_action_feasible(observation, sanitized): return fallback return sanitized def _query_openai_action( client: OpenAI, model_name: str, seed: int, observation_json: str, policy_hint_json: Optional[str] = None, max_retries: int = 2, ) -> Dict[str, Any]: last_error: Optional[Exception] = None for _ in range(max_retries + 1): try: user_content = f"Current observation: {observation_json}" if policy_hint_json: user_content += ( "\nSuggested safe action from a trained policy: " f"{policy_hint_json}" "\nPrefer this if it is valid for the current observation." ) kwargs: Dict[str, Any] = { "model": model_name, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ], "temperature": 0, "top_p": 1, "max_tokens": 220, "seed": seed, } response = None try: response = client.chat.completions.create(**kwargs) except TypeError: kwargs.pop("seed", None) response = client.chat.completions.create(**kwargs) content = response.choices[0].message.content or "" return _extract_json(content) except Exception as exc: last_error = exc raise RuntimeError(f"OpenAI call failed after retries: {last_error}") def _emit_start(task_name: str, benchmark: str, model_name: str) -> None: print(f"[START] task={task_name} env={benchmark} model={model_name}", flush=True) def _format_action_for_log(action: Action) -> str: payload = { "action_type": action.action_type.value, "passenger_id": action.passenger_id, "flight_id": action.flight_id, } return json.dumps(payload, separators=(",", ":"), ensure_ascii=True) def _emit_step( step_index: int, action_text: str, reward_value: float, done: bool, error: Optional[str], ) -> None: done_value = str(bool(done)).lower() error_value = error if error else "null" print( "[STEP] " f"step={step_index} " f"action={action_text} " f"reward={reward_value:.2f} " f"done={done_value} " f"error={error_value}", flush=True, ) def _emit_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_text = ",".join(f"{value:.2f}" for value in rewards) success_value = str(bool(success)).lower() print( "[END] " f"success={success_value} " f"steps={steps} " f"score={score:.4f} " f"rewards={rewards_text}", flush=True, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run submission inference across OpenEnv tasks.") parser.add_argument("--task", choices=["all", "easy", "medium", "hard"], default="all") parser.add_argument("--seed", type=int, default=int(os.getenv("BASELINE_SEED", "42"))) parser.add_argument( "--policy", choices=["openai", "heuristic", "trained_ml", "openai_trained"], default="openai_trained", help=( "Policy backend. openai_trained uses Llama with trained-policy hints; " "trained_ml uses the learned policy directly; openai and heuristic remain available." ), ) parser.add_argument( "--ml-policy-path", default=os.getenv("ML_POLICY_PATH", "artifacts/ml_policy.pkl"), help="Path to trained ML policy artifact used by trained_ml/openai_trained modes.", ) parser.add_argument( "--lookahead-depth", type=int, default=int(os.getenv("LOOKAHEAD_DEPTH", "2")), help="Lookahead depth for projected action scoring (>=1).", ) parser.add_argument( "--lookahead-width", type=int, default=int(os.getenv("LOOKAHEAD_WIDTH", "12")), help="Maximum candidate actions explored per lookahead level (>=1).", ) parser.add_argument("--json-out", default="", help="Optional JSON output path.") return parser.parse_args() def main() -> None: args = parse_args() args.lookahead_depth = max(1, int(args.lookahead_depth)) args.lookahead_width = max(1, int(args.lookahead_width)) task_keys = ["easy", "medium", "hard"] if args.task == "all" else [args.task] effective_policy = args.policy ml_policy_artifact: Optional[Dict[str, Any]] = None if effective_policy in {"trained_ml", "openai_trained"}: ml_policy_artifact = _require_ml_policy_artifact(args.ml_policy_path, effective_policy) api_base_url = "heuristic" model_name = "heuristic" client: Optional[OpenAI] = None if effective_policy in {"openai", "openai_trained"}: model_config = _resolve_model_config() api_base_url = model_config["api_base_url"] model_name = model_config["model_name"] client = OpenAI(api_key=model_config["api_key"], base_url=api_base_url) results: List[Dict[str, Any]] = [] for task_key in task_keys: task_data = TASKS[task_key] _emit_start(task_name=task_data["task_id"], benchmark=BENCHMARK_NAME, model_name=model_name) env = FlightRebookingEnv(task_data=task_data) observation = None done = False steps = 0 rewards: List[float] = [] score = 0.01 success = False episode_error: Optional[str] = None try: observation = env.reset() while not done: observation_dict = observation.model_dump(mode="json") if effective_policy in {"openai", "openai_trained"}: assert client is not None policy_hint_payload: Optional[Dict[str, Any]] = None if effective_policy == "openai_trained": assert ml_policy_artifact is not None ranked_types = _predict_ml_ranked_action_types(observation_dict, ml_policy_artifact) policy_hint_payload = _choose_lookahead_action( env=env, task_key=task_key, lookahead_depth=args.lookahead_depth, lookahead_width=args.lookahead_width, ranked_action_types=ranked_types, ) raw_payload = _query_openai_action( client=client, model_name=model_name, seed=args.seed, observation_json=observation.model_dump_json(), policy_hint_json=(json.dumps(policy_hint_payload) if policy_hint_payload is not None else None), ) llm_payload = _sanitize_action_payload(observation_dict, raw_payload) if effective_policy == "openai_trained" and policy_hint_payload is not None: action_payload = _pick_best_payload_by_projection( env=env, task_key=task_key, payloads=[policy_hint_payload, llm_payload], lookahead_depth=args.lookahead_depth, lookahead_width=args.lookahead_width, ) else: action_payload = llm_payload elif effective_policy == "trained_ml": assert ml_policy_artifact is not None ranked_types = _predict_ml_ranked_action_types(observation_dict, ml_policy_artifact) action_payload = _choose_lookahead_action( env=env, task_key=task_key, lookahead_depth=args.lookahead_depth, lookahead_width=args.lookahead_width, ranked_action_types=ranked_types, ) else: action_payload = _heuristic_action(observation_dict) try: action = Action(**action_payload) except Exception: action = Action(action_type=ActionType.FINALIZE) step_error: Optional[str] = None reward_value = 0.0 try: observation, reward, done, info = env.step(action) reward_value = float(reward.value) if isinstance(info, dict) and info.get("error"): step_error = str(info.get("error")) except Exception as exc: done = True step_error = str(exc) episode_error = step_error steps += 1 rewards.append(reward_value) _emit_step( step_index=steps, action_text=_format_action_for_log(action), reward_value=reward_value, done=done, error=step_error, ) try: final_state = env.state() score = float(grade_task(task_key, final_state, task_data["max_budget"])) except Exception as exc: episode_error = str(exc) score = 0.01 except Exception as exc: episode_error = str(exc) score = 0.01 finally: close_fn = getattr(env, "close", None) if callable(close_fn): try: close_fn() except Exception as exc: if not episode_error: episode_error = str(exc) success = (episode_error is None) and (0.0 <= score <= 1.0) and (score >= SUCCESS_SCORE_THRESHOLD) _emit_end(success=success, steps=steps, score=score, rewards=rewards) try: final_state = env.state() avg_step_reward = sum(rewards) / max(len(rewards), 1) results.append( { "task": task_key, "task_id": task_data["task_id"], "difficulty": task_data["difficulty"], "steps": steps, "avg_step_reward": round(avg_step_reward, 4), "score": round(score, 4), "budget_spent": round(final_state.budget_spent, 2), "budget_max": task_data["max_budget"], "invalid_actions": final_state.invalid_actions, "success": success, "error": episode_error, } ) except Exception: avg_step_reward = sum(rewards) / max(len(rewards), 1) results.append( { "task": task_key, "task_id": task_data["task_id"], "difficulty": task_data["difficulty"], "steps": steps, "avg_step_reward": round(avg_step_reward, 4), "score": round(score, 4), "budget_spent": None, "budget_max": task_data["max_budget"], "invalid_actions": None, "success": success, "error": episode_error, } ) overall = sum(item["score"] for item in results) / max(len(results), 1) if args.json_out: payload = { "policy_requested": args.policy, "policy_effective": effective_policy, "seed": args.seed, "api_base_url": api_base_url, "model_name": model_name, "ml_policy_path": args.ml_policy_path, "ml_policy_loaded": ml_policy_artifact is not None, "lookahead_depth": args.lookahead_depth, "lookahead_width": args.lookahead_width, "overall_score": round(overall, 4), "tasks": results, } with open(args.json_out, "w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) if __name__ == "__main__": main()