Elite-Trade-Sentry / inference.py
TheRealAIGuy's picture
Fix: P2 Grading Fix (#1)
3385186 verified
#!/usr/bin/env python3
import os
import sys
import json
import re
import datetime
import traceback
import time
from typing import List
from dotenv import load_dotenv
load_dotenv()
# ── Project root on sys.path so `hft_auditor` .so and `models` are importable ──
_ROOT = os.path.dirname(os.path.abspath(__file__))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from openai import OpenAI
from pydantic import BaseModel, ValidationError
try:
from hft_auditor_env import FinAuditorEnv as FinAuditorEnvironment
except ImportError:
from server.fin_auditor_environment import FinAuditorEnvironment
from models import AuditorAction
class LLMResponse(BaseModel):
reasoning: str
decisions: List[int]
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
HF_TOKEN: str = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("CRITICAL: HF_TOKEN environment variable is missing.")
TASK_ID: str = os.getenv("TASK_ID", "anomaly_detection_hard")
# FIX: Sync the inference max_steps default with the active task
if "easy" in TASK_ID.lower():
_DEFAULT_MAX = 5
elif "medium" in TASK_ID.lower():
_DEFAULT_MAX = 10
else:
_DEFAULT_MAX = 20
MAX_STEPS: int = int(os.getenv("MAX_STEPS", str(_DEFAULT_MAX)))
_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
SYSTEM_PROMPT = """\
You are a Wall Street Compliance Auditor AI embedded inside a High-Frequency Trading audit engine.
You MUST think step-by-step in the 'reasoning' field before determining your action.
━━━ DECISION VALUES ━━━
You must output ONLY raw integers (0 or 1) in the array. NO strings. NO labels.
1 : Confirmed anomaly.
0 : Trade is clean.
━━━ EVALUATION DIRECTIVE ━━━
1. CRITICAL SIGNAL (FLAG): If risk_score > 0.60, output 1.
2. NOISE SIGNAL (PASS): If risk_score < 0.30, output 0.
3. AMBIGUITY (FLAG): If risk_score is in between, output 1 to be safe.
━━━ CRITICAL JSON FORMAT ━━━
You MUST respond with a valid JSON object. The decisions array MUST contain exactly the requested number of raw integers. NO trailing commas.
Example:
{"reasoning": "Trade 1 has high risk. Trade 2 is safe.", "decisions": [1, 0, 1]}
"""
def _ts() -> str:
return datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def _build_user_prompt(step: int, features: list[list[float]]) -> str:
lines = [
f"Step {step}: You have {len(features)} flagged trades to audit.",
"",
"Trade# | time_elapsed | price_delta | missing_freq | risk_score",
"-------|--------------|-------------|--------------|----------",
]
for i, row in enumerate(features):
if len(row) >= 4:
lines.append(f" {i+1:3d} | {row[0]:8.4f} | {row[1]:7.4f} | {row[2]:8.4f} | {row[3]:7.4f}")
else:
lines.append(f" {i+1:3d} | (malformed row: {row})")
lines.append("")
lines.append(f"Provide exactly {len(features)} decisions as a JSON object.")
return "\n".join(lines)
_last_reasoning: str = ""
def _parse_llm_decisions(content: str, expected_count: int) -> list[int]:
global _last_reasoning
stripped = content.strip()
if stripped.startswith("```"):
stripped = re.sub(r'^```[\w]*\n?', '', stripped)
stripped = re.sub(r'\n?```$', '', stripped.strip())
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict) and "decisions" in parsed:
response = LLMResponse(**parsed)
_last_reasoning = response.reasoning
return _normalize_decisions([int(d) for d in response.decisions], expected_count)
except Exception:
pass
try:
parsed = json.loads(stripped)
if isinstance(parsed, dict) and "decisions" in parsed:
decisions = [int(d) for d in parsed["decisions"]]
return _normalize_decisions(decisions, expected_count)
except Exception:
pass
match = re.search(r'\[[\s\d,]+\]', content)
if match:
try:
decisions = json.loads(match.group())
return _normalize_decisions([int(d) for d in decisions], expected_count)
except Exception:
pass
return [1] * expected_count
def _normalize_decisions(decisions: list[int], expected: int) -> list[int]:
clamped = [1 if d >= 1 else 0 for d in decisions]
clamped = clamped[:expected]
while len(clamped) < expected:
clamped.append(1)
return clamped
def _call_llm(step: int, features: list[list[float]]) -> list[int]:
global _last_reasoning
_last_reasoning = "Fallback triggered."
user_prompt = _build_user_prompt(step, features)
max_retries = 3
for attempt in range(max_retries):
try:
response = _client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=1500,
temperature=0.0,
)
content = response.choices[0].message.content or ""
return _parse_llm_decisions(content, len(features))
except Exception as e:
time.sleep(1)
fallback_decisions = []
for row in features:
if len(row) >= 4:
fallback_decisions.append(1 if row[3] >= 0.7 else 0)
else:
fallback_decisions.append(1)
return fallback_decisions
def run_inference() -> None:
episode_id: str = "unknown"
total_reward: float = 0.0
steps_completed: int = 0
status: str = "SUCCESS"
try:
env = FinAuditorEnvironment()
obs = env.reset()
episode_id = getattr(env.state, 'episode_id', "test_run")
start_payload = {
"episode_id": episode_id,
"model": MODEL_NAME,
"difficulty": TASK_ID,
"max_steps": MAX_STEPS
}
print(f"[START] {json.dumps(start_payload)}", flush=True)
for step_num in range(1, MAX_STEPS + 1):
step_reward = 0.0
features = obs.features
if not features:
action = AuditorAction(decisions=[])
_last_reasoning = "Empty matrix."
else:
decisions = _call_llm(step_num, features)
action = AuditorAction(decisions=decisions)
obs = env.step(action)
step_reward = obs.reward if obs.reward is not None else 0.0
total_reward += step_reward
steps_completed = step_num
# FIX: Ensure fractional precision is retained for validation
step_payload = {
"step": step_num,
"anomalies": len(features),
"reward": round(float(step_reward), 4),
"cumulative_reward": round(float(total_reward), 4),
"done": bool(obs.done),
"error": None,
"reasoning": _last_reasoning[:120].replace('\n', ' ') + "...",
"tp": getattr(env.state, 'last_tp', 0),
"tn": getattr(env.state, 'last_tn', 0),
"fp": getattr(env.state, 'last_fp', 0),
"fn": getattr(env.state, 'last_fn', 0)
}
print(f"[STEP] {json.dumps(step_payload)}", flush=True)
if obs.done:
break
except KeyboardInterrupt:
status = "INTERRUPTED"
except Exception as exc:
status = "ERROR"
traceback.print_exc(file=sys.stderr)
avg_reward = total_reward / max(steps_completed, 1)
end_payload = {
"total_reward": round(float(total_reward), 4),
"avg_reward": round(float(avg_reward), 4),
"status": status
}
print(f"[END] {json.dumps(end_payload)}", flush=True)
if __name__ == "__main__":
run_inference()