Vittal-M commited on
Commit
d51b70a
·
verified ·
1 Parent(s): 34992b8

Upload graders/grader_fix.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graders/grader_fix.py +324 -0
graders/grader_fix.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grader for Task 3 — Schedule Repair (hard).
2
+
3
+ Scoring breakdown (additive, max 1.0)
4
+ --------------------------------------
5
+ 0.20 — response is parseable JSON
6
+ 0.20 — JSON has the required schema (assignments list, all jobs covered)
7
+ 0.40 — schedule satisfies all constraints (0.10 per category):
8
+ capacity, deadlines, precedence, availability
9
+ 0.20 — makespan within 30% of optimal (0.10 partial if within 60%)
10
+
11
+ Partial-progress signal
12
+ -----------------------
13
+ Even a structurally invalid JSON attempt earns 0.0 (wrong format).
14
+ A parseable but schema-invalid JSON earns 0.20 (gave a JSON object).
15
+ A valid schema with partial constraint satisfaction earns up to 0.80.
16
+ This dense reward curve supports multi-step improvement within an episode.
17
+
18
+ After each call, ``last_breakdown`` holds a full dict with per-category
19
+ pass/fail flags, makespan, and the optimality ratio — surfaced in the
20
+ environment's info dict.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import json
26
+ import re
27
+ from typing import Any
28
+
29
+ from models import Action
30
+
31
+
32
+ class RepairGrader:
33
+ """Grade the agent's proposed schedule repair."""
34
+
35
+ def __init__(self) -> None:
36
+ self.last_breakdown: dict[str, Any] = {}
37
+
38
+ def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
39
+ response: str = action.response.strip()
40
+ instance: dict[str, Any] = ground_truth.get("instance", {})
41
+ optimal_makespan: int = int(ground_truth.get("optimal_makespan", 1) or 1)
42
+
43
+ if not response:
44
+ self._record_breakdown(
45
+ json_ok=False, schema_ok=False,
46
+ constraint_detail={}, makespan=0,
47
+ optimal_makespan=optimal_makespan,
48
+ )
49
+ return 0.0
50
+
51
+ score = 0.0
52
+
53
+ # ------------------------------------------------------------------
54
+ # Component 1a — Is the response parseable JSON? (0.20)
55
+ # ------------------------------------------------------------------
56
+ parsed = self._parse_json(response)
57
+ if parsed is None:
58
+ self._record_breakdown(
59
+ json_ok=False, schema_ok=False,
60
+ constraint_detail={}, makespan=0,
61
+ optimal_makespan=optimal_makespan,
62
+ )
63
+ return 0.0 # not JSON → no partial credit at all
64
+
65
+ score += 0.20 # JSON parseable
66
+
67
+ # ------------------------------------------------------------------
68
+ # Component 1b — Does it have the required schema? (0.20)
69
+ # Required: {"assignments": [{"job_id", "machine_id", "start_time"}, ...]}
70
+ # All jobs from the instance must be present exactly once.
71
+ # ------------------------------------------------------------------
72
+ assignments: list[Any] = parsed.get("assignments", [])
73
+ schema_ok = self._valid_schema(assignments, instance)
74
+ if not schema_ok:
75
+ self._record_breakdown(
76
+ json_ok=True, schema_ok=False,
77
+ constraint_detail={}, makespan=0,
78
+ optimal_makespan=optimal_makespan,
79
+ )
80
+ return round(score, 4) # only 0.20
81
+
82
+ score += 0.20 # valid schema
83
+
84
+ # ------------------------------------------------------------------
85
+ # Component 2 — Constraint satisfaction (0.40, 0.10 per category)
86
+ # Categories: capacity, deadlines, precedence, availability
87
+ # ------------------------------------------------------------------
88
+ constraint_detail = self._check_constraints_detail(assignments, instance)
89
+ satisfied = sum(constraint_detail.values())
90
+ score += 0.40 * (satisfied / max(len(constraint_detail), 1))
91
+
92
+ # ------------------------------------------------------------------
93
+ # Component 3 — Makespan optimality (0.20)
94
+ # Full 0.20 if makespan ≤ optimal × 1.30; partial 0.10 if ≤ 1.60.
95
+ # ------------------------------------------------------------------
96
+ makespan = self._compute_makespan(assignments, instance)
97
+ if makespan > 0 and optimal_makespan > 0:
98
+ ratio = makespan / optimal_makespan
99
+ if ratio <= 1.30:
100
+ score += 0.20
101
+ elif ratio <= 1.60:
102
+ score += 0.10 # partial optimality credit
103
+
104
+ self._record_breakdown(
105
+ json_ok=True, schema_ok=True,
106
+ constraint_detail=constraint_detail,
107
+ makespan=makespan,
108
+ optimal_makespan=optimal_makespan,
109
+ )
110
+ return round(max(0.0, min(1.0, score)), 4)
111
+
112
+ # ------------------------------------------------------------------
113
+ # Breakdown recording
114
+ # ------------------------------------------------------------------
115
+
116
+ def _record_breakdown(
117
+ self,
118
+ json_ok: bool,
119
+ schema_ok: bool,
120
+ constraint_detail: dict[str, bool],
121
+ makespan: int,
122
+ optimal_makespan: int,
123
+ ) -> None:
124
+ ratio = (
125
+ round(makespan / optimal_makespan, 3)
126
+ if (makespan > 0 and optimal_makespan > 0)
127
+ else None
128
+ )
129
+ self.last_breakdown = {
130
+ "json_parseable": json_ok,
131
+ "schema_valid": schema_ok,
132
+ "constraints": constraint_detail,
133
+ "constraints_satisfied": sum(constraint_detail.values()) if constraint_detail else 0,
134
+ "makespan": makespan,
135
+ "optimal_makespan": optimal_makespan,
136
+ "makespan_ratio": ratio,
137
+ "within_30pct": ratio is not None and ratio <= 1.30,
138
+ }
139
+
140
+ # ------------------------------------------------------------------
141
+ # JSON parsing — robust to markdown fences and partial wrapping
142
+ # ------------------------------------------------------------------
143
+
144
+ @staticmethod
145
+ def _parse_json(response: str) -> dict[str, Any] | None:
146
+ """Try multiple strategies to extract a JSON object from the response.
147
+
148
+ Strategy 1: Direct json.loads (agent returned pure JSON).
149
+ Strategy 2: Strip markdown code fences, then parse.
150
+ Strategy 3: Brace-counting to find the outermost {...} block.
151
+ This is the most robust and handles agents that wrap JSON
152
+ in prose like "Here is my answer: {...}".
153
+ """
154
+ # Strategy 1 — direct parse
155
+ try:
156
+ obj = json.loads(response)
157
+ return obj if isinstance(obj, dict) else None
158
+ except (json.JSONDecodeError, ValueError):
159
+ pass
160
+
161
+ # Strategy 2 — strip code fences
162
+ stripped = re.sub(r"```(?:json)?", "", response).replace("```", "").strip()
163
+ try:
164
+ obj = json.loads(stripped)
165
+ return obj if isinstance(obj, dict) else None
166
+ except (json.JSONDecodeError, ValueError):
167
+ pass
168
+
169
+ # Strategy 3 — brace-counting for the outermost { ... }
170
+ start = response.find("{")
171
+ if start == -1:
172
+ return None
173
+ depth = 0
174
+ for i, ch in enumerate(response[start:], start):
175
+ if ch == "{":
176
+ depth += 1
177
+ elif ch == "}":
178
+ depth -= 1
179
+ if depth == 0:
180
+ candidate = response[start : i + 1]
181
+ try:
182
+ obj = json.loads(candidate)
183
+ return obj if isinstance(obj, dict) else None
184
+ except (json.JSONDecodeError, ValueError):
185
+ return None
186
+ return None
187
+
188
+ # ------------------------------------------------------------------
189
+ # Schema validation
190
+ # ------------------------------------------------------------------
191
+
192
+ @staticmethod
193
+ def _valid_schema(
194
+ assignments: list[Any], instance: dict[str, Any]
195
+ ) -> bool:
196
+ """Validate that assignments is a well-formed list covering all jobs."""
197
+ if not isinstance(assignments, list) or len(assignments) == 0:
198
+ return False
199
+
200
+ required_keys = {"job_id", "machine_id", "start_time"}
201
+ for a in assignments:
202
+ if not isinstance(a, dict):
203
+ return False
204
+ if not required_keys.issubset(a.keys()):
205
+ return False
206
+ if not isinstance(a.get("start_time"), (int, float)):
207
+ return False
208
+ if a.get("start_time") < 0:
209
+ return False # negative start times are never valid
210
+
211
+ # Every job in the instance must appear exactly once
212
+ expected_jobs = {j["id"] for j in instance.get("jobs", [])}
213
+ assigned_jobs = [a["job_id"] for a in assignments]
214
+ return set(assigned_jobs) == expected_jobs and len(assigned_jobs) == len(expected_jobs)
215
+
216
+ # ------------------------------------------------------------------
217
+ # Constraint checking (returns per-category bool dict)
218
+ # ------------------------------------------------------------------
219
+
220
+ @staticmethod
221
+ def _check_constraints_detail(
222
+ assignments: list[dict[str, Any]], instance: dict[str, Any]
223
+ ) -> dict[str, bool]:
224
+ """Return a dict of {constraint_name: passed} for each of the 4 categories."""
225
+ jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
226
+ machines_by_id = {m["id"]: m for m in instance.get("machines", [])}
227
+ assign_by_job = {a["job_id"]: a for a in assignments}
228
+
229
+ # ---- (a) Capacity: concurrent jobs on any machine ≤ its capacity ----
230
+ machine_intervals: dict[str, list[tuple[float, float]]] = {}
231
+ for a in assignments:
232
+ mid = a["machine_id"]
233
+ st = float(a["start_time"])
234
+ dur = float(jobs_by_id.get(a["job_id"], {}).get("duration", 1))
235
+ machine_intervals.setdefault(mid, []).append((st, st + dur))
236
+
237
+ capacity_ok = True
238
+ for mid, intervals in machine_intervals.items():
239
+ cap = machines_by_id.get(mid, {}).get("capacity", 1)
240
+ for s1, e1 in intervals:
241
+ # Count how many intervals overlap with [s1, e1)
242
+ concurrent = sum(
243
+ 1 for s2, e2 in intervals if s2 < e1 and e2 > s1
244
+ )
245
+ if concurrent > cap:
246
+ capacity_ok = False
247
+ break
248
+ if not capacity_ok:
249
+ break
250
+
251
+ # ---- (b) Deadlines: every job finishes by its deadline ----
252
+ deadline_ok = True
253
+ for a in assignments:
254
+ job = jobs_by_id.get(a["job_id"], {})
255
+ finish = float(a["start_time"]) + float(job.get("duration", 0))
256
+ dl = job.get("deadline", float("inf"))
257
+ if finish > dl:
258
+ deadline_ok = False
259
+ break
260
+
261
+ # ---- (c) Precedence: job starts after ALL its predecessors finish ----
262
+ precedence_ok = True
263
+ for a in assignments:
264
+ job = jobs_by_id.get(a["job_id"], {})
265
+ for dep_id in job.get("dependencies", []):
266
+ dep_a = assign_by_job.get(dep_id)
267
+ if dep_a is None:
268
+ precedence_ok = False
269
+ break
270
+ dep_job = jobs_by_id.get(dep_id, {})
271
+ dep_finish = float(dep_a["start_time"]) + float(
272
+ dep_job.get("duration", 0)
273
+ )
274
+ if float(a["start_time"]) < dep_finish:
275
+ precedence_ok = False
276
+ break
277
+ if not precedence_ok:
278
+ break
279
+
280
+ # ---- (d) Availability: job runs within machine availability window ----
281
+ availability_ok = True
282
+ for a in assignments:
283
+ machine = machines_by_id.get(a["machine_id"], {})
284
+ avail_start = float(machine.get("available_start", 0))
285
+ avail_end = float(machine.get("available_end", float("inf")))
286
+ job = jobs_by_id.get(a["job_id"], {})
287
+ job_start = float(a["start_time"])
288
+ job_end = job_start + float(job.get("duration", 0))
289
+ if job_start < avail_start or job_end > avail_end:
290
+ availability_ok = False
291
+ break
292
+
293
+ return {
294
+ "capacity": capacity_ok,
295
+ "deadlines": deadline_ok,
296
+ "precedence": precedence_ok,
297
+ "availability": availability_ok,
298
+ }
299
+
300
+ @staticmethod
301
+ def _check_constraints(
302
+ assignments: list[dict[str, Any]], instance: dict[str, Any]
303
+ ) -> float:
304
+ """Convenience wrapper — returns fraction of categories satisfied."""
305
+ detail = RepairGrader._check_constraints_detail(assignments, instance)
306
+ return sum(detail.values()) / max(len(detail), 1)
307
+
308
+ # ------------------------------------------------------------------
309
+ # Makespan calculation
310
+ # ------------------------------------------------------------------
311
+
312
+ @staticmethod
313
+ def _compute_makespan(
314
+ assignments: list[dict[str, Any]], instance: dict[str, Any]
315
+ ) -> int:
316
+ """Return the latest finish time across all assigned jobs."""
317
+ jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
318
+ max_finish = 0
319
+ for a in assignments:
320
+ job = jobs_by_id.get(a["job_id"], {})
321
+ finish = int(a["start_time"]) + int(job.get("duration", 0))
322
+ if finish > max_finish:
323
+ max_finish = finish
324
+ return max_finish