Elite-Trade-Sentry / server /fin_auditor_environment.py
TheRealAIGuy's picture
Removed a suppressor that was suppressing a non-existent std::cout
a590d66
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Fin Auditor Environment Implementation.
Wraps the compiled C++ ``hft_auditor.ReconciliationEngine``.
"""
import os
import sys
import glob
import importlib.util
from uuid import uuid4
import numpy as np
# ── Native Engine Bridge ─────────────────────────────────────────────────────
def _load_native_engine():
"""Surgically discovers and loads the compiled C++ binary (.so or .pyd)."""
try:
import hft_auditor
return hft_auditor
except ImportError:
pass
_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.abspath(os.path.join(_CURRENT_DIR, ".."))
patterns = [
os.path.join(_ROOT_DIR, "hft_auditor*.pyd"),
os.path.join(_ROOT_DIR, "hft_auditor*.so"),
os.path.join(_ROOT_DIR, "hf auditor/build/**/hft_auditor*.pyd"),
os.path.join(_ROOT_DIR, "hf auditor/build/**/hft_auditor*.so"),
]
lib_files = []
for p in patterns:
lib_files.extend(glob.glob(p, recursive=True))
if not lib_files:
return None
try:
lib_path = lib_files[0]
spec = importlib.util.spec_from_file_location("hft_auditor", lib_path)
if spec is None or spec.loader is None:
return None
module = importlib.util.module_from_spec(spec)
sys.modules["hft_auditor"] = module
spec.loader.exec_module(module)
return module
except Exception as e:
print(f"[CRITICAL] Native loading failed: {e}")
return None
hft_auditor = _load_native_engine()
# ─────────────────────────────────────────────────────────────────────────────
from typing import Any, Dict, Optional
from pydantic import Field
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import AuditorAction, AuditorObservation
class FinAuditorObservation(AuditorObservation):
model_config = AuditorObservation.model_config
done: bool = Field(default=False)
reward: Optional[float] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict)
class FinAuditorEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
_RING_BUFFER_CAPACITY: int = 1_048_576
_INGEST_CHUNK_SIZE: int = 40
_DELTA_MAX_NS: int = 5_000_000_000
def __init__(self) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
self.engine = hft_auditor.ReconciliationEngine(self._RING_BUFFER_CAPACITY)
self.sim_time_ns = 0
# We default to HARD, but the actual routing happens in reset()
self.difficulty = hft_auditor.Difficulty.HARD
self._MAX_EPISODE_STEPS = 20
# Initialize confusion-matrix counters here so they always exist on
# the State object — even when step() is called on a fresh env that
# has not yet had reset() called (OpenEnv HTTP stateless mode creates
# a new env per request, so step_handler calls step() directly).
self._state.total_tp = 0
self._state.total_tn = 0
self._state.total_fp = 0
self._state.total_fn = 0
self._state.last_tp = 0
self._state.last_tn = 0
self._state.last_fp = 0
self._state.last_fn = 0
# FIX 1: Add *args, **kwargs to prevent TypeError when OpenEnv injects task_id
def reset(self, *args, **kwargs) -> AuditorObservation:
self._state = State(episode_id=str(uuid4()), step_count=0)
# FIX 2: Dynamically shift difficulty based on OpenEnv's requested task
task_id = kwargs.get("task_id", os.getenv("TASK_ID", "anomaly_detection_hard")).lower()
if "easy" in task_id:
self.difficulty = hft_auditor.Difficulty.EASY
self._MAX_EPISODE_STEPS = 5
elif "medium" in task_id:
self.difficulty = hft_auditor.Difficulty.MEDIUM
self._MAX_EPISODE_STEPS = 10
else:
self.difficulty = hft_auditor.Difficulty.HARD
self._MAX_EPISODE_STEPS = 20
# 1. Initialize Cumulative Counters for the Grader
self._state.total_tp = 0
self._state.total_tn = 0
self._state.total_fp = 0
self._state.total_fn = 0
self._state.last_tp = 0
self._state.last_tn = 0
self._state.last_fp = 0
self._state.last_fn = 0
# 2. Pre-generate the first batch so Step 1 actually has data!
self.engine.generate_batch(self.difficulty, self._INGEST_CHUNK_SIZE, self.sim_time_ns)
self.sim_time_ns += self._DELTA_MAX_NS + 1_000_000_000
self.engine.tick(self.sim_time_ns)
anomalies: list[list[float]] = self.engine.get_anomaly_matrix().tolist()
return FinAuditorObservation(
features=anomalies,
message=f"Fin Auditor engine ready. {len(anomalies)} trades loaded.",
reward=0.01, # floor: strictly > 0.0 (boundary requirement)
done=False
)
def step(self, action: AuditorAction) -> AuditorObservation: # type: ignore[override]
self._state.step_count += 1
# 1. EVALUATE AGENT DECISIONS
if action and action.decisions:
# Clamp decisions to current batch size to prevent C++ engine crash
# on length mismatch (agent may send fewer/more decisions than trades).
current_batch_size = len(self.engine.get_anomaly_matrix())
decisions = list(action.decisions)
if len(decisions) > current_batch_size:
decisions = decisions[:current_batch_size] # truncate extras
elif len(decisions) < current_batch_size:
decisions += [0] * (current_batch_size - len(decisions)) # pad with Pass
action_array = np.array(decisions, dtype=np.uint8)
self.engine.compute_reward(action_array)
# ACCUMULATE metrics across the ENTIRE episode for the Grader!
self._state.total_tp += self.engine.last_tp
self._state.total_tn += self.engine.last_tn
self._state.total_fp += self.engine.last_fp
self._state.total_fn += self.engine.last_fn
# Expose the single-batch metrics for your React dashboard
self._state.last_tp = self.engine.last_tp
self._state.last_tn = self.engine.last_tn
self._state.last_fp = self.engine.last_fp
self._state.last_fn = self.engine.last_fn
# 2. ENGINE PROGRESSION
self.engine.generate_batch(self.difficulty, self._INGEST_CHUNK_SIZE, self.sim_time_ns)
self.sim_time_ns += self._DELTA_MAX_NS + 1_000_000_000
self.engine.tick(self.sim_time_ns)
# 3. EXTRACT NEXT MATRIX
anomalies: list[list[float]] = self.engine.get_anomaly_matrix().tolist()
done = self._state.step_count >= self._MAX_EPISODE_STEPS
# 4. COMPUTE LIVE STEP REWARD
tp = float(self._state.total_tp)
tn = float(self._state.total_tn)
fp = float(self._state.total_fp)
fn = float(self._state.total_fn)
actual_anomalies = tp + fn
actual_valid = tn + fp
perfect_signal = (actual_anomalies * 1.0) + (actual_valid * 0.1)
if perfect_signal > 0:
positive = (tp * 1.0) + (tn * 0.1)
negative = (fp * 0.1) + (fn * 0.4)
raw = max(0.0, positive - negative) / perfect_signal
# Clamp strictly inside (0.0, 1.0) — evaluator rejects exact boundaries.
step_reward = max(0.01, min(0.99, raw))
else:
step_reward = 0.01 # floor: strictly > 0.0 (boundary requirement)
return FinAuditorObservation(
features=anomalies,
message=f"Processed batch. Found {len(anomalies)} expired trades.",
reward=step_reward,
done=done
)
def close(self) -> None:
"""No-op: called by OpenEnv HTTP server after every request.
With the factory pattern each request gets a *fresh* instance, so
there is nothing to explicitly clean up here — the C++ engine is
reference-counted and will be released when the Python object is GC'd.
"""
pass
@property
def state(self) -> State:
return self._state