TheRealAIGuy commited on
Commit
d28a5ce
·
verified ·
1 Parent(s): c0240f7

Rolled back changes to be on the safe side. Bigger push incoming

Browse files
Files changed (1) hide show
  1. server/fin_auditor_environment.py +10 -18
server/fin_auditor_environment.py CHANGED
@@ -82,16 +82,7 @@ class FinAuditorEnvironment(Environment):
82
  self.engine = hft_auditor.ReconciliationEngine(self._RING_BUFFER_CAPACITY)
83
  self.sim_time_ns = 0
84
 
85
- # We default to HARD, but the actual routing happens in reset()
86
- self.difficulty = hft_auditor.Difficulty.HARD
87
- self._MAX_EPISODE_STEPS = 20
88
-
89
- # FIX 1: Add *args, **kwargs to prevent TypeError when OpenEnv injects task_id
90
- def reset(self, *args, **kwargs) -> AuditorObservation:
91
- self._state = State(episode_id=str(uuid4()), step_count=0)
92
-
93
- # FIX 2: Dynamically shift difficulty based on OpenEnv's requested task
94
- task_id = kwargs.get("task_id", os.getenv("TASK_ID", "anomaly_detection_hard")).lower()
95
 
96
  if "easy" in task_id:
97
  self.difficulty = hft_auditor.Difficulty.EASY
@@ -102,6 +93,9 @@ class FinAuditorEnvironment(Environment):
102
  else:
103
  self.difficulty = hft_auditor.Difficulty.HARD
104
  self._MAX_EPISODE_STEPS = 20
 
 
 
105
 
106
  # 1. Initialize Cumulative Counters for the Grader
107
  self._state.total_tp = 0
@@ -124,7 +118,7 @@ class FinAuditorEnvironment(Environment):
124
  return FinAuditorObservation(
125
  features=anomalies,
126
  message=f"Fin Auditor engine ready. {len(anomalies)} trades loaded.",
127
- reward=0.01,
128
  done=False
129
  )
130
 
@@ -133,12 +127,7 @@ class FinAuditorEnvironment(Environment):
133
 
134
  # 1. EVALUATE AGENT DECISIONS
135
  if action and action.decisions:
136
- # Protect C++ engine from generic OpenEnv agents (like Nemotron)
137
- safe_decisions = action.decisions[:self._INGEST_CHUNK_SIZE]
138
- while len(safe_decisions) < self._INGEST_CHUNK_SIZE:
139
- safe_decisions.append(1)
140
-
141
- action_array = np.array(safe_decisions, dtype=np.uint8)
142
  self.engine.compute_reward(action_array)
143
 
144
  # ACCUMULATE metrics across the ENTIRE episode for the Grader!
@@ -162,7 +151,9 @@ class FinAuditorEnvironment(Environment):
162
  anomalies: list[list[float]] = self.engine.get_anomaly_matrix().tolist()
163
  done = self._state.step_count >= self._MAX_EPISODE_STEPS
164
 
165
- # 4. COMPUTE LIVE STEP REWARD
 
 
166
  tp = float(self._state.total_tp)
167
  tn = float(self._state.total_tn)
168
  fp = float(self._state.total_fp)
@@ -184,6 +175,7 @@ class FinAuditorEnvironment(Environment):
184
  done=done
185
  )
186
 
 
187
  @property
188
  def state(self) -> State:
189
  return self._state
 
82
  self.engine = hft_auditor.ReconciliationEngine(self._RING_BUFFER_CAPACITY)
83
  self.sim_time_ns = 0
84
 
85
+ task_id = os.getenv("TASK_ID", "anomaly_detection_hard").lower()
 
 
 
 
 
 
 
 
 
86
 
87
  if "easy" in task_id:
88
  self.difficulty = hft_auditor.Difficulty.EASY
 
93
  else:
94
  self.difficulty = hft_auditor.Difficulty.HARD
95
  self._MAX_EPISODE_STEPS = 20
96
+
97
+ def reset(self) -> AuditorObservation:
98
+ self._state = State(episode_id=str(uuid4()), step_count=0)
99
 
100
  # 1. Initialize Cumulative Counters for the Grader
101
  self._state.total_tp = 0
 
118
  return FinAuditorObservation(
119
  features=anomalies,
120
  message=f"Fin Auditor engine ready. {len(anomalies)} trades loaded.",
121
+ reward=0.0,
122
  done=False
123
  )
124
 
 
127
 
128
  # 1. EVALUATE AGENT DECISIONS
129
  if action and action.decisions:
130
+ action_array = np.array(action.decisions, dtype=np.uint8)
 
 
 
 
 
131
  self.engine.compute_reward(action_array)
132
 
133
  # ACCUMULATE metrics across the ENTIRE episode for the Grader!
 
151
  anomalies: list[list[float]] = self.engine.get_anomaly_matrix().tolist()
152
  done = self._state.step_count >= self._MAX_EPISODE_STEPS
153
 
154
+ # 4. COMPUTE LIVE STEP REWARD from cumulative episode performance
155
+ # Uses same asymmetric weights as FinAuditorGrader so the dashboard
156
+ # value is consistent with the official final episode score.
157
  tp = float(self._state.total_tp)
158
  tn = float(self._state.total_tn)
159
  fp = float(self._state.total_fp)
 
175
  done=done
176
  )
177
 
178
+
179
  @property
180
  def state(self) -> State:
181
  return self._state