SamaKool commited on
Commit
bde1135
·
2 Parent(s): d28a5cee209e50

moved code into all the grader flies from just one file and fixed the import name 'FinAuditorGrader' which was used for every task difficulty

Browse files
.gitignore CHANGED
@@ -25,3 +25,4 @@ __pycache__/
25
  .venv/
26
  venv/
27
  .envlogs/
 
 
25
  .venv/
26
  venv/
27
  .envlogs/
28
+ .git_tokens
graders/__init__.py CHANGED
@@ -1,16 +1,18 @@
1
- """Graders package for OpenEnv environments.
2
 
3
  Exports
4
  -------
5
- FinAuditorGrader — HFT Auditor: asymmetric TP/FP/FN weighting (in grader_detection)
6
- FeasibilityGrader Task 1: binary feasible / infeasible
7
- ConflictGrader — Task 2: 5-class constraint-violation classification
8
- RepairGrader — Task 3: multi-component schedule repair
9
  """
10
 
11
- from graders.grader_detection import FeasibilityGrader, FinAuditorGrader
12
- from graders.grader_classification import ConflictGrader
13
- from graders.grader_fix import RepairGrader
14
-
15
- __all__ = ["FinAuditorGrader", "FeasibilityGrader", "ConflictGrader", "RepairGrader"]
16
 
 
 
 
 
 
 
1
+ """Graders package for Elite-Trade-Sentry HFT environments.
2
 
3
  Exports
4
  -------
5
+ EasyDetectionGrader - Task 1: Forgiving penalties (0.1 FP / 0.2 FN).
6
+ MediumClassificationGrader - Task 2: Standard HFT penalties (0.2 FP / 0.4 FN).
7
+ HardFixGrader - Task 3: Brutal adversarial penalties (0.4 FP / 0.8 FN).
 
8
  """
9
 
10
+ from graders.grader_detection import EasyDetectionGrader
11
+ from graders.grader_classification import MediumClassificationGrader
12
+ from graders.grader_fix import HardFixGrader
 
 
13
 
14
+ __all__ = [
15
+ "EasyDetectionGrader",
16
+ "MediumClassificationGrader",
17
+ "HardFixGrader"
18
+ ]
graders/grader_classification.py CHANGED
@@ -1,107 +1,34 @@
1
- """Grader for Task 2 — Conflict Classification (medium).
2
-
3
- Scoring
4
- -------
5
- 1.0 — exact match with the ground-truth violation type
6
- 0.5 — same constraint family (resource-limit or temporal-ordering)
7
- 0.1 — valid category but from a different family
8
- 0.0 — empty or completely unrecognised response
9
-
10
- Constraint families (related groups for partial credit)
11
- -------------------------------------------------------
12
- Resource-limit family : resource_overload, capacity_exceeded
13
- Both concern the number of jobs concurrently on a machine.
14
- Temporal-ordering family : deadline_violation, precedence_violation
15
- Both concern the sequencing and timing of job execution.
16
- Standalone : availability_conflict
17
- Concerns machine operational windows (no close sibling).
18
-
19
- After each call, ``last_breakdown`` holds a dict describing the decision.
20
- """
21
-
22
  from __future__ import annotations
23
-
24
  from typing import Any
25
 
26
- from models import Action
27
-
28
- VALID_CATEGORIES: frozenset[str] = frozenset(
29
- {
30
- "resource_overload",
31
- "deadline_violation",
32
- "precedence_violation",
33
- "availability_conflict",
34
- "capacity_exceeded",
35
- }
36
- )
37
-
38
- # Groups of semantically related categories; membership earns partial credit.
39
- _RELATED_GROUPS: list[frozenset[str]] = [
40
- frozenset({"resource_overload", "capacity_exceeded"}), # resource-limit family
41
- frozenset({"deadline_violation", "precedence_violation"}), # temporal-ordering family
42
- ]
43
 
44
-
45
- def _same_family(a: str, b: str) -> bool:
46
- """Return True if a and b belong to the same related group."""
47
- return any(a in g and b in g for g in _RELATED_GROUPS)
48
-
49
-
50
- class ConflictGrader:
51
- """Grade the agent's constraint-violation classification."""
52
 
53
  def __init__(self) -> None:
54
  self.last_breakdown: dict[str, Any] = {}
55
 
56
- def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
57
- # Normalise to snake_case (agents often write "deadline violation" etc.)
58
- response: str = (
59
- action.response.strip().lower().replace(" ", "_").replace("-", "_")
60
- )
61
- expected: str = ground_truth.get("violation_type") or ""
62
-
63
- if not response:
64
- self._record("", expected, 0.0, "Empty response.")
65
- return 0.0
66
-
67
- # Exact match
68
- if response == expected:
69
- self._record(response, expected, 1.0, "Exact match.")
70
- return 1.0
71
 
72
- # Not in vocabulary
73
- if response not in VALID_CATEGORIES:
74
- self._record(
75
- response, expected, 0.0,
76
- f"'{response}' is not a valid category. "
77
- f"Choose from: {', '.join(sorted(VALID_CATEGORIES))}.",
78
- )
79
- return 0.0
80
 
81
- # Same constraint family partial credit
82
- if _same_family(response, expected):
83
- self._record(
84
- response, expected, 0.5,
85
- f"Related category (same family as '{expected}').",
86
- )
87
- return 0.5
88
 
89
- # Valid but different family
90
- self._record(
91
- response, expected, 0.1,
92
- f"Valid category but wrong family. Expected '{expected}'.",
93
- )
94
- return 0.1
95
 
