Vittal-M commited on
Commit
325052f
·
verified ·
1 Parent(s): ac4518f

Upload baseline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. baseline.py +232 -0
baseline.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Baseline inference script for the Scheduling Optimisation Environment.
2
+
3
+ Runs GPT-4o-mini (or falls back to deterministic mock responses) against all
4
+ three tasks and prints a structured score report.
5
+
6
+ Usage:
7
+ OPENAI_API_KEY=sk-... python baseline.py
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ import sys
15
+ from typing import Any
16
+
17
+ from environment import INSTANCE_BANK
18
+ from graders.grader_classification import ConflictGrader
19
+ from graders.grader_detection import FeasibilityGrader
20
+ from graders.grader_fix import RepairGrader
21
+ from models import Action
22
+
23
+
24
+ def _get_openai_client():
25
+ """Return an OpenAI client, or None if unavailable."""
26
+ api_key = os.environ.get("OPENAI_API_KEY", "")
27
+ if not api_key:
28
+ return None
29
+ try:
30
+ from openai import OpenAI
31
+ return OpenAI(api_key=api_key)
32
+ except Exception:
33
+ return None
34
+
35
+
36
+ def _llm_response(client, system_prompt: str, user_prompt: str) -> str:
37
+ """Call GPT-4o-mini and return the response text."""
38
+ try:
39
+ resp = client.chat.completions.create(
40
+ model="gpt-4o-mini",
41
+ messages=[
42
+ {"role": "system", "content": system_prompt},
43
+ {"role": "user", "content": user_prompt},
44
+ ],
45
+ max_tokens=1024,
46
+ temperature=0.0,
47
+ )
48
+ return resp.choices[0].message.content.strip()
49
+ except Exception as e:
50
+ print(f" [LLM error: {e}]")
51
+ return ""
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Mock fallback responses (used when no API key is available)
56
+ # ---------------------------------------------------------------------------
57
+
58
+ # Ground-truth feasibility labels — index aligns with INSTANCE_BANK
59
+ _MOCK_FEASIBILITY: dict[int, str] = {
60
+ 0: "infeasible", 1: "infeasible", 2: "infeasible", 3: "infeasible",
61
+ 4: "infeasible", 5: "infeasible", 6: "infeasible", 7: "infeasible",
62
+ 8: "infeasible", 9: "infeasible", 10: "feasible", 11: "feasible",
63
+ }
64
+
65
+ # Ground-truth violation types for infeasible instances
66
+ _MOCK_CLASSIFICATION: dict[int, str] = {
67
+ 0: "resource_overload",
68
+ 1: "deadline_violation",
69
+ 2: "precedence_violation",
70
+ 3: "availability_conflict",
71
+ 4: "capacity_exceeded",
72
+ 5: "resource_overload",
73
+ 6: "deadline_violation",
74
+ 7: "precedence_violation",
75
+ 8: "availability_conflict",
76
+ 9: "capacity_exceeded",
77
+ }
78
+
79
+
80
+ def _mock_repair(instance_idx: int) -> str:
81
+ """Return the known optimal schedule JSON for mock mode."""
82
+ entry = INSTANCE_BANK[instance_idx]
83
+ optimal = entry.get("optimal_schedule", {})
84
+ if not optimal:
85
+ # Return the proposed schedule unchanged as a safe fallback
86
+ optimal = entry["instance"].get("proposed_schedule", {})
87
+ return json.dumps(optimal)
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Baseline runner
92
+ # ---------------------------------------------------------------------------
93
+
94
+
95
+ def run_baseline() -> dict[str, Any]:
96
+ """Execute the baseline across all three tasks and return scores."""
97
+ client = _get_openai_client()
98
+ use_llm = client is not None
99
+ mode = "GPT-4o-mini" if use_llm else "mock (no API key — oracle responses)"
100
+ print(f"\n{'='*65}")
101
+ print(f" SchedulingOptEnv — Baseline Evaluation ({mode})")
102
+ print(f"{'='*65}\n")
103
+
104
+ results: dict[str, Any] = {"mode": mode, "tasks": {}}
105
+
106
+ # ----- Task 1: Feasibility Check -----
107
+ feas_grader = FeasibilityGrader()
108
+ feas_scores: list[float] = []
109
+ print("Task 1: Feasibility Check (easy)")
110
+ for i, entry in enumerate(INSTANCE_BANK):
111
+ instance_str = json.dumps(entry["instance"], indent=2)
112
+ if use_llm:
113
+ resp = _llm_response(
114
+ client,
115
+ (
116
+ "You are a scheduling expert. Determine if the proposed schedule "
117
+ "satisfies all constraints. Reply with ONLY 'feasible' or 'infeasible'."
118
+ ),
119
+ instance_str,
120
+ )
121
+ else:
122
+ resp = _MOCK_FEASIBILITY.get(i, "infeasible")
123
+ action = Action(response=resp, task_id="feasibility_check")
124
+ score = feas_grader.grade(action, entry)
125
+ feas_scores.append(score)
126
+ status = "CORRECT" if score >= 0.95 else "wrong"
127
+ expected = "feasible" if entry["is_feasible"] else "infeasible"
128
+ print(
129
+ f" Instance {i:2d}: {status:7s} (score={score:.2f}) "
130
+ f"expected={expected} [{entry['description'][:45]}]"
131
+ )
132
+
133
+ avg_feas = sum(feas_scores) / len(feas_scores) if feas_scores else 0.0
134
+ results["tasks"]["feasibility_check"] = {
135
+ "average_score": round(avg_feas, 4),
136
+ "num_instances": len(feas_scores),
137
+ "scores": feas_scores,
138
+ }
139
+ print(f" >> Average: {avg_feas:.3f}\n")
140
+
141
+ # ----- Task 2: Conflict Classification -----
142
+ conf_grader = ConflictGrader()
143
+ conf_scores: list[float] = []
144
+ infeasible_entries = [(i, e) for i, e in enumerate(INSTANCE_BANK) if not e["is_feasible"]]
145
+ print("Task 2: Conflict Classification (medium)")
146
+ for i, entry in infeasible_entries:
147
+ instance_str = json.dumps(entry["instance"], indent=2)
148
+ if use_llm:
149
+ resp = _llm_response(
150
+ client,
151
+ (
152
+ "You are a scheduling expert. Identify the constraint violation type. "
153
+ "Reply with ONLY one of: resource_overload, deadline_violation, "
154
+ "precedence_violation, availability_conflict, capacity_exceeded."
155
+ ),
156
+ instance_str,
157
+ )
158
+ else:
159
+ resp = _MOCK_CLASSIFICATION.get(i, "resource_overload")
160
+ action = Action(response=resp, task_id="conflict_classification")
161
+ score = conf_grader.grade(action, entry)
162
+ conf_scores.append(score)
163
+ status = "EXACT" if score >= 0.95 else ("partial" if score >= 0.45 else "wrong")
164
+ print(
165
+ f" Instance {i:2d}: {status:7s} (score={score:.2f}) "
166
+ f"expected={entry['violation_type']}"
167
+ )
168
+
169
+ avg_conf = sum(conf_scores) / len(conf_scores) if conf_scores else 0.0
170
+ results["tasks"]["conflict_classification"] = {
171
+ "average_score": round(avg_conf, 4),
172
+ "num_instances": len(conf_scores),
173
+ "scores": conf_scores,
174
+ }
175
+ print(f" >> Average: {avg_conf:.3f}\n")
176
+
177
+ # ----- Task 3: Schedule Repair -----
178
+ repair_grader = RepairGrader()
179
+ repair_scores: list[float] = []
180
+ repairable = [
181
+ (i, e) for i, e in enumerate(INSTANCE_BANK)
182
+ if not e["is_feasible"] and e.get("optimal_schedule")
183
+ ]
184
+ print("Task 3: Schedule Repair (hard)")
185
+ for i, entry in repairable:
186
+ instance_str = json.dumps(entry["instance"], indent=2)
187
+ if use_llm:
188
+ resp = _llm_response(
189
+ client,
190
+ (
191
+ "You are a scheduling expert. Repair the infeasible schedule by "
192
+ "returning a JSON object with key 'assignments': a list of "
193
+ '{"job_id", "machine_id", "start_time"} dicts that satisfies all '
194
+ "constraints and minimises makespan. Return ONLY valid JSON."
195
+ ),
196
+ instance_str,
197
+ )
198
+ else:
199
+ resp = _mock_repair(i)
200
+ action = Action(response=resp, task_id="schedule_repair")
201
+ score = repair_grader.grade(action, entry)
202
+ repair_scores.append(score)
203
+ print(
204
+ f" Instance {i:2d}: score={score:.2f} "
205
+ f"optimal_makespan={entry['optimal_makespan']} "
206
+ f"[{entry['description'][:45]}]"
207
+ )
208
+
209
+ avg_repair = sum(repair_scores) / len(repair_scores) if repair_scores else 0.0
210
+ results["tasks"]["schedule_repair"] = {
211
+ "average_score": round(avg_repair, 4),
212
+ "num_instances": len(repair_scores),
213
+ "scores": repair_scores,
214
+ }
215
+ print(f" >> Average: {avg_repair:.3f}\n")
216
+
217
+ # ----- Summary -----
218
+ overall = (avg_feas + avg_conf + avg_repair) / 3
219
+ results["overall_average"] = round(overall, 4)
220
+ print(f"{'='*65}")
221
+ print(f" Overall Average Score: {overall:.3f}")
222
+ print(f"{'='*65}\n")
223
+
224
+ return results
225
+
226
+
227
+ if __name__ == "__main__":
228
+ try:
229
+ run_baseline()
230
+ except Exception as e:
231
+ print(f"Baseline failed: {e}", file=sys.stderr)
232
+ sys.exit(1)