Elite-Trade-Sentry / train.py
TheRealAIGuy's picture
E1P1 Fix Hopefully
f0023cf
#!/usr/bin/env python3
"""
train.py β€” PPO training loop for the High-Frequency Risk Compliance Auditor.
=============================================================================
Wraps the FinAuditorEnvironment in a Gymnasium-compatible adapter and trains
a PPO agent using Stable Baselines3.
NaN-collapse fixes applied (see inline comments):
1. Observation space bounded to [0.0, 1.0] instead of Β±inf.
2. Features clipped to [0.0, 1.0] in _process_obs to prevent gradient explosion.
3. Environment is terminated via done=True (set in fin_auditor_environment.py)
so PPO can compute GAE advantages without infinite truncation.
4. Density-based reward in the environment removes the sparse penalty dead zone.
Usage:
python train.py
"""
import os
import sys
import numpy as np
import gymnasium as gym
from gymnasium import spaces
# Stable Baselines3 for PPO
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import CheckpointCallback
# Add project root so hft_auditor .so is importable
_ROOT = os.path.dirname(os.path.abspath(__file__))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from server.fin_auditor_environment import FinAuditorEnvironment
from models import AuditorAction
# ── Hyperparameters ───────────────────────────────────────────────────────────
N_FEATURES = 4 # [time_elapsed, price_delta, missing_frequency, risk_score]
MAX_TRADES = 40 # maximum anomalies per step (== INGEST_CHUNK_SIZE)
TOTAL_TIMESTEPS = 100_000
SAVE_FREQ = 5_000
LOG_DIR = "./logs/"
SAVE_PATH = os.path.join(LOG_DIR, "rl_model")
class GymnasiumFinAuditorEnv(gym.Env):
"""
Gymnasium wrapper around FinAuditorEnvironment.
Observation: flat float32 array of shape (MAX_TRADES * N_FEATURES,)
Values clipped to [0.0, 1.0] to prevent NaN gradients.
Action: MultiDiscrete([2] * MAX_TRADES)
0=PASS, 1=FLAG per trade slot.
"""
metadata = {"render_modes": []}
def __init__(self) -> None:
super().__init__()
self._env = FinAuditorEnvironment()
obs_size = MAX_TRADES * N_FEATURES
# Normalization and clipping are now handled by VecNormalize wrapper in main()
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(obs_size,),
dtype=np.float32,
)
# One discrete decision per trade slot
self.action_space = spaces.MultiDiscrete([2] * MAX_TRADES)
def _process_obs(self, features: list[list[float]]) -> np.ndarray:
"""Flatten the anomaly matrix into a fixed-size float32 vector."""
flat = np.zeros(MAX_TRADES * N_FEATURES, dtype=np.float32)
for i, row in enumerate(features[:MAX_TRADES]):
for j, val in enumerate(row[:N_FEATURES]):
flat[i * N_FEATURES + j] = float(val)
# Padding and normalization are handled by the vectorized environment wrapper.
return flat
def reset(
self,
*,
seed: int | None = None,
options: dict | None = None,
) -> tuple[np.ndarray, dict]:
super().reset(seed=seed)
obs_obj = self._env.reset()
obs = self._process_obs(obs_obj.features)
return obs, {}
def step(
self, action: np.ndarray
) -> tuple[np.ndarray, float, bool, bool, dict]:
decisions = action.tolist() # MultiDiscrete β†’ Python list of ints
action_obj = AuditorAction(decisions=decisions)
obs_obj = self._env.step(action_obj)
obs = self._process_obs(obs_obj.features)
reward = float(obs_obj.reward) if obs_obj.reward is not None else 0.0
done = bool(obs_obj.done) # True when step_count >= _MAX_EPISODE_STEPS
return obs, reward, done, False, {}
def render(self) -> None:
pass
# ─────────────────────────────────────────────────────────────────────────────
# Training entrypoint
# ─────────────────────────────────────────────────────────────────────────────
def main() -> None:
os.makedirs(LOG_DIR, exist_ok=True)
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
env = GymnasiumFinAuditorEnv()
# Sanity-check the raw environment before vectorization
print("[TRAIN] Running Gymnasium environment check...")
check_env(env, warn=True)
print("[TRAIN] Environment check passed.\n")
# WRAP: Use DummyVecEnv and VecNormalize for robust training.
# SB3 requires vectorized environments for several wrappers.
env = DummyVecEnv([lambda: env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
checkpoint_callback = CheckpointCallback(
save_freq=SAVE_FREQ,
save_path=LOG_DIR,
name_prefix="rl_model",
verbose=1,
)
model = PPO(
"MlpPolicy",
env,
verbose=1,
device="cpu",
n_steps=2048, # rollout buffer length per env per update
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01, # mild entropy bonus for exploration
vf_coef=0.5,
max_grad_norm=0.5, # gradient clipping prevents NaN proliferation
tensorboard_log=LOG_DIR,
)
print(f"[TRAIN] Starting PPO training for {TOTAL_TIMESTEPS} timesteps...\n")
try:
model.learn(
total_timesteps=TOTAL_TIMESTEPS,
callback=checkpoint_callback,
progress_bar=True,
)
except KeyboardInterrupt:
print("\n[TRAIN] Training interrupted by user.")
final_path = os.path.join(LOG_DIR, "ppo_fin_auditor_final")
model.save(final_path)
print(f"\n[TRAIN] Model saved to: {final_path}.zip")
if __name__ == "__main__":
main()