96
- def _record(
97
- self, predicted: str, expected: str, score: float, feedback: str
98
- ) -> None:
99
- self.last_breakdown = {
100
- "predicted": predicted,
101
- "expected": expected,
102
- "score": score,
103
- "in_valid_categories": predicted in VALID_CATEGORIES,
104
- "same_family": _same_family(predicted, expected) if predicted and expected else False,
105
- "exact_match": predicted == expected,
106
- "feedback": feedback,
107
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
 
2
  from typing import Any
3
 
4
+ # MEDIUM MODE: Standard HFT penalties.
5
+ _TP_WEIGHT = 1.0
6
+ _TN_WEIGHT = 0.1
7
+ _FP_PENALTY = 0.2 # Stricter false positive
8
+ _FN_PENALTY = 0.4 # Standard catastrophic failure penalty
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class MediumClassificationGrader:
11
+ """Grader for Task 2: Conflict Classification repurposed for HFT (Medium)."""
 
 
 
 
 
 
12
 
13
  def __init__(self) -> None:
14
  self.last_breakdown: dict[str, Any] = {}
15
 
16
+ def grade(self, state: Any, ground_truth: dict[str, Any] | None = None) -> float:
17
+ tp = float(getattr(state, "total_tp", 0))
18
+ tn = float(getattr(state, "total_tn", 0))
19
+ fp = float(getattr(state, "total_fp", 0))
20
+ fn = float(getattr(state, "total_fn", 0))
 
 
 
 
 
 
 
 
 
 
21
 
22
+ total = tp + tn + fp + fn
23
+ if total == 0:
24
+ return 0.01
 
 
 
 
 
25
 
26
+ positive_signal = (tp * _TP_WEIGHT) + (tn * _TN_WEIGHT)
27
+ negative_signal = (fp * _FP_PENALTY) + (fn * _FN_PENALTY)
 
 
 
 
 
28
 
29
+ max_signal = total * _TP_WEIGHT
30
+ raw_score = max(0.0, positive_signal - negative_signal) / max_signal
 
 
 
 
31
 
32
+ score = max(0.01, min(0.99, raw_score))
33
+ self.last_breakdown = {"tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn), "score": score}
34
+ return score
 
 
 
 
 
 
 
 
 
graders/grader_detection.py CHANGED
@@ -1,143 +1,26 @@
1
- """Graders for detection-class tasks.
2
-
3
- FeasibilityGrader
4
- -----------------
5
- Grades Task 1 — binary feasible / infeasible schedule check.
6
- Scores: 1.0 exact match | 0.1 wrong answer | 0.0 empty.
7
-
8
- FinAuditorGrader
9
- ----------------
10
- Grades HFT Auditor episodes using C++ ReconciliationEngine metrics.
11
- Called by OpenEnv automatically when done=True.
12
- Score formula: asymmetric TP/FP/FN weighting, clamped strictly to [0.01, 0.99].
13
- """
14
-
15
  from __future__ import annotations
16
-
17
  from typing import Any
18
 
19
- from models import Action
 
 
 
 
20
 
21
- # Words treated as equivalent to "feasible"
22
- _FEASIBLE_WORDS: frozenset[str] = frozenset(
23
- {"feasible", "valid", "correct", "satisfiable", "yes", "ok", "pass"}
24
- )
25
-
26
- # Words treated as equivalent to "infeasible"
27
- _INFEASIBLE_WORDS: frozenset[str] = frozenset(
28
- {
29
- "infeasible", "invalid", "incorrect", "unsatisfiable", "no",
30
- "violated", "conflict", "fail", "impossible", "broken",
31
- }
32
- )
33
-
34
-
35
- class FeasibilityGrader:
36
- """Grade whether the agent correctly determined schedule feasibility."""
37
-
38
- def __init__(self) -> None:
39
- # Populated after each call to grade(); surfaced in env info dict.
40
- self.last_breakdown: dict[str, Any] = {}
41
-
42
- def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
43
- response: str = action.response.strip().lower()
44
- is_feasible: bool = ground_truth.get("is_feasible", False)
45
- expected: str = "feasible" if is_feasible else "infeasible"
46
-
47
- # Empty response → no signal
48
- if not response:
49
- self.last_breakdown = {
50
- "predicted": "",
51
- "expected": expected,
52
- "correct": False,
53
- "feedback": "Empty response — reply with 'feasible' or 'infeasible'.",
54
- }
55
- return 0.0
56
-
57
- # Normalise response to canonical form
58
- if response in _FEASIBLE_WORDS:
59
- predicted = "feasible"
60
- elif response in _INFEASIBLE_WORDS:
61
- predicted = "infeasible"
62
- else:
63
- # Recognisable attempt but could not be parsed cleanly
64
- self.last_breakdown = {
65
- "predicted": response,
66
- "expected": expected,
67
- "correct": False,
68
- "feedback": (
69
- f"Could not parse '{response}'. "
70
- "Use exactly 'feasible' or 'infeasible'."
71
- ),
72
- }
73
- return 0.1
74
-
75
- correct = predicted == expected
76
- self.last_breakdown = {
77
- "predicted": predicted,
78
- "expected": expected,
79
- "correct": correct,
80
- "feedback": (
81
- "Correct."
82
- if correct
83
- else f"Wrong — the schedule is {expected}, not {predicted}."
84
- ),
85
- }
86
- # Exact match → 1.0; wrong normalised answer → 0.1 (keeps gradient signal)
87
- return 1.0 if correct else 0.1
88
-
89
-
90
- # ── HFT Auditor Grader ────────────────────────────────────────────────────────
91
-
92
- # Asymmetric reward weights matching the C++ ReconciliationEngine constants
93
- _TP_WEIGHT: float = 1.0 # correctly flagged anomaly — full credit
94
- _TN_WEIGHT: float = 0.1 # correctly passed valid trade — small positive
95
- _FP_PENALTY: float = 0.1 # flagged a valid trade — minor penalty
96
- _FN_PENALTY: float = 0.4 # missed an anomaly — severe penalty
97
-
98
-
99
- class FinAuditorGrader:
100
- """Grade a completed HFT audit episode from C++ engine metrics.
101
-
102
- Called by OpenEnv automatically when ``done=True`` is returned by
103
- ``FinAuditorEnvironment.step()``.
104
-
105
- The score is computed from the cumulative confusion-matrix counters
106
- accumulated across the full episode by the C++ ReconciliationEngine:
107
-
108
- last_tp — True Positives (anomalous trade correctly flagged)
109
- last_tn — True Negatives (valid trade correctly passed)
110
- last_fp — False Positives (valid trade wrongly flagged)
111
- last_fn — False Negatives (anomalous trade missed — catastrophic)
112
-
113
- Hackathon rule: final score is strictly clamped to [0.01, 0.99].
114
- """
115
 
116
  def __init__(self) -> None:
117
  self.last_breakdown: dict[str, Any] = {}
118
 
119
  def grade(self, state: Any, ground_truth: dict[str, Any] | None = None) -> float:
120
- """Compute the final episode score.
121
-
122
- Reads cumulative ``total_*`` counters (full episode) when available,
123
- falling back to ``last_*`` (single-batch snapshot) for compatibility.
124
-
125
- Args:
126
- state: Environment state object at episode end.
127
- ground_truth: Unused — truth is implicit in the C++ engine.
128
-
129
- Returns:
130
- float strictly in (0.01, 0.99).
131
- """
132
- # Prefer full-episode accumulators; fall back to last-batch snapshot
133
- tp = float(getattr(state, "total_tp", None) or getattr(state, "last_tp", 0))
134
- tn = float(getattr(state, "total_tn", None) or getattr(state, "last_tn", 0))
135
- fp = float(getattr(state, "total_fp", None) or getattr(state, "last_fp", 0))
136
- fn = float(getattr(state, "total_fn", None) or getattr(state, "last_fn", 0))
137
 
138
  total = tp + tn + fp + fn
139
  if total == 0:
140
- self._record(tp, tn, fp, fn, 0.01, "No trades evaluated — floor score.")
141
  return 0.01
142
 
143
  positive_signal = (tp * _TP_WEIGHT) + (tn * _TN_WEIGHT)
@@ -147,28 +30,5 @@ class FinAuditorGrader:
147
  raw_score = max(0.0, positive_signal - negative_signal) / max_signal
148
 
149
  score = max(0.01, min(0.99, raw_score))
150
-
151
- self._record(
152
- tp, tn, fp, fn, score,
153
- f"tp={int(tp)} tn={int(tn)} fp={int(fp)} fn={int(fn)} | raw={raw_score:.4f}"
154
- )
155
- return score
156
-
157
- def _record(
158
- self,
159
- tp: float, tn: float, fp: float, fn: float,
160
- score: float, feedback: str,
161
- ) -> None:
162
- total = tp + tn + fp + fn
163
- self.last_breakdown = {
164
- "tp": int(tp),
165
- "tn": int(tn),
166
- "fp": int(fp),
167
- "fn": int(fn),
168
- "total": int(total),
169
- "precision": round(tp / (tp + fp), 4) if (tp + fp) > 0 else 0.0,
170
- "recall": round(tp / (tp + fn), 4) if (tp + fn) > 0 else 0.0,
171
- "score": round(score, 4),
172
- "feedback": feedback,
173
- }
174
- print(f"[GRADER] Episode scored: {feedback} => {score:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
 
2
  from typing import Any
3
 
4
+ # EASY MODE: Forgiving penalties.
5
+ _TP_WEIGHT = 1.0
6
+ _TN_WEIGHT = 0.1
7
+ _FP_PENALTY = 0.1
8
+ _FN_PENALTY = 0.2
9
 
10
+ class EasyDetectionGrader:
11
+ """Grader for Task 1: Anomaly Detection (Easy)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def __init__(self) -> None:
14
  self.last_breakdown: dict[str, Any] = {}
15
 
16
  def grade(self, state: Any, ground_truth: dict[str, Any] | None = None) -> float:
17
+ tp = float(getattr(state, "total_tp", 0))
18
+ tn = float(getattr(state, "total_tn", 0))
19
+ fp = float(getattr(state, "total_fp", 0))
20
+ fn = float(getattr(state, "total_fn", 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  total = tp + tn + fp + fn
23
  if total == 0:
 
24
  return 0.01
25
 
26
  positive_signal = (tp * _TP_WEIGHT) + (tn * _TN_WEIGHT)
 
30
  raw_score = max(0.0, positive_signal - negative_signal) / max_signal
31
 
32
  score = max(0.01, min(0.99, raw_score))
33
+ self.last_breakdown = {"tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn), "score": score}
34
+ return score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graders/grader_fix.py CHANGED
@@ -1,324 +1,34 @@
1
- """Grader for Task 3 — Schedule Repair (hard).
2
-
3
- Scoring breakdown (additive, max 1.0)
4
- --------------------------------------
5
- 0.20 — response is parseable JSON
6
- 0.20 — JSON has the required schema (assignments list, all jobs covered)
7
- 0.40 — schedule satisfies all constraints (0.10 per category):
8
- capacity, deadlines, precedence, availability
9
- 0.20 — makespan within 30% of optimal (0.10 partial if within 60%)
10
-
11
- Partial-progress signal
12
- -----------------------
13
- Even a structurally invalid JSON attempt earns 0.0 (wrong format).
14
- A parseable but schema-invalid JSON earns 0.20 (gave a JSON object).
15
- A valid schema with partial constraint satisfaction earns up to 0.80.
16
- This dense reward curve supports multi-step improvement within an episode.
17
-
18
- After each call, ``last_breakdown`` holds a full dict with per-category
19
- pass/fail flags, makespan, and the optimality ratio — surfaced in the
20
- environment's info dict.
21
- """
22
-
23
  from __future__ import annotations
24
-
25
- import json
26
- import re
27
  from typing import Any
28
 
29
- from models import Action
 
 
 
 
30
 
31
-
32
- class RepairGrader:
33
- """Grade the agent's proposed schedule repair."""
34
 
35
  def __init__(self) -> None:
36
  self.last_breakdown: dict[str, Any] = {}
37
 
38
- def grade(self, action: Action, ground_truth: dict[str, Any]) -> float:
39
- response: str = action.response.strip()
40
- instance: dict[str, Any] = ground_truth.get("instance", {})
41
- optimal_makespan: int = int(ground_truth.get("optimal_makespan", 1) or 1)
42
-
43
- if not response:
44
- self._record_breakdown(
45
- json_ok=False, schema_ok=False,
46
- constraint_detail={}, makespan=0,
47
- optimal_makespan=optimal_makespan,
48
- )
49
- return 0.0
50
-
51
- score = 0.0
52
-
53
- # ------------------------------------------------------------------
54
- # Component 1a — Is the response parseable JSON? (0.20)
55
- # ------------------------------------------------------------------
56
- parsed = self._parse_json(response)
57
- if parsed is None:
58
- self._record_breakdown(
59
- json_ok=False, schema_ok=False,
60
- constraint_detail={}, makespan=0,
61
- optimal_makespan=optimal_makespan,
62
- )
63
- return 0.0 # not JSON → no partial credit at all
64
-
65
- score += 0.20 # JSON parseable
66
-
67
- # ------------------------------------------------------------------
68
- # Component 1b — Does it have the required schema? (0.20)
69
- # Required: {"assignments": [{"job_id", "machine_id", "start_time"}, ...]}
70
- # All jobs from the instance must be present exactly once.
71
- # ------------------------------------------------------------------
72
- assignments: list[Any] = parsed.get("assignments", [])
73
- schema_ok = self._valid_schema(assignments, instance)
74
- if not schema_ok:
75
- self._record_breakdown(
76
- json_ok=True, schema_ok=False,
77
- constraint_detail={}, makespan=0,
78
- optimal_makespan=optimal_makespan,
79
- )
80
- return round(score, 4) # only 0.20
81
-
82
- score += 0.20 # valid schema
83
-
84
- # ------------------------------------------------------------------
85
- # Component 2 — Constraint satisfaction (0.40, 0.10 per category)
86
- # Categories: capacity, deadlines, precedence, availability
87
- # ------------------------------------------------------------------
88
- constraint_detail = self._check_constraints_detail(assignments, instance)
89
- satisfied = sum(constraint_detail.values())
90
- score += 0.40 * (satisfied / max(len(constraint_detail), 1))
91
-
92
- # ------------------------------------------------------------------
93
- # Component 3 — Makespan optimality (0.20)
94
- # Full 0.20 if makespan ≤ optimal × 1.30; partial 0.10 if ≤ 1.60.
95
- # ------------------------------------------------------------------
96
- makespan = self._compute_makespan(assignments, instance)
97
- if makespan > 0 and optimal_makespan > 0:
98
- ratio = makespan / optimal_makespan
99
- if ratio <= 1.30:
100
- score += 0.20
101
- elif ratio <= 1.60:
102
- score += 0.10 # partial optimality credit
103
-
104
- self._record_breakdown(
105
- json_ok=True, schema_ok=True,
106
- constraint_detail=constraint_detail,
107
- makespan=makespan,
108
- optimal_makespan=optimal_makespan,
109
- )
110
- return round(max(0.0, min(1.0, score)), 4)
111
-
112
- # ------------------------------------------------------------------
113
- # Breakdown recording
114
- # ------------------------------------------------------------------
115
-
116
- def _record_breakdown(
117
- self,
118
- json_ok: bool,
119
- schema_ok: bool,
120
- constraint_detail: dict[str, bool],
121
- makespan: int,
122
- optimal_makespan: int,
123
- ) -> None:
124
- ratio = (
125
- round(makespan / optimal_makespan, 3)
126
- if (makespan > 0 and optimal_makespan > 0)
127
- else None
128
- )
129
- self.last_breakdown = {
130
- "json_parseable": json_ok,
131
- "schema_valid": schema_ok,
132
- "constraints": constraint_detail,
133
- "constraints_satisfied": sum(constraint_detail.values()) if constraint_detail else 0,
134
- "makespan": makespan,
135
- "optimal_makespan": optimal_makespan,
136
- "makespan_ratio": ratio,
137
- "within_30pct": ratio is not None and ratio <= 1.30,
138
- }
139
-
140
- # ------------------------------------------------------------------
141
- # JSON parsing — robust to markdown fences and partial wrapping
142
- # ------------------------------------------------------------------
143
-
144
- @staticmethod
145
- def _parse_json(response: str) -> dict[str, Any] | None:
146
- """Try multiple strategies to extract a JSON object from the response.
147
-
148
- Strategy 1: Direct json.loads (agent returned pure JSON).
149
- Strategy 2: Strip markdown code fences, then parse.
150
- Strategy 3: Brace-counting to find the outermost {...} block.
151
- This is the most robust and handles agents that wrap JSON
152
- in prose like "Here is my answer: {...}".
153
- """
154
- # Strategy 1 — direct parse
155
- try:
156
- obj = json.loads(response)
157
- return obj if isinstance(obj, dict) else None
158
- except (json.JSONDecodeError, ValueError):
159
- pass
160
-
161
- # Strategy 2 — strip code fences
162
- stripped = re.sub(r"```(?:json)?", "", response).replace("```", "").strip()
163
- try:
164
- obj = json.loads(stripped)
165
- return obj if isinstance(obj, dict) else None
166
- except (json.JSONDecodeError, ValueError):
167
- pass
168
-
169
- # Strategy 3 — brace-counting for the outermost { ... }
170
- start = response.find("{")
171
- if start == -1:
172
- return None
173
- depth = 0
174
- for i, ch in enumerate(response[start:], start):
175
- if ch == "{":
176
- depth += 1
177
- elif ch == "}":
178
- depth -= 1
179
- if depth == 0:
180
- candidate = response[start : i + 1]
181
- try:
182
- obj = json.loads(candidate)
183
- return obj if isinstance(obj, dict) else None
184
- except (json.JSONDecodeError, ValueError):
185
- return None
186
- return None
187
-
188
- # ------------------------------------------------------------------
189
- # Schema validation
190
- # ------------------------------------------------------------------
191
-
192
- @staticmethod
193
- def _valid_schema(
194
- assignments: list[Any], instance: dict[str, Any]
195
- ) -> bool:
196
- """Validate that assignments is a well-formed list covering all jobs."""
197
- if not isinstance(assignments, list) or len(assignments) == 0:
198
- return False
199
-
200
- required_keys = {"job_id", "machine_id", "start_time"}
201
- for a in assignments:
202
- if not isinstance(a, dict):
203
- return False
204
- if not required_keys.issubset(a.keys()):
205
- return False
206
- if not isinstance(a.get("start_time"), (int, float)):
207
- return False
208
- if a.get("start_time") < 0:
209
- return False # negative start times are never valid
210
-
211
- # Every job in the instance must appear exactly once
212
- expected_jobs = {j["id"] for j in instance.get("jobs", [])}
213
- assigned_jobs = [a["job_id"] for a in assignments]
214
- return set(assigned_jobs) == expected_jobs and len(assigned_jobs) == len(expected_jobs)
215
-
216
- # ------------------------------------------------------------------
217
- # Constraint checking (returns per-category bool dict)
218
- # ------------------------------------------------------------------
219
-
220
- @staticmethod
221
- def _check_constraints_detail(
222
- assignments: list[dict[str, Any]], instance: dict[str, Any]
223
- ) -> dict[str, bool]:
224
- """Return a dict of {constraint_name: passed} for each of the 4 categories."""
225
- jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
226
- machines_by_id = {m["id"]: m for m in instance.get("machines", [])}
227
- assign_by_job = {a["job_id"]: a for a in assignments}
228
-
229
- # ---- (a) Capacity: concurrent jobs on any machine ≤ its capacity ----
230
- machine_intervals: dict[str, list[tuple[float, float]]] = {}
231
- for a in assignments:
232
- mid = a["machine_id"]
233
- st = float(a["start_time"])
234
- dur = float(jobs_by_id.get(a["job_id"], {}).get("duration", 1))
235
- machine_intervals.setdefault(mid, []).append((st, st + dur))
236
-
237
- capacity_ok = True
238
- for mid, intervals in machine_intervals.items():
239
- cap = machines_by_id.get(mid, {}).get("capacity", 1)
240
- for s1, e1 in intervals:
241
- # Count how many intervals overlap with [s1, e1)
242
- concurrent = sum(
243
- 1 for s2, e2 in intervals if s2 < e1 and e2 > s1
244
- )
245
- if concurrent > cap:
246
- capacity_ok = False
247
- break
248
- if not capacity_ok:
249
- break
250
-
251
- # ---- (b) Deadlines: every job finishes by its deadline ----
252
- deadline_ok = True
253
- for a in assignments:
254
- job = jobs_by_id.get(a["job_id"], {})
255
- finish = float(a["start_time"]) + float(job.get("duration", 0))
256
- dl = job.get("deadline", float("inf"))
257
- if finish > dl:
258
- deadline_ok = False
259
- break
260
-
261
- # ---- (c) Precedence: job starts after ALL its predecessors finish ----
262
- precedence_ok = True
263
- for a in assignments:
264
- job = jobs_by_id.get(a["job_id"], {})
265
- for dep_id in job.get("dependencies", []):
266
- dep_a = assign_by_job.get(dep_id)
267
- if dep_a is None:
268
- precedence_ok = False
269
- break
270
- dep_job = jobs_by_id.get(dep_id, {})
271
- dep_finish = float(dep_a["start_time"]) + float(
272
- dep_job.get("duration", 0)
273
- )
274
- if float(a["start_time"]) < dep_finish:
275
- precedence_ok = False
276
- break
277
- if not precedence_ok:
278
- break
279
-
280
- # ---- (d) Availability: job runs within machine availability window ----
281
- availability_ok = True
282
- for a in assignments:
283
- machine = machines_by_id.get(a["machine_id"], {})
284
- avail_start = float(machine.get("available_start", 0))
285
- avail_end = float(machine.get("available_end", float("inf")))
286
- job = jobs_by_id.get(a["job_id"], {})
287
- job_start = float(a["start_time"])
288
- job_end = job_start + float(job.get("duration", 0))
289
- if job_start < avail_start or job_end > avail_end:
290
- availability_ok = False
291
- break
292
 
293
- return {
294
- "capacity": capacity_ok,
295
- "deadlines": deadline_ok,
296
- "precedence": precedence_ok,
297
- "availability": availability_ok,
298
- }
299
 
300
- @staticmethod
301
- def _check_constraints(
302
- assignments: list[dict[str, Any]], instance: dict[str, Any]
303
- ) -> float:
304
- """Convenience wrapper — returns fraction of categories satisfied."""
305
- detail = RepairGrader._check_constraints_detail(assignments, instance)
306
- return sum(detail.values()) / max(len(detail), 1)
307
 
308
- # ------------------------------------------------------------------
309
- # Makespan calculation
310
- # ------------------------------------------------------------------
311
 
312
- @staticmethod
313
- def _compute_makespan(
314
- assignments: list[dict[str, Any]], instance: dict[str, Any]
315
- ) -> int:
316
- """Return the latest finish time across all assigned jobs."""
317
- jobs_by_id = {j["id"]: j for j in instance.get("jobs", [])}
318
- max_finish = 0
319
- for a in assignments:
320
- job = jobs_by_id.get(a["job_id"], {})
321
- finish = int(a["start_time"]) + int(job.get("duration", 0))
322
- if finish > max_finish:
323
- max_finish = finish
324
- return max_finish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
 
 
 
2
  from typing import Any
3
 
4
+ # HARD MODE: Brutal adversarial penalties.
5
+ _TP_WEIGHT = 1.0
6
+ _TN_WEIGHT = 0.05 # Tiny reward for passing valid trades
7
+ _FP_PENALTY = 0.4 # Heavy penalty for false alarms
8
+ _FN_PENALTY = 0.8 # Massive penalty for missing adversarial trades
9
 
10
+ class HardFixGrader:
11
+ """Grader for Task 3: Code Fix repurposed for HFT (Hard)."""
 
12
 
13
  def __init__(self) -> None:
14
  self.last_breakdown: dict[str, Any] = {}
15
 
16
+ def grade(self, state: Any, ground_truth: dict[str, Any] | None = None) -> float:
17
+ tp = float(getattr(state, "total_tp", 0))
18
+ tn = float(getattr(state, "total_tn", 0))
19
+ fp = float(getattr(state, "total_fp", 0))
20
+ fn = float(getattr(state, "total_fn", 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ total = tp + tn + fp + fn
23
+ if total == 0:
24
+ return 0.01
 
 
 
25
 
26
+ positive_signal = (tp * _TP_WEIGHT) + (tn * _TN_WEIGHT)
27
+ negative_signal = (fp * _FP_PENALTY) + (fn * _FN_PENALTY)
 
 
 
 
 
28
 
29
+ max_signal = total * _TP_WEIGHT
30
+ raw_score = max(0.0, positive_signal - negative_signal) / max_signal
 
31
 
32
+ score = max(0.01, min(0.99, raw_score))
33
+ self.last_breakdown = {"tp": int(tp), "tn": int(tn), "fp": int(fp), "fn": int(fn), "score": score}
34
+ return score
 
 
 
 
 
 
 
 
 
 
hf auditor/openenv_fin_auditor.egg-info/PKG-INFO ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-fin_auditor
3
+ Version: 0.1.0
4
+ Summary: Fin Auditor environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: gymnasium>=1.2.3
7
+ Requires-Dist: openenv-core[core]>=0.2.2
8
+ Requires-Dist: numpy
9
+ Requires-Dist: nanobind
10
+ Requires-Dist: openai
11
+ Requires-Dist: pydantic
12
+ Requires-Dist: fastapi
13
+ Requires-Dist: uvicorn
14
+ Requires-Dist: pandas>=2.3.3
15
+ Requires-Dist: python-dotenv>=1.2.2
16
+ Requires-Dist: stable-baselines3[extra]>=2.8.0
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
19
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
hf auditor/openenv_fin_auditor.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ build_engine.py
3
+ final_check.py
4
+ inference.py
5
+ models.py
6
+ pyproject.toml
7
+ test_import.py
8
+ train.py
9
+ ./build_engine.py
10
+ ./final_check.py
11
+ ./inference.py
12
+ ./models.py
13
+ ./test_import.py
14
+ ./train.py
15
+ openenv_fin_auditor.egg-info/PKG-INFO
16
+ openenv_fin_auditor.egg-info/SOURCES.txt
17
+ openenv_fin_auditor.egg-info/dependency_links.txt
18
+ openenv_fin_auditor.egg-info/entry_points.txt
19
+ openenv_fin_auditor.egg-info/requires.txt
20
+ openenv_fin_auditor.egg-info/top_level.txt
21
+ server/__init__.py
22
+ server/app.py
23
+ server/fin_auditor_environment.py
hf auditor/openenv_fin_auditor.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
hf auditor/openenv_fin_auditor.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = fin_auditor.server.app:main
hf auditor/openenv_fin_auditor.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gymnasium>=1.2.3
2
+ openenv-core[core]>=0.2.2
3
+ numpy
4
+ nanobind
5
+ openai
6
+ pydantic
7
+ fastapi
8
+ uvicorn
9
+ pandas>=2.3.3
10
+ python-dotenv>=1.2.2
11
+ stable-baselines3[extra]>=2.8.0
12
+
13
+ [dev]
14
+ pytest>=8.0.0
15
+ pytest-cov>=4.0.0
hf auditor/openenv_fin_auditor.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fin_auditor
server/app.py CHANGED
@@ -25,10 +25,13 @@ if _ROOT_DIR not in sys.path:
25
  if _CURRENT_DIR not in sys.path:
26
  sys.path.insert(0, _CURRENT_DIR)
27
 
 
 
 
 
28
  try:
29
- from fin_auditor_environment import FinAuditorEnvironment, hft_auditor
30
  from models import AuditorAction, AuditorObservation
31
- from graders.grader_detection import FinAuditorGrader
32
  from tasks import task1_easy, task2_medium, task3_hard
33
 
34
  HAS_ENV = True
@@ -47,32 +50,40 @@ except ImportError as e:
47
  # ==============================================================================
48
 
49
  if HAS_ENV and NATIVE_VERIFIED:
50
- # 1. Create a single tracked instance so your custom dashboard can read live metrics
51
- global active_env_instance
 
 
52
  active_env_instance = FinAuditorEnvironment()
53
-
54
- # 2. Register environment, action/obs models, and pass the RAW TASK MODULES
55
- try:
56
- app = create_app(
57
- lambda: active_env_instance, # This preserves your custom UI dashboard!
58
- AuditorAction,
59
- AuditorObservation,
60
- tasks=[
61
- task1_easy,
62
- task2_medium,
63
- task3_hard
64
- ]
65
- )
66
- except TypeError as e:
67
- print(f"[CRITICAL] OpenEnv framework version mismatch: {e}")
68
- # Fallback if the local OpenEnv version doesn't support tasks kwargs
69
- app = create_app(lambda: active_env_instance, AuditorAction, AuditorObservation)
 
 
 
 
 
 
70
 
71
  else:
72
  # Fallback for local development without the C++ binary
73
  app = FastAPI(title="PayGorn (MOCK MODE)")
74
  @app.post("/reset")
75
- async def mock_reset(): return {"reward": 0.0}
76
  @app.post("/step")
77
  async def mock_step(action: dict): return {"reward": 0.5, "done": False, "step_count": 0}
78
 
@@ -230,7 +241,48 @@ async def get_dashboard_action(req: ActionRequest):
230
  decisions = await execute_llm_step(api_key, base_url, model_name, batch_size)
231
  else:
232
  decisions = [random.choice([0, 1]) for _ in range(batch_size)]
233
- return {"decisions": decisions}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  @app.post("/config/llm")
236
  async def config_llm(cfg: LLMConfig):
@@ -704,10 +756,10 @@ async def root_dashboard():
704
  async function executeReset() {
705
  logMsg("SPSC_BUFFER_FLUSHING...", "warn");
706
  try {
707
- await fetch('/reset', {method: 'POST'});
708
- ledgerBody.innerHTML = '';
709
  updateState();
710
- logMsg("Memory pool purged.", "success");
711
  } catch(e) {
712
  logMsg("Reset failed: " + e.message, "err");
713
  }
@@ -721,48 +773,33 @@ async def root_dashboard():
721
  headers: {'Content-Type': 'application/json'},
722
  body: JSON.stringify({action_type: actionType})
723
  });
724
-
725
  if (!actionRes.ok) {
726
  const errData = await actionRes.json();
727
  logMsg("LLM Error: " + (errData.detail || "Failed to generate decisions"), "err");
728
  return;
729
  }
730
-
731
  const actionData = await actionRes.json();
732
- if(!actionData.decisions) {
733
- logMsg("Decision matrix generation failed.", "err"); return;
734
- }
735
 
736
  logMsg(`Executing Step with ${actionData.decisions.length} decisions...`, "info");
737
 
738
- const res = await fetch('/step', {
 
739
  method: 'POST',
740
  headers: {'Content-Type': 'application/json'},
741
- body: JSON.stringify({ action: actionData })
742
  });
743
-
744
  if (!res.ok) {
745
  const errorData = await res.json();
746
- console.error("Validation Details:", errorData);
747
- logMsg(`Server Error: ${res.status}. Check browser console.`, "err");
748
  return;
749
  }
750
-
751
  const data = await res.json();
 
 
 
752
 
753
- // FIX: Robust payload extraction handling regardless of OpenEnv wrapper depth
754
- const reward = data.reward ?? data.observation?.reward ?? data.info?.reward ?? 0.0;
755
- const done = data.done ?? data.observation?.done ?? data.info?.done ?? false;
756
-
757
- // Fetch the authoritative step count from /dashboard/state
758
- let step = 'N/A';
759
- try {
760
- const stateRes = await fetch('/dashboard/state');
761
- const stateData = await stateRes.json();
762
- step = stateData.step_count ?? 'N/A';
763
- } catch(se) {} // Swallow — non-critical
764
-
765
- logMsg(`[RECON] Reward: ${reward.toFixed(4)} | Success`, reward >= 0.8 ? 'success' : 'warn');
766
 
767
  const row = document.createElement('tr');
768
  row.innerHTML = `
@@ -774,13 +811,13 @@ async def root_dashboard():
774
  `;
775
  if(ledgerBody.children.length >= 5) { ledgerBody.removeChild(ledgerBody.firstChild); }
776
  ledgerBody.appendChild(row);
777
-
778
  updateState();
779
  } catch(e) {
780
  logMsg("Step Execution Error: " + e.message, "err");
781
  }
782
  }
783
 
 
784
  // Auto-Reset the environment on boot so it actually has data to process,
785
  // then try to authenticate with the default HF_TOKEN
786
  window.addEventListener('DOMContentLoaded', async () => {
 
25
  if _CURRENT_DIR not in sys.path:
26
  sys.path.insert(0, _CURRENT_DIR)
27
 
28
+ # Always define this as a safe global so /dashboard/state never throws NameError
29
+ # even when NATIVE_VERIFIED is False (C++ binary missing).
30
+ active_env_instance = None
31
+
32
  try:
33
+ from server.fin_auditor_environment import FinAuditorEnvironment, hft_auditor
34
  from models import AuditorAction, AuditorObservation
 
35
  from tasks import task1_easy, task2_medium, task3_hard
36
 
37
  HAS_ENV = True
 
50
  # ==============================================================================
51
 
52
  if HAS_ENV and NATIVE_VERIFIED:
53
+ # ── Dashboard singleton ────────────────────────────────────────────────────
54
+ # Used ONLY by /dashboard/* endpoints and /ws/telemetry for live telemetry.
55
+ # NEVER passed to create_app — OpenEnv gets a factory that produces isolated
56
+ # instances per session so close() cannot corrupt the dashboard engine.
57
  active_env_instance = FinAuditorEnvironment()
58
+ # Pre-load the first batch so the dashboard has real data immediately.
59
+ # (reset() calls generate_batch internally, so just call reset here.)
60
+ active_env_instance.reset()
61
+
62
+ # ── OpenEnv factory ────────────────────────────────────────────────────────
63
+ # CRITICAL: OpenEnv's WebSocket server creates ONE env per session via
64
+ # env_factory() and then sends reset + step messages to the SAME instance.
65
+ # This means reset() IS called before step(), so the C++ engine has data.
66
+ # For HTTP mode (stateless), each request gets its own env — step() is called
67
+ # on a cold engine, but our __init__ initialises counters to 0 so it won't
68
+ # crash; it will just return the floor reward of 0.01 (acceptable for Phase 2).
69
+ def env_factory() -> FinAuditorEnvironment:
70
+ """Create a fresh, self-contained FinAuditorEnvironment per OpenEnv session."""
71
+ return FinAuditorEnvironment()
72
+
73
+ # NOTE: create_app() has no `tasks=` parameter in openenv-core >= 0.2.x.
74
+ # Task routing (easy/medium/hard difficulty) is handled inside reset() via
75
+ # the task_id kwarg that Phase 2 injects into the reset message body.
76
+ app = create_app(
77
+ env_factory,
78
+ AuditorAction,
79
+ AuditorObservation,
80
+ )
81
 
82
  else:
83
  # Fallback for local development without the C++ binary
84
  app = FastAPI(title="PayGorn (MOCK MODE)")
85
  @app.post("/reset")
86
+ async def mock_reset(): return {"reward": 0.01}
87
  @app.post("/step")
88
  async def mock_step(action: dict): return {"reward": 0.5, "done": False, "step_count": 0}
89
 
 
241
  decisions = await execute_llm_step(api_key, base_url, model_name, batch_size)
242
  else:
243
  decisions = [random.choice([0, 1]) for _ in range(batch_size)]
244
+ return {"decisions": decisions}
245
+
246
+
247
+ class DashboardStepRequest(BaseModel):
248
+ """Action payload for the dashboard-native step endpoint."""
249
+ decisions: List[int]
250
+
251
+
252
+ @app.post("/dashboard/step")
253
+ async def dashboard_step(req: DashboardStepRequest):
254
+ """
255
+ Dashboard-native step: runs on the SINGLETON engine (warm, with real trade data),
256
+ not on the OpenEnv /step route which creates a cold engine per request.
257
+
258
+ This is what the three dashboard buttons (OPTIMAL / STRESS / LLM) call so that
259
+ rewards reflect actual confusion-matrix scoring rather than the 0.01 floor.
260
+ """
261
+ if not active_env_instance or not NATIVE_VERIFIED:
262
+ raise HTTPException(status_code=503, detail="Native engine not available")
263
+
264
+ from models import AuditorAction
265
+ action = AuditorAction(decisions=req.decisions)
266
+ obs = active_env_instance.step(action)
267
+ return {
268
+ "reward": obs.reward,
269
+ "done": obs.done,
270
+ "step_count": active_env_instance.state.step_count,
271
+ "features_shape": [len(obs.features), len(obs.features[0]) if obs.features else 0],
272
+ }
273
+
274
+
275
+ @app.post("/dashboard/reset")
276
+ async def dashboard_reset():
277
+ """
278
+ Reset the dashboard singleton: re-seeds the ring buffer with fresh trade data.
279
+ Called by the [FLUSH_SPSC_BUFFER] button in the dashboard JS.
280
+ """
281
+ if not active_env_instance or not NATIVE_VERIFIED:
282
+ raise HTTPException(status_code=503, detail="Native engine not available")
283
+ active_env_instance.reset()
284
+ return {"status": "ok", "step_count": active_env_instance.state.step_count}
285
+
286
 
287
  @app.post("/config/llm")
288
  async def config_llm(cfg: LLMConfig):
 
756
  async function executeReset() {
757
  logMsg("SPSC_BUFFER_FLUSHING...", "warn");
758
  try {
759
+ await fetch('/dashboard/reset', {method: 'POST'});
760
+ ledgerBody.innerHTML = '';
761
  updateState();
762
+ logMsg("Memory pool purged and re-seeded.", "success");
763
  } catch(e) {
764
  logMsg("Reset failed: " + e.message, "err");
765
  }
 
773
  headers: {'Content-Type': 'application/json'},
774
  body: JSON.stringify({action_type: actionType})
775
  });
 
776
  if (!actionRes.ok) {
777
  const errData = await actionRes.json();
778
  logMsg("LLM Error: " + (errData.detail || "Failed to generate decisions"), "err");
779
  return;
780
  }
 
781
  const actionData = await actionRes.json();
782
+ if(!actionData.decisions) { logMsg("Decision matrix generation failed.", "err"); return; }
 
 
783
 
784
  logMsg(`Executing Step with ${actionData.decisions.length} decisions...`, "info");
785
 
786
+ // POST to /dashboard/step (warm singleton) NOT /step (cold factory engine)
787
+ const res = await fetch('/dashboard/step', {
788
  method: 'POST',
789
  headers: {'Content-Type': 'application/json'},
790
+ body: JSON.stringify({ decisions: actionData.decisions })
791
  });
 
792
  if (!res.ok) {
793
  const errorData = await res.json();
794
+ logMsg(`Server Error: ${res.status} — ${errorData.detail || 'check logs'}`, "err");
 
795
  return;
796
  }
 
797
  const data = await res.json();
798
+ const reward = data.reward ?? 0.0;
799
+ const done = data.done ?? false;
800
+ const step = data.step_count ?? 'N/A';
801
 
802
+ logMsg(`[RECON] Reward: ${reward.toFixed(4)} | Step: ${step}`, reward >= 0.8 ? 'success' : 'warn');
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
  const row = document.createElement('tr');
805
  row.innerHTML = `
 
811
  `;
812
  if(ledgerBody.children.length >= 5) { ledgerBody.removeChild(ledgerBody.firstChild); }
813
  ledgerBody.appendChild(row);
 
814
  updateState();
815
  } catch(e) {
816
  logMsg("Step Execution Error: " + e.message, "err");
817
  }
818
  }
819
 
820
+
821
  // Auto-Reset the environment on boot so it actually has data to process,
822
  // then try to authenticate with the default HF_TOKEN
823
  window.addEventListener('DOMContentLoaded', async () => {
server/fin_auditor_environment.py CHANGED
@@ -81,8 +81,30 @@ class FinAuditorEnvironment(Environment):
81
  self._state = State(episode_id=str(uuid4()), step_count=0)
82
  self.engine = hft_auditor.ReconciliationEngine(self._RING_BUFFER_CAPACITY)
83
  self.sim_time_ns = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- task_id = os.getenv("TASK_ID", "anomaly_detection_hard").lower()
 
86
 
87
  if "easy" in task_id:
88
  self.difficulty = hft_auditor.Difficulty.EASY
@@ -93,9 +115,6 @@ class FinAuditorEnvironment(Environment):
93
  else:
94
  self.difficulty = hft_auditor.Difficulty.HARD
95
  self._MAX_EPISODE_STEPS = 20
96
-
97
- def reset(self) -> AuditorObservation:
98
- self._state = State(episode_id=str(uuid4()), step_count=0)
99
 
100
  # 1. Initialize Cumulative Counters for the Grader
101
  self._state.total_tp = 0
@@ -175,6 +194,14 @@ class FinAuditorEnvironment(Environment):
175
  done=done
176
  )
177
 
 
 
 
 
 
 
 
 
178
 
179
  @property
180
  def state(self) -> State:
 
81
  self._state = State(episode_id=str(uuid4()), step_count=0)
82
  self.engine = hft_auditor.ReconciliationEngine(self._RING_BUFFER_CAPACITY)
83
  self.sim_time_ns = 0
84
+
85
+ # We default to HARD, but the actual routing happens in reset()
86
+ self.difficulty = hft_auditor.Difficulty.HARD
87
+ self._MAX_EPISODE_STEPS = 20
88
+
89
+ # Initialize confusion-matrix counters here so they always exist on
90
+ # the State object — even when step() is called on a fresh env that
91
+ # has not yet had reset() called (OpenEnv HTTP stateless mode creates
92
+ # a new env per request, so step_handler calls step() directly).
93
+ self._state.total_tp = 0
94
+ self._state.total_tn = 0
95
+ self._state.total_fp = 0
96
+ self._state.total_fn = 0
97
+ self._state.last_tp = 0
98
+ self._state.last_tn = 0
99
+ self._state.last_fp = 0
100
+ self._state.last_fn = 0
101
+
102
+ # FIX 1: Add *args, **kwargs to prevent TypeError when OpenEnv injects task_id
103
+ def reset(self, *args, **kwargs) -> AuditorObservation:
104
+ self._state = State(episode_id=str(uuid4()), step_count=0)
105
 
106
+ # FIX 2: Dynamically shift difficulty based on OpenEnv's requested task
107
+ task_id = kwargs.get("task_id", os.getenv("TASK_ID", "anomaly_detection_hard")).lower()
108
 
109
  if "easy" in task_id:
110
  self.difficulty = hft_auditor.Difficulty.EASY
 
115
  else:
116
  self.difficulty = hft_auditor.Difficulty.HARD
117
  self._MAX_EPISODE_STEPS = 20
 
 
 
118
 
119
  # 1. Initialize Cumulative Counters for the Grader
120
  self._state.total_tp = 0
 
194
  done=done
195
  )
196
 
197
+ def close(self) -> None:
198
+ """No-op: called by OpenEnv HTTP server after every request.
199
+
200
+ With the factory pattern each request gets a *fresh* instance, so
201
+ there is nothing to explicitly clean up here — the C++ engine is
202
+ reference-counted and will be released when the Python object is GC'd.
203
+ """
204
+ pass
205
 
206
  @property
207
  def state(self) -> State:
tasks/task1_easy.py CHANGED
@@ -7,13 +7,15 @@ _ROOT = os.path.abspath(os.path.join(_HERE, ".."))
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
- from graders.grader_detection import FinAuditorGrader
 
11
 
12
  TASK_ID = "anomaly_detection_easy"
13
  MAX_STEPS = 5
14
  DIFFICULTY = "easy"
15
 
16
- grader = FinAuditorGrader()
 
17
 
18
  def get_task_config() -> dict:
19
  return {
 
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
+ # 1. IMPORT THE EASY GRADER FROM THE DETECTION FILE
11
+ from graders.grader_detection import EasyDetectionGrader
12
 
13
  TASK_ID = "anomaly_detection_easy"
14
  MAX_STEPS = 5
15
  DIFFICULTY = "easy"
16
 
17
+ # 2. INSTANTIATE THE EASY GRADER
18
+ grader = EasyDetectionGrader()
19
 
20
  def get_task_config() -> dict:
21
  return {
tasks/task2_medium.py CHANGED
@@ -7,13 +7,15 @@ _ROOT = os.path.abspath(os.path.join(_HERE, ".."))
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
- from graders.grader_detection import FinAuditorGrader
 
11
 
12
  TASK_ID = "anomaly_detection_medium"
13
  MAX_STEPS = 10
14
  DIFFICULTY = "medium"
15
 
16
- grader = FinAuditorGrader()
 
17
 
18
  def get_task_config() -> dict:
19
  return {
 
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
+ # 1. IMPORT THE MEDIUM GRADER FROM THE CLASSIFICATION FILE
11
+ from graders.grader_classification import MediumClassificationGrader
12
 
13
  TASK_ID = "anomaly_detection_medium"
14
  MAX_STEPS = 10
15
  DIFFICULTY = "medium"
16
 
17
+ # 2. INSTANTIATE THE MEDIUM GRADER
18
+ grader = MediumClassificationGrader()
19
 
20
  def get_task_config() -> dict:
21
  return {
tasks/task3_hard.py CHANGED
@@ -7,13 +7,15 @@ _ROOT = os.path.abspath(os.path.join(_HERE, ".."))
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
- from graders.grader_detection import FinAuditorGrader
 
11
 
12
  TASK_ID = "anomaly_detection_hard"
13
  MAX_STEPS = 20
14
  DIFFICULTY = "hard"
15
 
16
- grader = FinAuditorGrader()
 
17
 
18
  def get_task_config() -> dict:
19
  return {
 
7
  if _ROOT not in sys.path:
8
  sys.path.insert(0, _ROOT)
9
 
10
+ # 1. IMPORT THE HARD GRADER FROM THE FIX FILE
11
+ from graders.grader_fix import HardFixGrader
12
 
13
  TASK_ID = "anomaly_detection_hard"
14
  MAX_STEPS = 20
15
  DIFFICULTY = "hard"
16
 
17
+ # 2. INSTANTIATE THE HARD GRADER
18
+ grader = HardFixGrader()
19
 
20
  def get_task_config() -> dict:
21
  return {