openenv-hackathon / graders /grader_fix.py
Vittal-M's picture
Upload graders/grader_fix.py with huggingface_hub
d51b70a verified
"""Grader for Task 3 — Schedule Repair (hard).
Scoring breakdown (additive, max 1.0)
--------------------------------------
0.20 — response is parseable JSON
0.20 — JSON has the required schema (assignments list, all jobs covered)
0.40 — schedule satisfies all constraints (0.10 per category):
capacity, deadlines, precedence, availability
0.20 — makespan within 30% of optimal (0.10 partial if within 60%)
Partial-progress signal
-----------------------
Even a structurally invalid JSON attempt earns 0.0 (wrong format).
A parseable but schema-invalid JSON earns 0.20 (gave a JSON object).
A valid schema with partial constraint satisfaction earns up to 0.80.
This dense reward curve supports multi-step improvement within an episode.
After each call, ``last_breakdown`` holds a full dict with per-category
pass/fail flags, makespan, and the optimality ratio — surfaced in the
environment's info dict.
"""
from __future__ import annotations
import json
import re
from typing import Any
from models import Action
class RepairGrader:
"""Grade the agent's proposed schedule repair."""
def __init__(self) -> None:
self.last_breakdown: dict[str, Any] = {}
def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
response: str = action.response.strip()
instance: dict[str, Any] = ground_truth.get("instance", {})
optimal_makespan: int = int(ground_truth.get("optimal_makespan", 1) or 1)
if not response:
self._record_breakdown(
json_ok=False, schema_ok=False,
constraint_detail={}, makespan=0,
optimal_makespan=optimal_makespan,
)
return 0.0
score = 0.0
# ------------------------------------------------------------------
# Component 1a — Is the response parseable JSON? (0.20)
# ------------------------------------------------------------------
parsed = self._parse_json(response)
if parsed is None:
self._record_breakdown(
json_ok=False, schema_ok=False,
constraint_detail={}, makespan=0,
optimal_makespan=optimal_makespan,
)
return 0.0 # not JSON → no partial credit at all
score += 0.20 # JSON parseable
# ------------------------------------------------------------------
# Component 1b — Does it have the required schema? (0.20)
# Required: {"assignments": [{"job_id", "machine_id", "start_time"}, ...]}
# All jobs from the instance must be present exactly once.
# ------------------------------------------------------------------
assignments: list[Any] = parsed.get("assignments", [])
schema_ok = self._valid_schema(assignments, instance)
if not schema_ok:
self._record_breakdown(
json_ok=True, schema_ok=False,
constraint_detail={}, makespan=0,
optimal_makespan=optimal_makespan,
)
return round(score, 4) # only 0.20
score += 0.20 # valid schema
# ------------------------------------------------------------------
# Component 2 — Constraint satisfaction (0.40, 0.10 per category)
# Categories: capacity, deadlines, precedence, availability
# ------------------------------------------------------------------
constraint_detail = self._check_constraints_detail(assignments, instance)
satisfied = sum(constraint_detail.values())
score += 0.40 * (satisfied / max(len(constraint_detail), 1))
# ------------------------------------------------------------------
# Component 3 — Makespan optimality (0.20)
# Full 0.20 if makespan ≤ optimal × 1.30; partial 0.10 if ≤ 1.60.
# ------------------------------------------------------------------
makespan = self._compute_makespan(assignments, instance)
if makespan > 0 and optimal_makespan > 0:
ratio = makespan / optimal_makespan
if ratio <= 1.30:
score += 0.20
elif ratio <= 1.60:
score += 0.10 # partial optimality credit
self._record_breakdown(
json_ok=True, schema_ok=True,
constraint_detail=constraint_detail,
makespan=makespan,
optimal_makespan=optimal_makespan,
)
return round(max(0.0, min(1.0, score)), 4)
# ------------------------------------------------------------------
# Breakdown recording
# ------------------------------------------------------------------
def _record_breakdown(
self,
json_ok: bool,
schema_ok: bool,
constraint_detail: dict[str, bool],
makespan: int,
optimal_makespan: int,
) -> None:
ratio = (
round(makespan / optimal_makespan, 3)
if (makespan > 0 and optimal_makespan > 0)
else None
)
self.last_breakdown = {
"json_parseable": json_ok,
"schema_valid": schema_ok,
"constraints": constraint_detail,
"constraints_satisfied": sum(constraint_detail.values()) if constraint_detail else 0,
"makespan": makespan,
"optimal_makespan": optimal_makespan,
"makespan_ratio": ratio,
"within_30pct": ratio is not None and ratio <= 1.30,
}
# ------------------------------------------------------------------
# JSON parsing — robust to markdown fences and partial wrapping
# ------------------------------------------------------------------
@staticmethod
def _parse_json(response: str) -> dict[str, Any] | None:
"""Try multiple strategies to extract a JSON object from the response.
Strategy 1: Direct json.loads (agent returned pure JSON).
Strategy 2: Strip markdown code fences, then parse.
Strategy 3: Brace-counting to find the outermost {...} block.
This is the most robust and handles agents that wrap JSON
in prose like "Here is my answer: {...}".
"""
# Strategy 1 — direct parse
try:
obj = json.loads(response)
return obj if isinstance(obj, dict) else None
except (json.JSONDecodeError, ValueError):
pass
# Strategy 2 — strip code fences
stripped = re.sub(r"```(?:json)?", "", response).replace("```", "").strip()
try:
obj = json.loads(stripped)
return obj if isinstance(obj, dict) else None
except (json.JSONDecodeError, ValueError):
pass
# Strategy 3 — brace-counting for the outermost { ... }
start = response.find("{")
if start == -1:
return None
depth = 0
for i, ch in enumerate(response[start:], start):
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
candidate = response[start : i + 1]
try:
obj = json.loads(candidate)
return obj if isinstance(obj, dict) else None
except (json.JSONDecodeError, ValueError):
return None
return None
# ------------------------------------------------------------------
# Schema validation
# ------------------------------------------------------------------
@staticmethod
def _valid_schema(
assignments: list[Any], instance: dict[str, Any]
) -> bool:
"""Validate that assignments is a well-formed list covering all jobs."""
if not isinstance(assignments, list) or len(assignments) == 0:
return False
required_keys = {"job_id", "machine_id", "start_time"}
for a in assignments:
if not isinstance(a, dict):
return False
if not required_keys.issubset(a.keys()):
return False
if not isinstance(a.get("start_time"), (int, float)):
return False
if a.get("start_time") < 0:
return False # negative start times are never valid
# Every job in the instance must appear exactly once
expected_jobs = {j["id"] for j in instance.get("jobs", [])}
assigned_jobs = [a["job_id"] for a in assignments]
return set(assigned_jobs) == expected_jobs and len(assigned_jobs) == len(expected_jobs)
# ------------------------------------------------------------------
# Constraint checking (returns per-category bool dict)
# ------------------------------------------------------------------
@staticmethod
def _check_constraints_detail(
assignments: list[dict[str, Any]], instance: dict[str, Any]
) -> dict[str, bool]:
"""Return a dict of {constraint_name: passed} for each of the 4 categories."""
jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
machines_by_id = {m["id"]: m for m in instance.get("machines", [])}
assign_by_job = {a["job_id"]: a for a in assignments}
# ---- (a) Capacity: concurrent jobs on any machine ≤ its capacity ----
machine_intervals: dict[str, list[tuple[float, float]]] = {}
for a in assignments:
mid = a["machine_id"]
st = float(a["start_time"])
dur = float(jobs_by_id.get(a["job_id"], {}).get("duration", 1))
machine_intervals.setdefault(mid, []).append((st, st + dur))
capacity_ok = True
for mid, intervals in machine_intervals.items():
cap = machines_by_id.get(mid, {}).get("capacity", 1)
for s1, e1 in intervals:
# Count how many intervals overlap with [s1, e1)
concurrent = sum(
1 for s2, e2 in intervals if s2 < e1 and e2 > s1
)
if concurrent > cap:
capacity_ok = False
break
if not capacity_ok:
break
# ---- (b) Deadlines: every job finishes by its deadline ----
deadline_ok = True
for a in assignments:
job = jobs_by_id.get(a["job_id"], {})
finish = float(a["start_time"]) + float(job.get("duration", 0))
dl = job.get("deadline", float("inf"))
if finish > dl:
deadline_ok = False
break
# ---- (c) Precedence: job starts after ALL its predecessors finish ----
precedence_ok = True
for a in assignments:
job = jobs_by_id.get(a["job_id"], {})
for dep_id in job.get("dependencies", []):
dep_a = assign_by_job.get(dep_id)
if dep_a is None:
precedence_ok = False
break
dep_job = jobs_by_id.get(dep_id, {})
dep_finish = float(dep_a["start_time"]) + float(
dep_job.get("duration", 0)
)
if float(a["start_time"]) < dep_finish:
precedence_ok = False
break
if not precedence_ok:
break
# ---- (d) Availability: job runs within machine availability window ----
availability_ok = True
for a in assignments:
machine = machines_by_id.get(a["machine_id"], {})
avail_start = float(machine.get("available_start", 0))
avail_end = float(machine.get("available_end", float("inf")))
job = jobs_by_id.get(a["job_id"], {})
job_start = float(a["start_time"])
job_end = job_start + float(job.get("duration", 0))
if job_start < avail_start or job_end > avail_end:
availability_ok = False
break
return {
"capacity": capacity_ok,
"deadlines": deadline_ok,
"precedence": precedence_ok,
"availability": availability_ok,
}
@staticmethod
def _check_constraints(
assignments: list[dict[str, Any]], instance: dict[str, Any]
) -> float:
"""Convenience wrapper — returns fraction of categories satisfied."""
detail = RepairGrader._check_constraints_detail(assignments, instance)
return sum(detail.values()) / max(len(detail), 1)
# ------------------------------------------------------------------
# Makespan calculation
# ------------------------------------------------------------------
@staticmethod
def _compute_makespan(
assignments: list[dict[str, Any]], instance: dict[str, Any]
) -> int:
"""Return the latest finish time across all assigned jobs."""
jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
max_finish = 0
for a in assignments:
job = jobs_by_id.get(a["job_id"], {})
finish = int(a["start_time"]) + int(job.get("duration", 0))
if finish > max_finish:
max_finish = finish
return max_finish