Spaces:
Sleeping
Sleeping
Upload baseline.py with huggingface_hub
Browse files- baseline.py +232 -0
baseline.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline inference script for the Scheduling Optimisation Environment.
|
| 2 |
+
|
| 3 |
+
Runs GPT-4o-mini (or falls back to deterministic mock responses) against all
|
| 4 |
+
three tasks and prints a structured score report.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
OPENAI_API_KEY=sk-... python baseline.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
from environment import INSTANCE_BANK
|
| 18 |
+
from graders.grader_classification import ConflictGrader
|
| 19 |
+
from graders.grader_detection import FeasibilityGrader
|
| 20 |
+
from graders.grader_fix import RepairGrader
|
| 21 |
+
from models import Action
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_openai_client():
|
| 25 |
+
"""Return an OpenAI client, or None if unavailable."""
|
| 26 |
+
api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 27 |
+
if not api_key:
|
| 28 |
+
return None
|
| 29 |
+
try:
|
| 30 |
+
from openai import OpenAI
|
| 31 |
+
return OpenAI(api_key=api_key)
|
| 32 |
+
except Exception:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _llm_response(client, system_prompt: str, user_prompt: str) -> str:
|
| 37 |
+
"""Call GPT-4o-mini and return the response text."""
|
| 38 |
+
try:
|
| 39 |
+
resp = client.chat.completions.create(
|
| 40 |
+
model="gpt-4o-mini",
|
| 41 |
+
messages=[
|
| 42 |
+
{"role": "system", "content": system_prompt},
|
| 43 |
+
{"role": "user", "content": user_prompt},
|
| 44 |
+
],
|
| 45 |
+
max_tokens=1024,
|
| 46 |
+
temperature=0.0,
|
| 47 |
+
)
|
| 48 |
+
return resp.choices[0].message.content.strip()
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f" [LLM error: {e}]")
|
| 51 |
+
return ""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Mock fallback responses (used when no API key is available)
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
# Ground-truth feasibility labels — index aligns with INSTANCE_BANK
|
| 59 |
+
_MOCK_FEASIBILITY: dict[int, str] = {
|
| 60 |
+
0: "infeasible", 1: "infeasible", 2: "infeasible", 3: "infeasible",
|
| 61 |
+
4: "infeasible", 5: "infeasible", 6: "infeasible", 7: "infeasible",
|
| 62 |
+
8: "infeasible", 9: "infeasible", 10: "feasible", 11: "feasible",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Ground-truth violation types for infeasible instances
|
| 66 |
+
_MOCK_CLASSIFICATION: dict[int, str] = {
|
| 67 |
+
0: "resource_overload",
|
| 68 |
+
1: "deadline_violation",
|
| 69 |
+
2: "precedence_violation",
|
| 70 |
+
3: "availability_conflict",
|
| 71 |
+
4: "capacity_exceeded",
|
| 72 |
+
5: "resource_overload",
|
| 73 |
+
6: "deadline_violation",
|
| 74 |
+
7: "precedence_violation",
|
| 75 |
+
8: "availability_conflict",
|
| 76 |
+
9: "capacity_exceeded",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _mock_repair(instance_idx: int) -> str:
|
| 81 |
+
"""Return the known optimal schedule JSON for mock mode."""
|
| 82 |
+
entry = INSTANCE_BANK[instance_idx]
|
| 83 |
+
optimal = entry.get("optimal_schedule", {})
|
| 84 |
+
if not optimal:
|
| 85 |
+
# Return the proposed schedule unchanged as a safe fallback
|
| 86 |
+
optimal = entry["instance"].get("proposed_schedule", {})
|
| 87 |
+
return json.dumps(optimal)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
# Baseline runner
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def run_baseline() -> dict[str, Any]:
|
| 96 |
+
"""Execute the baseline across all three tasks and return scores."""
|
| 97 |
+
client = _get_openai_client()
|
| 98 |
+
use_llm = client is not None
|
| 99 |
+
mode = "GPT-4o-mini" if use_llm else "mock (no API key — oracle responses)"
|
| 100 |
+
print(f"\n{'='*65}")
|
| 101 |
+
print(f" SchedulingOptEnv — Baseline Evaluation ({mode})")
|
| 102 |
+
print(f"{'='*65}\n")
|
| 103 |
+
|
| 104 |
+
results: dict[str, Any] = {"mode": mode, "tasks": {}}
|
| 105 |
+
|
| 106 |
+
# ----- Task 1: Feasibility Check -----
|
| 107 |
+
feas_grader = FeasibilityGrader()
|
| 108 |
+
feas_scores: list[float] = []
|
| 109 |
+
print("Task 1: Feasibility Check (easy)")
|
| 110 |
+
for i, entry in enumerate(INSTANCE_BANK):
|
| 111 |
+
instance_str = json.dumps(entry["instance"], indent=2)
|
| 112 |
+
if use_llm:
|
| 113 |
+
resp = _llm_response(
|
| 114 |
+
client,
|
| 115 |
+
(
|
| 116 |
+
"You are a scheduling expert. Determine if the proposed schedule "
|
| 117 |
+
"satisfies all constraints. Reply with ONLY 'feasible' or 'infeasible'."
|
| 118 |
+
),
|
| 119 |
+
instance_str,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
resp = _MOCK_FEASIBILITY.get(i, "infeasible")
|
| 123 |
+
action = Action(response=resp, task_id="feasibility_check")
|
| 124 |
+
score = feas_grader.grade(action, entry)
|
| 125 |
+
feas_scores.append(score)
|
| 126 |
+
status = "CORRECT" if score >= 0.95 else "wrong"
|
| 127 |
+
expected = "feasible" if entry["is_feasible"] else "infeasible"
|
| 128 |
+
print(
|
| 129 |
+
f" Instance {i:2d}: {status:7s} (score={score:.2f}) "
|
| 130 |
+
f"expected={expected} [{entry['description'][:45]}]"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
avg_feas = sum(feas_scores) / len(feas_scores) if feas_scores else 0.0
|
| 134 |
+
results["tasks"]["feasibility_check"] = {
|
| 135 |
+
"average_score": round(avg_feas, 4),
|
| 136 |
+
"num_instances": len(feas_scores),
|
| 137 |
+
"scores": feas_scores,
|
| 138 |
+
}
|
| 139 |
+
print(f" >> Average: {avg_feas:.3f}\n")
|
| 140 |
+
|
| 141 |
+
# ----- Task 2: Conflict Classification -----
|
| 142 |
+
conf_grader = ConflictGrader()
|
| 143 |
+
conf_scores: list[float] = []
|
| 144 |
+
infeasible_entries = [(i, e) for i, e in enumerate(INSTANCE_BANK) if not e["is_feasible"]]
|
| 145 |
+
print("Task 2: Conflict Classification (medium)")
|
| 146 |
+
for i, entry in infeasible_entries:
|
| 147 |
+
instance_str = json.dumps(entry["instance"], indent=2)
|
| 148 |
+
if use_llm:
|
| 149 |
+
resp = _llm_response(
|
| 150 |
+
client,
|
| 151 |
+
(
|
| 152 |
+
"You are a scheduling expert. Identify the constraint violation type. "
|
| 153 |
+
"Reply with ONLY one of: resource_overload, deadline_violation, "
|
| 154 |
+
"precedence_violation, availability_conflict, capacity_exceeded."
|
| 155 |
+
),
|
| 156 |
+
instance_str,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
resp = _MOCK_CLASSIFICATION.get(i, "resource_overload")
|
| 160 |
+
action = Action(response=resp, task_id="conflict_classification")
|
| 161 |
+
score = conf_grader.grade(action, entry)
|
| 162 |
+
conf_scores.append(score)
|
| 163 |
+
status = "EXACT" if score >= 0.95 else ("partial" if score >= 0.45 else "wrong")
|
| 164 |
+
print(
|
| 165 |
+
f" Instance {i:2d}: {status:7s} (score={score:.2f}) "
|
| 166 |
+
f"expected={entry['violation_type']}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
avg_conf = sum(conf_scores) / len(conf_scores) if conf_scores else 0.0
|
| 170 |
+
results["tasks"]["conflict_classification"] = {
|
| 171 |
+
"average_score": round(avg_conf, 4),
|
| 172 |
+
"num_instances": len(conf_scores),
|
| 173 |
+
"scores": conf_scores,
|
| 174 |
+
}
|
| 175 |
+
print(f" >> Average: {avg_conf:.3f}\n")
|
| 176 |
+
|
| 177 |
+
# ----- Task 3: Schedule Repair -----
|
| 178 |
+
repair_grader = RepairGrader()
|
| 179 |
+
repair_scores: list[float] = []
|
| 180 |
+
repairable = [
|
| 181 |
+
(i, e) for i, e in enumerate(INSTANCE_BANK)
|
| 182 |
+
if not e["is_feasible"] and e.get("optimal_schedule")
|
| 183 |
+
]
|
| 184 |
+
print("Task 3: Schedule Repair (hard)")
|
| 185 |
+
for i, entry in repairable:
|
| 186 |
+
instance_str = json.dumps(entry["instance"], indent=2)
|
| 187 |
+
if use_llm:
|
| 188 |
+
resp = _llm_response(
|
| 189 |
+
client,
|
| 190 |
+
(
|
| 191 |
+
"You are a scheduling expert. Repair the infeasible schedule by "
|
| 192 |
+
"returning a JSON object with key 'assignments': a list of "
|
| 193 |
+
'{"job_id", "machine_id", "start_time"} dicts that satisfies all '
|
| 194 |
+
"constraints and minimises makespan. Return ONLY valid JSON."
|
| 195 |
+
),
|
| 196 |
+
instance_str,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
resp = _mock_repair(i)
|
| 200 |
+
action = Action(response=resp, task_id="schedule_repair")
|
| 201 |
+
score = repair_grader.grade(action, entry)
|
| 202 |
+
repair_scores.append(score)
|
| 203 |
+
print(
|
| 204 |
+
f" Instance {i:2d}: score={score:.2f} "
|
| 205 |
+
f"optimal_makespan={entry['optimal_makespan']} "
|
| 206 |
+
f"[{entry['description'][:45]}]"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
avg_repair = sum(repair_scores) / len(repair_scores) if repair_scores else 0.0
|
| 210 |
+
results["tasks"]["schedule_repair"] = {
|
| 211 |
+
"average_score": round(avg_repair, 4),
|
| 212 |
+
"num_instances": len(repair_scores),
|
| 213 |
+
"scores": repair_scores,
|
| 214 |
+
}
|
| 215 |
+
print(f" >> Average: {avg_repair:.3f}\n")
|
| 216 |
+
|
| 217 |
+
# ----- Summary -----
|
| 218 |
+
overall = (avg_feas + avg_conf + avg_repair) / 3
|
| 219 |
+
results["overall_average"] = round(overall, 4)
|
| 220 |
+
print(f"{'='*65}")
|
| 221 |
+
print(f" Overall Average Score: {overall:.3f}")
|
| 222 |
+
print(f"{'='*65}\n")
|
| 223 |
+
|
| 224 |
+
return results
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
try:
|
| 229 |
+
run_baseline()
|
| 230 |
+
except Exception as e:
|
| 231 |
+
print(f"Baseline failed: {e}", file=sys.stderr)
|
| 232 |
+
sys.exit(1)
|