flight-rebooking / tasks.py
dhnkhr's picture
Production-ready: Clean code with Groq API integration, LoRA model support, and FastAPI app
9753ee2
"""
Task Definitions and Deterministic Graders
===========================================
Provides three progressively difficult real-world disruption tasks and
task-specific grading functions that return normalized scores in [0.0, 1.0].
"""
from typing import Dict, List
from environment import EnvState, PassengerStatus, PriorityTier
EASY_TASK = {
"task_id": "easy_minor_disruption",
"difficulty": "easy",
"objective": "Rebook all passengers on same-airline flights while preserving premium service and minimizing spend.",
"max_budget": 3000,
"max_steps": 40,
"passengers": [
{
"id": "P1",
"name": "Alice Johnson",
"priority_tier": "Platinum",
"original_flight": "FL-100",
"cabin_class": "Business",
"connection_deadline_hrs": None,
},
{
"id": "P2",
"name": "Bob Smith",
"priority_tier": "Gold",
"original_flight": "FL-100",
"cabin_class": "Economy",
"connection_deadline_hrs": None,
},
{
"id": "P3",
"name": "Carol Davis",
"priority_tier": "Standard",
"original_flight": "FL-100",
"cabin_class": "Economy",
"connection_deadline_hrs": None,
},
],
"flights": [
{
"id": "FL-102",
"destination": "New York",
"departure_hrs": 3.0,
"economy_seats": 5,
"business_seats": 2,
"is_partner": False,
},
{
"id": "FL-104",
"destination": "New York",
"departure_hrs": 6.0,
"economy_seats": 10,
"business_seats": 3,
"is_partner": False,
},
{
"id": "FL-201",
"destination": "New York",
"departure_hrs": 4.0,
"economy_seats": 3,
"business_seats": 1,
"is_partner": True,
},
],
}
MEDIUM_TASK = {
"task_id": "medium_connection_crisis",
"difficulty": "medium",
"objective": "Prioritize high-tier passengers with tight deadlines under constrained seats and budget.",
"max_budget": 5000,
"max_steps": 60,
"passengers": [
{
"id": "P1",
"name": "David Lee",
"priority_tier": "Platinum",
"original_flight": "FL-300",
"cabin_class": "Business",
"connection_deadline_hrs": 4.0,
},
{
"id": "P2",
"name": "Emma Wilson",
"priority_tier": "Gold",
"original_flight": "FL-300",
"cabin_class": "Economy",
"connection_deadline_hrs": 2.5,
},
{
"id": "P3",
"name": "Frank Brown",
"priority_tier": "Silver",
"original_flight": "FL-300",
"cabin_class": "Economy",
"connection_deadline_hrs": None,
},
{
"id": "P4",
"name": "Grace Kim",
"priority_tier": "Standard",
"original_flight": "FL-300",
"cabin_class": "Business",
"connection_deadline_hrs": 5.0,
},
{
"id": "P5",
"name": "Henry Park",
"priority_tier": "Gold",
"original_flight": "FL-300",
"cabin_class": "Economy",
"connection_deadline_hrs": 3.0,
},
],
"flights": [
{
"id": "FL-302",
"destination": "Chicago",
"departure_hrs": 2.0,
"economy_seats": 2,
"business_seats": 1,
"is_partner": False,
},
{
"id": "FL-304",
"destination": "Chicago",
"departure_hrs": 5.0,
"economy_seats": 4,
"business_seats": 0,
"is_partner": False,
},
{
"id": "FL-401",
"destination": "Chicago",
"departure_hrs": 3.5,
"economy_seats": 2,
"business_seats": 1,
"is_partner": True,
},
],
}
HARD_TASK = {
"task_id": "hard_multi_wave_disruption",
"difficulty": "hard",
"objective": "Handle mixed loyalty tiers, scarce seats, and multiple urgent connections while staying under budget.",
"max_budget": 7000,
"max_steps": 90,
"passengers": [
{
"id": "P1",
"name": "Iris Patel",
"priority_tier": "Platinum",
"original_flight": "FL-500",
"cabin_class": "Business",
"connection_deadline_hrs": 2.5,
},
{
"id": "P2",
"name": "Jack Rivera",
"priority_tier": "Gold",
"original_flight": "FL-500",
"cabin_class": "Economy",
"connection_deadline_hrs": 2.0,
},
{
"id": "P3",
"name": "Karen Novak",
"priority_tier": "Gold",
"original_flight": "FL-500",
"cabin_class": "Business",
"connection_deadline_hrs": 4.0,
},
{
"id": "P4",
"name": "Liam Chen",
"priority_tier": "Silver",
"original_flight": "FL-500",
"cabin_class": "Economy",
"connection_deadline_hrs": 3.0,
},
{
"id": "P5",
"name": "Maya Brooks",
"priority_tier": "Standard",
"original_flight": "FL-500",
"cabin_class": "Economy",
"connection_deadline_hrs": None,
},
{
"id": "P6",
"name": "Noah Singh",
"priority_tier": "Platinum",
"original_flight": "FL-500",
"cabin_class": "Business",
"connection_deadline_hrs": 3.5,
},
{
"id": "P7",
"name": "Olivia Green",
"priority_tier": "Silver",
"original_flight": "FL-500",
"cabin_class": "Economy",
"connection_deadline_hrs": 5.0,
},
{
"id": "P8",
"name": "Peter Hall",
"priority_tier": "Standard",
"original_flight": "FL-500",
"cabin_class": "Economy",
"connection_deadline_hrs": 2.8,
},
],
"flights": [
{
"id": "FL-502",
"destination": "San Francisco",
"departure_hrs": 1.8,
"economy_seats": 2,
"business_seats": 1,
"is_partner": False,
},
{
"id": "FL-504",
"destination": "San Francisco",
"departure_hrs": 3.0,
"economy_seats": 2,
"business_seats": 1,
"is_partner": False,
},
{
"id": "FL-506",
"destination": "San Francisco",
"departure_hrs": 5.5,
"economy_seats": 3,
"business_seats": 0,
"is_partner": False,
},
{
"id": "FL-701",
"destination": "San Francisco",
"departure_hrs": 2.2,
"economy_seats": 2,
"business_seats": 1,
"is_partner": True,
},
{
"id": "FL-703",
"destination": "San Francisco",
"departure_hrs": 4.4,
"economy_seats": 2,
"business_seats": 1,
"is_partner": True,
},
],
}
TASKS = {
"easy": EASY_TASK,
"medium": MEDIUM_TASK,
"hard": HARD_TASK,
}
_OUTCOME_SCORES = {
PassengerStatus.REBOOKED: 1.00,
PassengerStatus.PARTNER_REBOOKED: 0.85,
PassengerStatus.DOWNGRADED: 0.65,
PassengerStatus.HOTEL_BOOKED: 0.40,
PassengerStatus.NO_SOLUTION: 0.00,
PassengerStatus.PENDING: 0.00,
}
_TIER_WEIGHTS = {
PriorityTier.PLATINUM: 4,
PriorityTier.GOLD: 3,
PriorityTier.SILVER: 2,
PriorityTier.STANDARD: 1,
}
_GRADING_PROFILES = {
"easy": {
"quality": 0.45,
"coverage": 0.20,
"connection": 0.10,
"budget": 0.15,
"policy": 0.10,
},
"medium": {
"quality": 0.38,
"coverage": 0.17,
"connection": 0.22,
"budget": 0.13,
"policy": 0.10,
},
"hard": {
"quality": 0.30,
"coverage": 0.15,
"connection": 0.30,
"budget": 0.15,
"policy": 0.10,
},
}
def _clamp(value: float) -> float:
return max(0.01, min(0.99, value))
def _resolve_tier_weight(tier: PriorityTier) -> int:
if isinstance(tier, str):
tier = PriorityTier(tier)
return _TIER_WEIGHTS.get(tier, 1)
def _resolve_outcome_score(status: PassengerStatus) -> float:
if isinstance(status, str):
status = PassengerStatus(status)
return _OUTCOME_SCORES.get(status, 0.0)
def _connection_score(state: EnvState) -> float:
deadline_passengers = [p for p in state.passengers if p.connection_deadline_hrs is not None]
if not deadline_passengers:
return 0.99
weighted_hits = 0.0
weighted_total = 0.0
flights_by_id = {f.id: f for f in state.flights}
for passenger in deadline_passengers:
weight = _resolve_tier_weight(passenger.priority_tier)
weighted_total += weight
if passenger.assigned_flight is None:
continue
flight = flights_by_id.get(passenger.assigned_flight)
if flight is None:
continue
if flight.departure_hrs <= passenger.connection_deadline_hrs:
weighted_hits += weight
else:
weighted_hits += weight * 0.2
if weighted_total <= 0:
return 0.01
return _clamp(weighted_hits / weighted_total)
def _coverage_score(state: EnvState) -> float:
if not state.passengers:
return 0.01
resolved = sum(1 for p in state.passengers if p.status != PassengerStatus.PENDING)
return _clamp(resolved / len(state.passengers))
def _quality_score(state: EnvState) -> float:
weighted_sum = 0.0
weighted_total = 0.0
for passenger in state.passengers:
weight = _resolve_tier_weight(passenger.priority_tier)
weighted_total += weight
weighted_sum += weight * _resolve_outcome_score(passenger.status)
if weighted_total <= 0:
return 0.01
return _clamp(weighted_sum / weighted_total)
def _budget_score(state: EnvState, max_budget: float) -> float:
if max_budget <= 0:
return 0.99
return _clamp(1.0 - (state.budget_spent / max_budget))
def _policy_score(state: EnvState) -> float:
invalid_actions = max(getattr(state, "invalid_actions", 0), 0)
invalid_penalty = min(invalid_actions * 0.03, 0.3)
order: Dict[str, int] = {}
step = 0
for event in state.actions_taken:
if not event.get("success", False):
continue
action = event.get("action", {})
passenger_id = action.get("passenger_id")
if passenger_id and passenger_id not in order:
order[passenger_id] = step
step += 1
inversion_pairs = 0
total_pairs = 0
passengers = list(state.passengers)
for i in range(len(passengers)):
for j in range(i + 1, len(passengers)):
p_i = passengers[i]
p_j = passengers[j]
w_i = _resolve_tier_weight(p_i.priority_tier)
w_j = _resolve_tier_weight(p_j.priority_tier)
if w_i == w_j:
continue
if p_i.id not in order or p_j.id not in order:
continue
total_pairs += 1
if w_i > w_j and order[p_i.id] > order[p_j.id]:
inversion_pairs += 1
if w_j > w_i and order[p_j.id] > order[p_i.id]:
inversion_pairs += 1
inversion_penalty = (inversion_pairs / total_pairs) if total_pairs > 0 else 0.0
return _clamp(1.0 - invalid_penalty - inversion_penalty)
def _grade_with_profile(state: EnvState, max_budget: float, profile_name: str) -> float:
profile = _GRADING_PROFILES[profile_name]
quality = _quality_score(state)
coverage = _coverage_score(state)
connection = _connection_score(state)
budget = _budget_score(state, max_budget)
policy = _policy_score(state)
final = (
profile["quality"] * quality
+ profile["coverage"] * coverage
+ profile["connection"] * connection
+ profile["budget"] * budget
+ profile["policy"] * policy
)
return _clamp(final)
def grade_easy_episode(state: EnvState, max_budget: float) -> float:
return _grade_with_profile(state, max_budget, "easy")
def grade_medium_episode(state: EnvState, max_budget: float) -> float:
return _grade_with_profile(state, max_budget, "medium")
def grade_hard_episode(state: EnvState, max_budget: float) -> float:
return _grade_with_profile(state, max_budget, "hard")
TASK_GRADERS = {
"easy": grade_easy_episode,
"medium": grade_medium_episode,
"hard": grade_hard_episode,
}
def grade_task(task_key: str, state: EnvState, max_budget: float) -> float:
grader = TASK_GRADERS[task_key]
score = grader(state, max_budget)
# Enforce strict (0, 1) bounds required by the validator
return max(0.01, min(0.99, float(score)))
def grade_episode(state: EnvState, max_budget: float) -> float:
"""Backward-compatible default grader, mapped to medium difficulty."""
score = grade_medium_episode(state, max_budget)
return max(0.01, min(0.99, float(score)))