Spaces:
Sleeping
Sleeping
| """Baseline inference script for the Scheduling Optimisation Environment. | |
| Runs GPT-4o-mini (or falls back to deterministic mock responses) against all | |
| three tasks and prints a structured score report. | |
| Usage: | |
| OPENAI_API_KEY=sk-... python baseline.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from typing import Any | |
| from environment import INSTANCE_BANK | |
| from graders.grader_classification import ConflictGrader | |
| from graders.grader_detection import FeasibilityGrader | |
| from graders.grader_fix import RepairGrader | |
| from models import Action | |
| def _get_openai_client(): | |
| """Return an OpenAI client, or None if unavailable.""" | |
| api_key = os.environ.get("OPENAI_API_KEY", "") | |
| if not api_key: | |
| return None | |
| try: | |
| from openai import OpenAI | |
| return OpenAI(api_key=api_key) | |
| except Exception: | |
| return None | |
| def _llm_response(client, system_prompt: str, user_prompt: str) -> str: | |
| """Call GPT-4o-mini and return the response text.""" | |
| try: | |
| resp = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=1024, | |
| temperature=0.0, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f" [LLM error: {e}]") | |
| return "" | |
| # --------------------------------------------------------------------------- | |
| # Mock fallback responses (used when no API key is available) | |
| # --------------------------------------------------------------------------- | |
| # Ground-truth feasibility labels — index aligns with INSTANCE_BANK | |
| _MOCK_FEASIBILITY: dict[int, str] = { | |
| 0: "infeasible", 1: "infeasible", 2: "infeasible", 3: "infeasible", | |
| 4: "infeasible", 5: "infeasible", 6: "infeasible", 7: "infeasible", | |
| 8: "infeasible", 9: "infeasible", 10: "feasible", 11: "feasible", | |
| } | |
| # Ground-truth violation types for infeasible instances | |
| _MOCK_CLASSIFICATION: dict[int, str] = { | |
| 0: "resource_overload", | |
| 1: "deadline_violation", | |
| 2: "precedence_violation", | |
| 3: "availability_conflict", | |
| 4: "capacity_exceeded", | |
| 5: "resource_overload", | |
| 6: "deadline_violation", | |
| 7: "precedence_violation", | |
| 8: "availability_conflict", | |
| 9: "capacity_exceeded", | |
| } | |
| def _mock_repair(instance_idx: int) -> str: | |
| """Return the known optimal schedule JSON for mock mode.""" | |
| entry = INSTANCE_BANK[instance_idx] | |
| optimal = entry.get("optimal_schedule", {}) | |
| if not optimal: | |
| # Return the proposed schedule unchanged as a safe fallback | |
| optimal = entry["instance"].get("proposed_schedule", {}) | |
| return json.dumps(optimal) | |
| # --------------------------------------------------------------------------- | |
| # Baseline runner | |
| # --------------------------------------------------------------------------- | |
| def run_baseline() -> dict[str, Any]: | |
| """Execute the baseline across all three tasks and return scores.""" | |
| client = _get_openai_client() | |
| use_llm = client is not None | |
| mode = "GPT-4o-mini" if use_llm else "mock (no API key — oracle responses)" | |
| print(f"\n{'='*65}") | |
| print(f" SchedulingOptEnv — Baseline Evaluation ({mode})") | |
| print(f"{'='*65}\n") | |
| results: dict[str, Any] = {"mode": mode, "tasks": {}} | |
| # ----- Task 1: Feasibility Check ----- | |
| feas_grader = FeasibilityGrader() | |
| feas_scores: list[float] = [] | |
| print("Task 1: Feasibility Check (easy)") | |
| for i, entry in enumerate(INSTANCE_BANK): | |
| instance_str = json.dumps(entry["instance"], indent=2) | |
| if use_llm: | |
| resp = _llm_response( | |
| client, | |
| ( | |
| "You are a scheduling expert. Determine if the proposed schedule " | |
| "satisfies all constraints. Reply with ONLY 'feasible' or 'infeasible'." | |
| ), | |
| instance_str, | |
| ) | |
| else: | |
| resp = _MOCK_FEASIBILITY.get(i, "infeasible") | |
| action = Action(response=resp, task_id="feasibility_check") | |
| score = feas_grader.grade(action, entry) | |
| feas_scores.append(score) | |
| status = "CORRECT" if score >= 0.95 else "wrong" | |
| expected = "feasible" if entry["is_feasible"] else "infeasible" | |
| print( | |
| f" Instance {i:2d}: {status:7s} (score={score:.2f}) " | |
| f"expected={expected} [{entry['description'][:45]}]" | |
| ) | |
| avg_feas = sum(feas_scores) / len(feas_scores) if feas_scores else 0.0 | |
| results["tasks"]["feasibility_check"] = { | |
| "average_score": round(avg_feas, 4), | |
| "num_instances": len(feas_scores), | |
| "scores": feas_scores, | |
| } | |
| print(f" >> Average: {avg_feas:.3f}\n") | |
| # ----- Task 2: Conflict Classification ----- | |
| conf_grader = ConflictGrader() | |
| conf_scores: list[float] = [] | |
| infeasible_entries = [(i, e) for i, e in enumerate(INSTANCE_BANK) if not e["is_feasible"]] | |
| print("Task 2: Conflict Classification (medium)") | |
| for i, entry in infeasible_entries: | |
| instance_str = json.dumps(entry["instance"], indent=2) | |
| if use_llm: | |
| resp = _llm_response( | |
| client, | |
| ( | |
| "You are a scheduling expert. Identify the constraint violation type. " | |
| "Reply with ONLY one of: resource_overload, deadline_violation, " | |
| "precedence_violation, availability_conflict, capacity_exceeded." | |
| ), | |
| instance_str, | |
| ) | |
| else: | |
| resp = _MOCK_CLASSIFICATION.get(i, "resource_overload") | |
| action = Action(response=resp, task_id="conflict_classification") | |
| score = conf_grader.grade(action, entry) | |
| conf_scores.append(score) | |
| status = "EXACT" if score >= 0.95 else ("partial" if score >= 0.45 else "wrong") | |
| print( | |
| f" Instance {i:2d}: {status:7s} (score={score:.2f}) " | |
| f"expected={entry['violation_type']}" | |
| ) | |
| avg_conf = sum(conf_scores) / len(conf_scores) if conf_scores else 0.0 | |
| results["tasks"]["conflict_classification"] = { | |
| "average_score": round(avg_conf, 4), | |
| "num_instances": len(conf_scores), | |
| "scores": conf_scores, | |
| } | |
| print(f" >> Average: {avg_conf:.3f}\n") | |
| # ----- Task 3: Schedule Repair ----- | |
| repair_grader = RepairGrader() | |
| repair_scores: list[float] = [] | |
| repairable = [ | |
| (i, e) for i, e in enumerate(INSTANCE_BANK) | |
| if not e["is_feasible"] and e.get("optimal_schedule") | |
| ] | |
| print("Task 3: Schedule Repair (hard)") | |
| for i, entry in repairable: | |
| instance_str = json.dumps(entry["instance"], indent=2) | |
| if use_llm: | |
| resp = _llm_response( | |
| client, | |
| ( | |
| "You are a scheduling expert. Repair the infeasible schedule by " | |
| "returning a JSON object with key 'assignments': a list of " | |
| '{"job_id", "machine_id", "start_time"} dicts that satisfies all ' | |
| "constraints and minimises makespan. Return ONLY valid JSON." | |
| ), | |
| instance_str, | |
| ) | |
| else: | |
| resp = _mock_repair(i) | |
| action = Action(response=resp, task_id="schedule_repair") | |
| score = repair_grader.grade(action, entry) | |
| repair_scores.append(score) | |
| print( | |
| f" Instance {i:2d}: score={score:.2f} " | |
| f"optimal_makespan={entry['optimal_makespan']} " | |
| f"[{entry['description'][:45]}]" | |
| ) | |
| avg_repair = sum(repair_scores) / len(repair_scores) if repair_scores else 0.0 | |
| results["tasks"]["schedule_repair"] = { | |
| "average_score": round(avg_repair, 4), | |
| "num_instances": len(repair_scores), | |
| "scores": repair_scores, | |
| } | |
| print(f" >> Average: {avg_repair:.3f}\n") | |
| # ----- Summary ----- | |
| overall = (avg_feas + avg_conf + avg_repair) / 3 | |
| results["overall_average"] = round(overall, 4) | |
| print(f"{'='*65}") | |
| print(f" Overall Average Score: {overall:.3f}") | |
| print(f"{'='*65}\n") | |
| return results | |
| if __name__ == "__main__": | |
| try: | |
| run_baseline() | |
| except Exception as e: | |
| print(f"Baseline failed: {e}", file=sys.stderr) | |
| sys.exit(1) | |