Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| import sys | |
| import json | |
| import re | |
| import datetime | |
| import traceback | |
| import time | |
| from typing import List | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # ── Project root on sys.path so `hft_auditor` .so and `models` are importable ── | |
| _ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if _ROOT not in sys.path: | |
| sys.path.insert(0, _ROOT) | |
| from openai import OpenAI | |
| from pydantic import BaseModel, ValidationError | |
| try: | |
| from hft_auditor_env import FinAuditorEnv as FinAuditorEnvironment | |
| except ImportError: | |
| from server.fin_auditor_environment import FinAuditorEnvironment | |
| from models import AuditorAction | |
| class LLMResponse(BaseModel): | |
| reasoning: str | |
| decisions: List[int] | |
| API_BASE_URL: str = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| HF_TOKEN: str = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("CRITICAL: HF_TOKEN environment variable is missing.") | |
| TASK_ID: str = os.getenv("TASK_ID", "anomaly_detection_hard") | |
| # FIX: Sync the inference max_steps default with the active task | |
| if "easy" in TASK_ID.lower(): | |
| _DEFAULT_MAX = 5 | |
| elif "medium" in TASK_ID.lower(): | |
| _DEFAULT_MAX = 10 | |
| else: | |
| _DEFAULT_MAX = 20 | |
| MAX_STEPS: int = int(os.getenv("MAX_STEPS", str(_DEFAULT_MAX))) | |
| _client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| SYSTEM_PROMPT = """\ | |
| You are a Wall Street Compliance Auditor AI embedded inside a High-Frequency Trading audit engine. | |
| You MUST think step-by-step in the 'reasoning' field before determining your action. | |
| ━━━ DECISION VALUES ━━━ | |
| You must output ONLY raw integers (0 or 1) in the array. NO strings. NO labels. | |
| 1 : Confirmed anomaly. | |
| 0 : Trade is clean. | |
| ━━━ EVALUATION DIRECTIVE ━━━ | |
| 1. CRITICAL SIGNAL (FLAG): If risk_score > 0.60, output 1. | |
| 2. NOISE SIGNAL (PASS): If risk_score < 0.30, output 0. | |
| 3. AMBIGUITY (FLAG): If risk_score is in between, output 1 to be safe. | |
| ━━━ CRITICAL JSON FORMAT ━━━ | |
| You MUST respond with a valid JSON object. The decisions array MUST contain exactly the requested number of raw integers. NO trailing commas. | |
| Example: | |
| {"reasoning": "Trade 1 has high risk. Trade 2 is safe.", "decisions": [1, 0, 1]} | |
| """ | |
| def _ts() -> str: | |
| return datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" | |
| def _build_user_prompt(step: int, features: list[list[float]]) -> str: | |
| lines = [ | |
| f"Step {step}: You have {len(features)} flagged trades to audit.", | |
| "", | |
| "Trade# | time_elapsed | price_delta | missing_freq | risk_score", | |
| "-------|--------------|-------------|--------------|----------", | |
| ] | |
| for i, row in enumerate(features): | |
| if len(row) >= 4: | |
| lines.append(f" {i+1:3d} | {row[0]:8.4f} | {row[1]:7.4f} | {row[2]:8.4f} | {row[3]:7.4f}") | |
| else: | |
| lines.append(f" {i+1:3d} | (malformed row: {row})") | |
| lines.append("") | |
| lines.append(f"Provide exactly {len(features)} decisions as a JSON object.") | |
| return "\n".join(lines) | |
| _last_reasoning: str = "" | |
| def _parse_llm_decisions(content: str, expected_count: int) -> list[int]: | |
| global _last_reasoning | |
| stripped = content.strip() | |
| if stripped.startswith("```"): | |
| stripped = re.sub(r'^```[\w]*\n?', '', stripped) | |
| stripped = re.sub(r'\n?```$', '', stripped.strip()) | |
| try: | |
| parsed = json.loads(stripped) | |
| if isinstance(parsed, dict) and "decisions" in parsed: | |
| response = LLMResponse(**parsed) | |
| _last_reasoning = response.reasoning | |
| return _normalize_decisions([int(d) for d in response.decisions], expected_count) | |
| except Exception: | |
| pass | |
| try: | |
| parsed = json.loads(stripped) | |
| if isinstance(parsed, dict) and "decisions" in parsed: | |
| decisions = [int(d) for d in parsed["decisions"]] | |
| return _normalize_decisions(decisions, expected_count) | |
| except Exception: | |
| pass | |
| match = re.search(r'\[[\s\d,]+\]', content) | |
| if match: | |
| try: | |
| decisions = json.loads(match.group()) | |
| return _normalize_decisions([int(d) for d in decisions], expected_count) | |
| except Exception: | |
| pass | |
| return [1] * expected_count | |
| def _normalize_decisions(decisions: list[int], expected: int) -> list[int]: | |
| clamped = [1 if d >= 1 else 0 for d in decisions] | |
| clamped = clamped[:expected] | |
| while len(clamped) < expected: | |
| clamped.append(1) | |
| return clamped | |
| def _call_llm(step: int, features: list[list[float]]) -> list[int]: | |
| global _last_reasoning | |
| _last_reasoning = "Fallback triggered." | |
| user_prompt = _build_user_prompt(step, features) | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| response = _client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=1500, | |
| temperature=0.0, | |
| ) | |
| content = response.choices[0].message.content or "" | |
| return _parse_llm_decisions(content, len(features)) | |
| except Exception as e: | |
| time.sleep(1) | |
| fallback_decisions = [] | |
| for row in features: | |
| if len(row) >= 4: | |
| fallback_decisions.append(1 if row[3] >= 0.7 else 0) | |
| else: | |
| fallback_decisions.append(1) | |
| return fallback_decisions | |
| def run_inference() -> None: | |
| episode_id: str = "unknown" | |
| total_reward: float = 0.0 | |
| steps_completed: int = 0 | |
| status: str = "SUCCESS" | |
| try: | |
| env = FinAuditorEnvironment() | |
| obs = env.reset() | |
| episode_id = getattr(env.state, 'episode_id', "test_run") | |
| start_payload = { | |
| "episode_id": episode_id, | |
| "model": MODEL_NAME, | |
| "difficulty": TASK_ID, | |
| "max_steps": MAX_STEPS | |
| } | |
| print(f"[START] {json.dumps(start_payload)}", flush=True) | |
| for step_num in range(1, MAX_STEPS + 1): | |
| step_reward = 0.0 | |
| features = obs.features | |
| if not features: | |
| action = AuditorAction(decisions=[]) | |
| _last_reasoning = "Empty matrix." | |
| else: | |
| decisions = _call_llm(step_num, features) | |
| action = AuditorAction(decisions=decisions) | |
| obs = env.step(action) | |
| step_reward = obs.reward if obs.reward is not None else 0.0 | |
| total_reward += step_reward | |
| steps_completed = step_num | |
| # FIX: Ensure fractional precision is retained for validation | |
| step_payload = { | |
| "step": step_num, | |
| "anomalies": len(features), | |
| "reward": round(float(step_reward), 4), | |
| "cumulative_reward": round(float(total_reward), 4), | |
| "done": bool(obs.done), | |
| "error": None, | |
| "reasoning": _last_reasoning[:120].replace('\n', ' ') + "...", | |
| "tp": getattr(env.state, 'last_tp', 0), | |
| "tn": getattr(env.state, 'last_tn', 0), | |
| "fp": getattr(env.state, 'last_fp', 0), | |
| "fn": getattr(env.state, 'last_fn', 0) | |
| } | |
| print(f"[STEP] {json.dumps(step_payload)}", flush=True) | |
| if obs.done: | |
| break | |
| except KeyboardInterrupt: | |
| status = "INTERRUPTED" | |
| except Exception as exc: | |
| status = "ERROR" | |
| traceback.print_exc(file=sys.stderr) | |
| avg_reward = total_reward / max(steps_completed, 1) | |
| end_payload = { | |
| "total_reward": round(float(total_reward), 4), | |
| "avg_reward": round(float(avg_reward), 4), | |
| "status": status | |
| } | |
| print(f"[END] {json.dumps(end_payload)}", flush=True) | |
| if __name__ == "__main__": | |
| run_inference() |