Vittal-M commited on
Commit
32a2564
·
verified ·
1 Parent(s): cbe9d96

Upload graders/grader_classification.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graders/grader_classification.py +107 -0
graders/grader_classification.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grader for Task 2 — Conflict Classification (medium).
2
+
3
+ Scoring
4
+ -------
5
+ 1.0 — exact match with the ground-truth violation type
6
+ 0.5 — same constraint family (resource-limit or temporal-ordering)
7
+ 0.1 — valid category but from a different family
8
+ 0.0 — empty or completely unrecognised response
9
+
10
+ Constraint families (related groups for partial credit)
11
+ -------------------------------------------------------
12
+ Resource-limit family : resource_overload, capacity_exceeded
13
+ Both concern the number of jobs concurrently on a machine.
14
+ Temporal-ordering family : deadline_violation, precedence_violation
15
+ Both concern the sequencing and timing of job execution.
16
+ Standalone : availability_conflict
17
+ Concerns machine operational windows (no close sibling).
18
+
19
+ After each call, ``last_breakdown`` holds a dict describing the decision.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from typing import Any
25
+
26
+ from models import Action
27
+
28
+ VALID_CATEGORIES: frozenset[str] = frozenset(
29
+ {
30
+ "resource_overload",
31
+ "deadline_violation",
32
+ "precedence_violation",
33
+ "availability_conflict",
34
+ "capacity_exceeded",
35
+ }
36
+ )
37
+
38
+ # Groups of semantically related categories; membership earns partial credit.
39
+ _RELATED_GROUPS: list[frozenset[str]] = [
40
+ frozenset({"resource_overload", "capacity_exceeded"}), # resource-limit family
41
+ frozenset({"deadline_violation", "precedence_violation"}), # temporal-ordering family
42
+ ]
43
+
44
+
45
+ def _same_family(a: str, b: str) -> bool:
46
+ """Return True if a and b belong to the same related group."""
47
+ return any(a in g and b in g for g in _RELATED_GROUPS)
48
+
49
+
50
+ class ConflictGrader:
51
+ """Grade the agent's constraint-violation classification."""
52
+
53
+ def __init__(self) -> None:
54
+ self.last_breakdown: dict[str, Any] = {}
55
+
56
+ def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
57
+ # Normalise to snake_case (agents often write "deadline violation" etc.)
58
+ response: str = (
59
+ action.response.strip().lower().replace(" ", "_").replace("-", "_")
60
+ )
61
+ expected: str = ground_truth.get("violation_type") or ""
62
+
63
+ if not response:
64
+ self._record("", expected, 0.0, "Empty response.")
65
+ return 0.0
66
+
67
+ # Exact match
68
+ if response == expected:
69
+ self._record(response, expected, 1.0, "Exact match.")
70
+ return 1.0
71
+
72
+ # Not in vocabulary
73
+ if response not in VALID_CATEGORIES:
74
+ self._record(
75
+ response, expected, 0.0,
76
+ f"'{response}' is not a valid category. "
77
+ f"Choose from: {', '.join(sorted(VALID_CATEGORIES))}.",
78
+ )
79
+ return 0.0
80
+
81
+ # Same constraint family → partial credit
82
+ if _same_family(response, expected):
83
+ self._record(
84
+ response, expected, 0.5,
85
+ f"Related category (same family as '{expected}').",
86
+ )
87
+ return 0.5
88
+
89
+ # Valid but different family
90
+ self._record(
91
+ response, expected, 0.1,
92
+ f"Valid category but wrong family. Expected '{expected}'.",
93
+ )
94
+ return 0.1
95
+
96
+ def _record(
97
+ self, predicted: str, expected: str, score: float, feedback: str
98
+ ) -> None:
99
+ self.last_breakdown = {
100
+ "predicted": predicted,
101
+ "expected": expected,
102
+ "score": score,
103
+ "in_valid_categories": predicted in VALID_CATEGORIES,
104
+ "same_family": _same_family(predicted, expected) if predicted and expected else False,
105
+ "exact_match": predicted == expected,
106
+ "feedback": feedback,
107
+ }