flight-rebooking / inference.py
dhnkhr's picture
Production-ready: Clean code with Groq API integration, LoRA model support, and FastAPI app
9753ee2
"""
Submission inference runner.
Requirements covered:
- Script name is inference.py in repo root.
- Uses OpenAI client for model calls.
- Uses internal Groq + Llama 8B defaults (overridable via environment).
- Emits structured stdout logs with [START], [STEP], [END].
"""
import argparse
from copy import deepcopy
import json
import os
import pickle
import re
import sys
from typing import Any, Dict, List, Optional
from openai import OpenAI
from environment import Action, ActionType, CabinClass, FlightRebookingEnv, PriorityTier
from ml_policy import choose_action_from_ranked_types, observation_to_features
from tasks import TASKS, grade_task
SYSTEM_PROMPT = """You are an airline disruption operations agent.
Return exactly one JSON object on each turn with this schema:
{
\"action_type\": \"rebook_passenger\" | \"offer_downgrade\" | \"book_hotel\" | \"rebook_on_partner\" | \"mark_no_solution\" | \"finalize\",
\"passenger_id\": \"optional passenger id\",
\"flight_id\": \"optional flight id\"
}
Policy:
- Process one pending passenger per step.
- Respect tiers (Platinum > Gold > Silver > Standard).
- Prefer earlier departures for deadline passengers.
- Prefer same-airline rebooking over partner when feasible.
- Minimize budget usage.
- Output raw JSON only.
"""
DEFAULT_API_BASE_URL = "https://api.groq.com/openai/v1"
DEFAULT_LLM_MODEL = "llama-3.1-8b-instant"
INTERNAL_GROQ_API_KEY = ""
BENCHMARK_NAME = os.getenv("BENCHMARK", "flight-rebooking-openenv")
SUCCESS_SCORE_THRESHOLD = 0.1
GIT_LFS_POINTER_HEADER = "version https://git-lfs.github.com/spec/v1"
def _first_non_empty(*values: str) -> str:
for value in values:
cleaned = (value or "").strip()
if cleaned:
return cleaned
return ""
def _resolve_model_config() -> Dict[str, str]:
api_base_url = _first_non_empty(
os.getenv("API_BASE_URL", ""),
os.getenv("OPENAI_BASE_URL", ""),
DEFAULT_API_BASE_URL,
)
model_name = _first_non_empty(
os.getenv("MODEL_NAME", ""),
os.getenv("OPENAI_MODEL", ""),
DEFAULT_LLM_MODEL,
)
api_key = _first_non_empty(
os.getenv("GROQ_API_KEY", ""),
os.getenv("HF_TOKEN", ""),
os.getenv("OPENAI_API_KEY", ""),
INTERNAL_GROQ_API_KEY,
)
if not api_key:
raise SystemExit(
"No API key configured. Set GROQ_API_KEY (preferred), OPENAI_API_KEY, or HF_TOKEN."
)
return {
"api_base_url": api_base_url,
"model_name": model_name,
"api_key": api_key,
}
def _load_ml_policy_artifact(path: str) -> Optional[Dict[str, Any]]:
if not path:
return None
if not os.path.exists(path):
return None
try:
with open(path, "rb") as handle:
artifact = pickle.load(handle)
except Exception as exc:
print(f"[WARN] Failed to load ML policy artifact at {path}: {exc}", file=sys.stderr)
return None
if not isinstance(artifact, dict) or "model" not in artifact:
print(f"[WARN] Invalid ML policy artifact format at {path}; ignoring.", file=sys.stderr)
return None
return artifact
def _is_git_lfs_pointer_file(path: str) -> bool:
try:
with open(path, "r", encoding="utf-8") as handle:
lines = [handle.readline().strip() for _ in range(3)]
except (UnicodeDecodeError, OSError):
return False
if not lines or lines[0] != GIT_LFS_POINTER_HEADER:
return False
return any(line.startswith("oid sha256:") for line in lines[1:])
def _ml_policy_fix_instructions(path: str) -> str:
return (
"Fix options:\n"
"1) Materialize artifact bytes with Git LFS (if this repo stores models in LFS):\n"
f" git lfs pull --include \"{path}\"\n"
"2) Regenerate the artifact locally:\n"
" python train_ml_policy.py --episodes-per-task 450 --seed 42 --output artifacts/ml_policy.pkl --report artifacts/ml_policy_report.json"
)
def _require_ml_policy_artifact(path: str, policy_name: str) -> Dict[str, Any]:
if not path:
raise SystemExit(
f"Policy '{policy_name}' requires --ml-policy-path.\n"
+ _ml_policy_fix_instructions("artifacts/ml_policy.pkl")
)
if not os.path.exists(path):
raise SystemExit(
f"Policy '{policy_name}' requires an ML artifact, but '{path}' was not found.\n"
+ _ml_policy_fix_instructions(path)
)
if _is_git_lfs_pointer_file(path):
raise SystemExit(
f"Policy '{policy_name}' cannot run because '{path}' is a Git LFS pointer, not a pickle artifact.\n"
+ _ml_policy_fix_instructions(path)
)
artifact = _load_ml_policy_artifact(path)
if artifact is None:
raise SystemExit(
f"Policy '{policy_name}' requires a valid ML artifact, but '{path}' could not be loaded as a pickle.\n"
+ _ml_policy_fix_instructions(path)
)
return artifact
def _rank_action_types_from_model(model: Any, features: List[float]) -> List[str]:
ranked: List[str]
if hasattr(model, "predict_proba") and hasattr(model, "classes_"):
probabilities = model.predict_proba([features])[0]
classes = [str(cls) for cls in model.classes_]
ranked = [
label
for _, label in sorted(
zip(probabilities, classes),
key=lambda item: item[0],
reverse=True,
)
]
else:
ranked = [str(model.predict([features])[0])]
for action_type in (
ActionType.REBOOK_PASSENGER.value,
ActionType.OFFER_DOWNGRADE.value,
ActionType.REBOOK_ON_PARTNER.value,
ActionType.BOOK_HOTEL.value,
ActionType.MARK_NO_SOLUTION.value,
ActionType.FINALIZE.value,
):
if action_type not in ranked:
ranked.append(action_type)
return ranked
def _predict_ml_policy_action(observation: Dict[str, Any], ml_policy_artifact: Dict[str, Any]) -> Dict[str, Any]:
model = ml_policy_artifact["model"]
features = observation_to_features(observation)
ranked_action_types = _rank_action_types_from_model(model, features)
return choose_action_from_ranked_types(observation, ranked_action_types)
def _predict_ml_ranked_action_types(observation: Dict[str, Any], ml_policy_artifact: Dict[str, Any]) -> List[str]:
model = ml_policy_artifact["model"]
features = observation_to_features(observation)
return _rank_action_types_from_model(model, features)
def _feasible_actions_from_observation(observation: Dict[str, Any]) -> List[Action]:
pending = list(observation.get("pending_passengers", []))
flights = list(observation.get("available_flights", []))
budget_remaining = float(observation.get("budget_remaining", 0.0))
if not pending:
return [Action(action_type=ActionType.FINALIZE)]
actions: List[Action] = []
for passenger in pending:
for flight in flights:
if (not flight.get("is_partner", False)) and _has_seat(flight, str(passenger.get("cabin_class", ""))):
actions.append(
Action(
action_type=ActionType.REBOOK_PASSENGER,
passenger_id=passenger["id"],
flight_id=flight["id"],
)
)
if (
passenger.get("cabin_class") == CabinClass.BUSINESS.value
and (not flight.get("is_partner", False))
and int(flight.get("economy_seats", 0)) > 0
and budget_remaining >= 500.0
):
actions.append(
Action(
action_type=ActionType.OFFER_DOWNGRADE,
passenger_id=passenger["id"],
flight_id=flight["id"],
)
)
if (
flight.get("is_partner", False)
and _has_seat(flight, str(passenger.get("cabin_class", "")))
and budget_remaining >= 800.0
):
actions.append(
Action(
action_type=ActionType.REBOOK_ON_PARTNER,
passenger_id=passenger["id"],
flight_id=flight["id"],
)
)
if budget_remaining >= 250.0:
actions.append(
Action(
action_type=ActionType.BOOK_HOTEL,
passenger_id=passenger["id"],
)
)
actions.append(
Action(
action_type=ActionType.MARK_NO_SOLUTION,
passenger_id=passenger["id"],
)
)
actions.append(Action(action_type=ActionType.FINALIZE))
return actions
def _action_cost(action_type: ActionType) -> float:
return {
ActionType.REBOOK_PASSENGER: 0.0,
ActionType.OFFER_DOWNGRADE: 500.0,
ActionType.BOOK_HOTEL: 250.0,
ActionType.REBOOK_ON_PARTNER: 800.0,
ActionType.MARK_NO_SOLUTION: 0.0,
ActionType.FINALIZE: 0.0,
}.get(action_type, 0.0)
def _action_priority_score(observation: Dict[str, Any], action: Action) -> float:
pending = list(observation.get("pending_passengers", []))
if action.action_type == ActionType.FINALIZE:
return 10.0 if not pending else -10.0
pending_by_id = {p["id"]: p for p in pending}
flights_by_id = {f["id"]: f for f in observation.get("available_flights", [])}
passenger = pending_by_id.get(action.passenger_id or "")
if passenger is None:
return -100.0
tier_component = _tier_weight(str(passenger.get("priority_tier", ""))) / 4.0
deadline = passenger.get("connection_deadline_hrs")
if deadline is None:
deadline_component = 0.0
else:
deadline_component = (12.0 - min(max(float(deadline), 0.0), 12.0)) / 12.0
score = (0.65 * tier_component) + (0.35 * deadline_component)
type_bonus = {
ActionType.REBOOK_PASSENGER: 0.60,
ActionType.OFFER_DOWNGRADE: 0.30,
ActionType.REBOOK_ON_PARTNER: 0.18,
ActionType.BOOK_HOTEL: 0.10,
ActionType.MARK_NO_SOLUTION: -0.60,
ActionType.FINALIZE: 0.0,
}[action.action_type]
score += type_bonus
if action.flight_id:
flight = flights_by_id.get(action.flight_id)
if flight is not None and deadline is not None:
departure = float(flight.get("departure_hrs", 99.0))
if departure <= float(deadline):
score += 0.22
else:
score -= 0.22
budget_remaining = float(observation.get("budget_remaining", 0.0))
budget_spent = float(observation.get("budget_spent", 0.0))
budget_total = max(budget_remaining + budget_spent, 1.0)
score -= 0.35 * min(_action_cost(action.action_type) / budget_total, 1.0)
return score
def _prune_candidate_actions(
observation: Dict[str, Any],
actions: List[Action],
ranked_action_types: Optional[List[str]],
max_candidates: int,
) -> List[Action]:
deduped: List[Action] = []
seen = set()
for action in actions:
signature = (action.action_type.value, action.passenger_id, action.flight_id)
if signature in seen:
continue
seen.add(signature)
deduped.append(action)
rank_index: Dict[str, int] = {}
if ranked_action_types:
rank_index = {action_type: idx for idx, action_type in enumerate(ranked_action_types)}
deduped.sort(
key=lambda action: (
rank_index.get(action.action_type.value, 999),
-_action_priority_score(observation, action),
)
)
return deduped[: max(1, max_candidates)]
def _rollout_heuristic_to_end(env: FlightRebookingEnv) -> None:
done = False
while not done:
observation = env._get_observation().model_dump(mode="json")
action = Action(**_heuristic_action(observation))
_, _, done, _ = env.step(action)
def _evaluate_state_with_lookahead(
env: FlightRebookingEnv,
task_key: str,
lookahead_depth: int,
lookahead_width: int,
ranked_action_types: Optional[List[str]] = None,
) -> float:
observation = env._get_observation().model_dump(mode="json")
candidate_actions = _feasible_actions_from_observation(observation)
if ranked_action_types:
preferred_types = set(ranked_action_types[:5])
preferred_types.add(ActionType.FINALIZE.value)
preferred_types.add(ActionType.MARK_NO_SOLUTION.value)
preferred_candidates = [a for a in candidate_actions if a.action_type.value in preferred_types]
if preferred_candidates:
candidate_actions = preferred_candidates
candidate_actions = _prune_candidate_actions(
observation=observation,
actions=candidate_actions,
ranked_action_types=ranked_action_types,
max_candidates=lookahead_width,
)
best_score = -1.0
for action in candidate_actions:
env_copy = deepcopy(env)
_, _, done, _ = env_copy.step(action)
if done:
score = float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"]))
elif lookahead_depth <= 1:
_rollout_heuristic_to_end(env_copy)
score = float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"]))
else:
score = _evaluate_state_with_lookahead(
env=env_copy,
task_key=task_key,
lookahead_depth=lookahead_depth - 1,
lookahead_width=lookahead_width,
ranked_action_types=None,
)
if score > best_score:
best_score = score
if best_score >= 0.0:
return best_score
env_fallback = deepcopy(env)
_rollout_heuristic_to_end(env_fallback)
return float(grade_task(task_key, env_fallback.state(), TASKS[task_key]["max_budget"]))
def _projected_score_for_action(
env: FlightRebookingEnv,
task_key: str,
action: Action,
lookahead_depth: int,
lookahead_width: int,
) -> float:
env_copy = deepcopy(env)
_, _, done, _ = env_copy.step(action)
if done:
return float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"]))
if lookahead_depth <= 1:
_rollout_heuristic_to_end(env_copy)
return float(grade_task(task_key, env_copy.state(), TASKS[task_key]["max_budget"]))
return _evaluate_state_with_lookahead(
env=env_copy,
task_key=task_key,
lookahead_depth=lookahead_depth - 1,
lookahead_width=lookahead_width,
ranked_action_types=None,
)
def _choose_lookahead_action(
env: FlightRebookingEnv,
task_key: str,
lookahead_depth: int,
lookahead_width: int,
ranked_action_types: Optional[List[str]] = None,
) -> Dict[str, Any]:
observation = env._get_observation().model_dump(mode="json")
candidate_actions = _feasible_actions_from_observation(observation)
if ranked_action_types:
preferred_types = set(ranked_action_types[:5])
preferred_types.add(ActionType.FINALIZE.value)
preferred_types.add(ActionType.MARK_NO_SOLUTION.value)
preferred_candidates = [a for a in candidate_actions if a.action_type.value in preferred_types]
if preferred_candidates:
candidate_actions = preferred_candidates
best_action = candidate_actions[0]
best_score = -1.0
for action in candidate_actions:
try:
projected_score = _projected_score_for_action(
env=env,
task_key=task_key,
action=action,
lookahead_depth=lookahead_depth,
lookahead_width=lookahead_width,
)
except Exception:
continue
if projected_score > best_score:
best_score = projected_score
best_action = action
return best_action.model_dump(mode="json")
def _pick_best_payload_by_projection(
env: FlightRebookingEnv,
task_key: str,
payloads: List[Dict[str, Any]],
lookahead_depth: int,
lookahead_width: int,
) -> Dict[str, Any]:
best_payload = payloads[0]
best_score = -1.0
seen_signatures = set()
for payload in payloads:
try:
action = Action(**payload)
except Exception:
continue
signature = (action.action_type.value, action.passenger_id, action.flight_id)
if signature in seen_signatures:
continue
seen_signatures.add(signature)
try:
projected_score = _projected_score_for_action(
env=env,
task_key=task_key,
action=action,
lookahead_depth=lookahead_depth,
lookahead_width=lookahead_width,
)
except Exception:
continue
if projected_score > best_score:
best_score = projected_score
best_payload = action.model_dump(mode="json")
return best_payload
def _extract_json(text: str) -> Dict[str, Any]:
text = (text or "").strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
if fenced:
return json.loads(fenced.group(1))
inline = re.search(r"\{.*\}", text, re.DOTALL)
if inline:
return json.loads(inline.group(0))
raise ValueError("No valid JSON action in model output")
def _tier_weight(tier: str) -> int:
return {
PriorityTier.PLATINUM.value: 4,
PriorityTier.GOLD.value: 3,
PriorityTier.SILVER.value: 2,
PriorityTier.STANDARD.value: 1,
}.get(tier, 1)
def _has_seat(flight: Dict[str, Any], cabin_class: str) -> bool:
if cabin_class == CabinClass.BUSINESS.value:
return flight["business_seats"] > 0
return flight["economy_seats"] > 0
def _heuristic_action(observation: Dict[str, Any]) -> Dict[str, Any]:
pending = list(observation["pending_passengers"])
if not pending:
return {"action_type": ActionType.FINALIZE.value}
pending.sort(
key=lambda p: (
-_tier_weight(p["priority_tier"]),
p["connection_deadline_hrs"] if p["connection_deadline_hrs"] is not None else 10**9,
)
)
passenger = pending[0]
flights = sorted(observation["available_flights"], key=lambda f: f["departure_hrs"])
for flight in flights:
if flight["is_partner"]:
continue
if _has_seat(flight, passenger["cabin_class"]):
return {
"action_type": ActionType.REBOOK_PASSENGER.value,
"passenger_id": passenger["id"],
"flight_id": flight["id"],
}
if passenger["cabin_class"] == CabinClass.BUSINESS.value:
for flight in flights:
if flight["is_partner"]:
continue
if flight["economy_seats"] > 0 and observation["budget_remaining"] >= 500:
return {
"action_type": ActionType.OFFER_DOWNGRADE.value,
"passenger_id": passenger["id"],
"flight_id": flight["id"],
}
for flight in flights:
if not flight["is_partner"]:
continue
if _has_seat(flight, passenger["cabin_class"]) and observation["budget_remaining"] >= 800:
return {
"action_type": ActionType.REBOOK_ON_PARTNER.value,
"passenger_id": passenger["id"],
"flight_id": flight["id"],
}
if observation["budget_remaining"] >= 250:
return {
"action_type": ActionType.BOOK_HOTEL.value,
"passenger_id": passenger["id"],
}
return {
"action_type": ActionType.MARK_NO_SOLUTION.value,
"passenger_id": passenger["id"],
}
def _is_action_feasible(observation: Dict[str, Any], payload: Dict[str, Any]) -> bool:
action_type = payload["action_type"]
if action_type == ActionType.FINALIZE.value:
return True
pending_by_id = {p["id"]: p for p in observation["pending_passengers"]}
flights_by_id = {f["id"]: f for f in observation["available_flights"]}
budget_remaining = float(observation["budget_remaining"])
passenger = pending_by_id.get(payload.get("passenger_id"))
if passenger is None:
return False
if action_type == ActionType.BOOK_HOTEL.value:
return budget_remaining >= 250
if action_type == ActionType.MARK_NO_SOLUTION.value:
return True
flight = flights_by_id.get(payload.get("flight_id"))
if flight is None:
return False
passenger_cabin = passenger["cabin_class"]
needs_business = passenger_cabin == CabinClass.BUSINESS.value
has_matching_cabin_seat = (flight["business_seats"] > 0) if needs_business else (flight["economy_seats"] > 0)
if action_type == ActionType.REBOOK_PASSENGER.value:
return (not flight["is_partner"]) and has_matching_cabin_seat
if action_type == ActionType.OFFER_DOWNGRADE.value:
return (
passenger_cabin == CabinClass.BUSINESS.value
and budget_remaining >= 500
and flight["economy_seats"] > 0
)
if action_type == ActionType.REBOOK_ON_PARTNER.value:
return flight["is_partner"] and budget_remaining >= 800 and has_matching_cabin_seat
return False
def _sanitize_action_payload(observation: Dict[str, Any], payload: Any) -> Dict[str, Any]:
fallback = _heuristic_action(observation)
if not isinstance(payload, dict):
return fallback
valid_action_types = {action_type.value for action_type in ActionType}
action_type = str(payload.get("action_type", "")).strip()
if action_type not in valid_action_types:
return fallback
sanitized: Dict[str, Any] = {"action_type": action_type}
passenger_id = str(payload.get("passenger_id", "")).strip()
flight_id = str(payload.get("flight_id", "")).strip()
if passenger_id:
sanitized["passenger_id"] = passenger_id
if flight_id:
sanitized["flight_id"] = flight_id
if action_type == ActionType.FINALIZE.value:
return sanitized
pending_ids = {p["id"] for p in observation["pending_passengers"]}
if sanitized.get("passenger_id") not in pending_ids:
return fallback
if action_type in {
ActionType.REBOOK_PASSENGER.value,
ActionType.OFFER_DOWNGRADE.value,
ActionType.REBOOK_ON_PARTNER.value,
}:
flight_ids = {f["id"] for f in observation["available_flights"]}
if sanitized.get("flight_id") not in flight_ids:
return fallback
if not _is_action_feasible(observation, sanitized):
return fallback
return sanitized
def _query_openai_action(
client: OpenAI,
model_name: str,
seed: int,
observation_json: str,
policy_hint_json: Optional[str] = None,
max_retries: int = 2,
) -> Dict[str, Any]:
last_error: Optional[Exception] = None
for _ in range(max_retries + 1):
try:
user_content = f"Current observation: {observation_json}"
if policy_hint_json:
user_content += (
"\nSuggested safe action from a trained policy: "
f"{policy_hint_json}"
"\nPrefer this if it is valid for the current observation."
)
kwargs: Dict[str, Any] = {
"model": model_name,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
"temperature": 0,
"top_p": 1,
"max_tokens": 220,
"seed": seed,
}
response = None
try:
response = client.chat.completions.create(**kwargs)
except TypeError:
kwargs.pop("seed", None)
response = client.chat.completions.create(**kwargs)
content = response.choices[0].message.content or ""
return _extract_json(content)
except Exception as exc:
last_error = exc
raise RuntimeError(f"OpenAI call failed after retries: {last_error}")
def _emit_start(task_name: str, benchmark: str, model_name: str) -> None:
print(f"[START] task={task_name} env={benchmark} model={model_name}", flush=True)
def _format_action_for_log(action: Action) -> str:
payload = {
"action_type": action.action_type.value,
"passenger_id": action.passenger_id,
"flight_id": action.flight_id,
}
return json.dumps(payload, separators=(",", ":"), ensure_ascii=True)
def _emit_step(
step_index: int,
action_text: str,
reward_value: float,
done: bool,
error: Optional[str],
) -> None:
done_value = str(bool(done)).lower()
error_value = error if error else "null"
print(
"[STEP] "
f"step={step_index} "
f"action={action_text} "
f"reward={reward_value:.2f} "
f"done={done_value} "
f"error={error_value}",
flush=True,
)
def _emit_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_text = ",".join(f"{value:.2f}" for value in rewards)
success_value = str(bool(success)).lower()
print(
"[END] "
f"success={success_value} "
f"steps={steps} "
f"score={score:.4f} "
f"rewards={rewards_text}",
flush=True,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run submission inference across OpenEnv tasks.")
parser.add_argument("--task", choices=["all", "easy", "medium", "hard"], default="all")
parser.add_argument("--seed", type=int, default=int(os.getenv("BASELINE_SEED", "42")))
parser.add_argument(
"--policy",
choices=["openai", "heuristic", "trained_ml", "openai_trained"],
default="openai_trained",
help=(
"Policy backend. openai_trained uses Llama with trained-policy hints; "
"trained_ml uses the learned policy directly; openai and heuristic remain available."
),
)
parser.add_argument(
"--ml-policy-path",
default=os.getenv("ML_POLICY_PATH", "artifacts/ml_policy.pkl"),
help="Path to trained ML policy artifact used by trained_ml/openai_trained modes.",
)
parser.add_argument(
"--lookahead-depth",
type=int,
default=int(os.getenv("LOOKAHEAD_DEPTH", "2")),
help="Lookahead depth for projected action scoring (>=1).",
)
parser.add_argument(
"--lookahead-width",
type=int,
default=int(os.getenv("LOOKAHEAD_WIDTH", "12")),
help="Maximum candidate actions explored per lookahead level (>=1).",
)
parser.add_argument("--json-out", default="", help="Optional JSON output path.")
return parser.parse_args()
def main() -> None:
args = parse_args()
args.lookahead_depth = max(1, int(args.lookahead_depth))
args.lookahead_width = max(1, int(args.lookahead_width))
task_keys = ["easy", "medium", "hard"] if args.task == "all" else [args.task]
effective_policy = args.policy
ml_policy_artifact: Optional[Dict[str, Any]] = None
if effective_policy in {"trained_ml", "openai_trained"}:
ml_policy_artifact = _require_ml_policy_artifact(args.ml_policy_path, effective_policy)
api_base_url = "heuristic"
model_name = "heuristic"
client: Optional[OpenAI] = None
if effective_policy in {"openai", "openai_trained"}:
model_config = _resolve_model_config()
api_base_url = model_config["api_base_url"]
model_name = model_config["model_name"]
client = OpenAI(api_key=model_config["api_key"], base_url=api_base_url)
results: List[Dict[str, Any]] = []
for task_key in task_keys:
task_data = TASKS[task_key]
_emit_start(task_name=task_data["task_id"], benchmark=BENCHMARK_NAME, model_name=model_name)
env = FlightRebookingEnv(task_data=task_data)
observation = None
done = False
steps = 0
rewards: List[float] = []
score = 0.01
success = False
episode_error: Optional[str] = None
try:
observation = env.reset()
while not done:
observation_dict = observation.model_dump(mode="json")
if effective_policy in {"openai", "openai_trained"}:
assert client is not None
policy_hint_payload: Optional[Dict[str, Any]] = None
if effective_policy == "openai_trained":
assert ml_policy_artifact is not None
ranked_types = _predict_ml_ranked_action_types(observation_dict, ml_policy_artifact)
policy_hint_payload = _choose_lookahead_action(
env=env,
task_key=task_key,
lookahead_depth=args.lookahead_depth,
lookahead_width=args.lookahead_width,
ranked_action_types=ranked_types,
)
raw_payload = _query_openai_action(
client=client,
model_name=model_name,
seed=args.seed,
observation_json=observation.model_dump_json(),
policy_hint_json=(json.dumps(policy_hint_payload) if policy_hint_payload is not None else None),
)
llm_payload = _sanitize_action_payload(observation_dict, raw_payload)
if effective_policy == "openai_trained" and policy_hint_payload is not None:
action_payload = _pick_best_payload_by_projection(
env=env,
task_key=task_key,
payloads=[policy_hint_payload, llm_payload],
lookahead_depth=args.lookahead_depth,
lookahead_width=args.lookahead_width,
)
else:
action_payload = llm_payload
elif effective_policy == "trained_ml":
assert ml_policy_artifact is not None
ranked_types = _predict_ml_ranked_action_types(observation_dict, ml_policy_artifact)
action_payload = _choose_lookahead_action(
env=env,
task_key=task_key,
lookahead_depth=args.lookahead_depth,
lookahead_width=args.lookahead_width,
ranked_action_types=ranked_types,
)
else:
action_payload = _heuristic_action(observation_dict)
try:
action = Action(**action_payload)
except Exception:
action = Action(action_type=ActionType.FINALIZE)
step_error: Optional[str] = None
reward_value = 0.0
try:
observation, reward, done, info = env.step(action)
reward_value = float(reward.value)
if isinstance(info, dict) and info.get("error"):
step_error = str(info.get("error"))
except Exception as exc:
done = True
step_error = str(exc)
episode_error = step_error
steps += 1
rewards.append(reward_value)
_emit_step(
step_index=steps,
action_text=_format_action_for_log(action),
reward_value=reward_value,
done=done,
error=step_error,
)
try:
final_state = env.state()
score = float(grade_task(task_key, final_state, task_data["max_budget"]))
except Exception as exc:
episode_error = str(exc)
score = 0.01
except Exception as exc:
episode_error = str(exc)
score = 0.01
finally:
close_fn = getattr(env, "close", None)
if callable(close_fn):
try:
close_fn()
except Exception as exc:
if not episode_error:
episode_error = str(exc)
success = (episode_error is None) and (0.0 <= score <= 1.0) and (score >= SUCCESS_SCORE_THRESHOLD)
_emit_end(success=success, steps=steps, score=score, rewards=rewards)
try:
final_state = env.state()
avg_step_reward = sum(rewards) / max(len(rewards), 1)
results.append(
{
"task": task_key,
"task_id": task_data["task_id"],
"difficulty": task_data["difficulty"],
"steps": steps,
"avg_step_reward": round(avg_step_reward, 4),
"score": round(score, 4),
"budget_spent": round(final_state.budget_spent, 2),
"budget_max": task_data["max_budget"],
"invalid_actions": final_state.invalid_actions,
"success": success,
"error": episode_error,
}
)
except Exception:
avg_step_reward = sum(rewards) / max(len(rewards), 1)
results.append(
{
"task": task_key,
"task_id": task_data["task_id"],
"difficulty": task_data["difficulty"],
"steps": steps,
"avg_step_reward": round(avg_step_reward, 4),
"score": round(score, 4),
"budget_spent": None,
"budget_max": task_data["max_budget"],
"invalid_actions": None,
"success": success,
"error": episode_error,
}
)
overall = sum(item["score"] for item in results) / max(len(results), 1)
if args.json_out:
payload = {
"policy_requested": args.policy,
"policy_effective": effective_policy,
"seed": args.seed,
"api_base_url": api_base_url,
"model_name": model_name,
"ml_policy_path": args.ml_policy_path,
"ml_policy_loaded": ml_policy_artifact is not None,
"lookahead_depth": args.lookahead_depth,
"lookahead_width": args.lookahead_width,
"overall_score": round(overall, 4),
"tasks": results,
}
with open(args.json_out, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
if __name__ == "__main__":
main()