File size: 23,384 Bytes
9cb3002 d5dc8ac 9cb3002 5b8d71e 9cb3002 8489eaa d5dc8ac 8489eaa d5dc8ac 8489eaa d5dc8ac 8489eaa 9cb3002 df708fe 8489eaa 9cb3002 8489eaa 9cb3002 8489eaa 9cb3002 8489eaa 9cb3002 5686d79 9cb3002 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 | """
Multi-Agent Trading API Server.
Uses the PettingZoo AEC MultiAgentTradingEnv with three RL agents
(RiskManager β PortfolioManager β Trader) that negotiate each cycle.
Advisory agents (QuantResearcher, FundamentalAnalyst) run in parallel
to enrich the UI with signal context but do NOT participate in the AEC loop.
"""
from pathlib import Path
import asyncio
import os
import numpy as np
import uvicorn
from fastapi import BackgroundTasks, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from agents.fa_agent import FundamentalAnalyst
from agents.researcher import QuantResearcher
from env.multi_agent_env import (
MultiAgentTradingEnv,
RISK_MANAGER,
PORTFOLIO_MGR,
TRADER,
ALL_AGENTS,
)
# TradingEnv kept for backward compat data generation only (not used in endpoints)
from training.config import TrainingConfig
from training.train_multi_agent import (
RulePortfolioManagerPolicy,
RuleRiskManagerPolicy,
RuleTraderPolicy,
)
from huggingface_hub import snapshot_download
class GRPOAgent:
"""Bridges the trained GRPO model to the UI demo."""
def __init__(self, model_id=None):
self.model_id = model_id or os.getenv("GRPO_MODEL_ID", "ARKAISW/QuantHive-GRPO-Trader")
self.model = None
self.tokenizer = None
self.is_ready = False
def load(self):
try:
import torch
except Exception as e:
print(f"PyTorch unavailable ({e}). Falling back to rule-based.")
return False
if not torch.cuda.is_available():
print("CUDA not available in this environment. Falling back to rule-based.")
return False
try:
from unsloth import FastLanguageModel
except Exception as e:
print(f"Could not import Unsloth: {e}. Falling back to rule-based.")
return False
try:
print(f"Attempting to sync GRPO model from {self.model_id}...")
# Auto-download from HF Hub if not local
local_dir = Path("models") / "grpo_hf_trained"
local_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id=self.model_id, local_dir=local_dir,
allow_patterns=["*.json", "*.bin", "*.safetensors", "*.txt"])
print(f"Loading weights from {local_dir}...")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=str(local_dir),
max_seq_length=2048,
load_in_4bit=True,
)
FastLanguageModel.for_inference(self.model)
self.is_ready = True
print("β
GRPO Model loaded successfully.")
return True
except Exception as e:
print(f"Could not load GRPO model: {e}")
return False
def act(self, obs: np.ndarray) -> dict:
"""Sample an action from the GRPO model."""
if not self.is_ready:
return None
try:
import torch
# Construct a prompt that looks like the training scenarios
prompt = f"Observation: {obs[:5].tolist()}... (truncated)\nResponse:"
device = getattr(self.model, "device", "cuda")
inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
# Fast generation for demo smoothness
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=32,
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Basic parsing of the model's 'thought' or action intent
# If the model says 'buy' or 'up', we signal 1, etc.
direction = 0
if "buy" in response.lower() or "up" in response.lower():
direction = 1
elif "sell" in response.lower() or "down" in response.lower() or "short" in response.lower():
direction = 2
return {
"direction": direction,
"size": np.array([0.15], dtype=np.float32),
"sl": np.array([0.0], dtype=np.float32),
"tp": np.array([0.0], dtype=np.float32),
"thought": response[:100] # Expose thought to UI
}
except Exception as e:
print(f"GRPO inference error: {e}")
return None
ROOT_DIR = Path(__file__).resolve().parents[1]
FRONTEND_DIST = ROOT_DIR / "ui" / "dist"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def make_initial_state():
return {
"is_running": False,
"current_step": 0,
# Five logical agents for the UI (maps to the 3 PZ agents + 2 advisory)
"agents": {
"Researcher": {"message": "Scanning the tape.", "confidence": 0.0, "status": "idle"},
"Fundamental Analyst": {"message": "Watching macro tone.", "confidence": 0.0, "status": "idle"},
"Risk Manager": {"message": "Limits standing by.", "confidence": 0.0, "status": "idle"},
"Trader": {"message": "Desk is flat.", "confidence": 0.0, "status": "idle"},
"Portfolio Manager": {"message": "Waiting for conviction.", "confidence": 0.0, "status": "idle"},
},
"portfolio": {"value": 100000.0, "cash": 100000.0, "positions": {}},
"metrics": {"reward": 0.0, "grade": 0.0, "drawdown": 0.0, "sharpe": 0.0},
"chart": {"price": 50000.0, "trade": None, "price_change": 0.0},
"trade": {
"pulse": 0, "side": "HOLD", "size": 0.0, "price": 50000.0,
"sl": 0.0, "tp": 0.0, "portfolio_delta": 0.0, "notional": 0.0,
"reason": "Waiting for the first coordinated decision.",
"override": False,
},
"flow": [],
"engine": {
"name": "Multi-Agent Governance (PettingZoo AEC)",
"mode": "Rule Fallback",
"policy_active": False,
"note": "Three independent RL agents negotiating via AEC turns: RiskManager β PortfolioManager β Trader.",
},
"negotiation": { # Exposes per-agent negotiation each cycle
"rm_size_limit": 0.5,
"rm_allow_new": True,
"rm_force_reduce": False,
"pm_cap_alloc": 0.5,
"pm_override": 0.0,
"governance_log": [],
},
}
sim_state = make_initial_state()
class SimulationRunner:
"""
Orchestrates the PettingZoo AEC loop.
Each call to step() runs one full AEC cycle:
RiskManager β PortfolioManager β Trader β market advance
Advisory agents (Researcher, FA) provide contextual signals
for the UI but do NOT affect the AEC action pipeline.
"""
def __init__(self):
self.config = TrainingConfig(tickers=["AAPL"], fast_mode=True, max_steps=100)
# Reduced commission for demo realism (preventing bleed from rule-based noise)
self.config.commission = 0.0001
# ββ PettingZoo multi-agent environment ββββββββββββββββββββββββββββββ
self.env = MultiAgentTradingEnv(
df=None,
initial_cash=self.config.initial_cash,
ticker=self.config.tickers[0],
commission=self.config.commission,
max_steps=self.config.max_steps,
)
# ββ Rule-based AEC policies βββββββββββββββββββββββββββββββββββββββββ
self.policies = {
RISK_MANAGER: RuleRiskManagerPolicy(),
PORTFOLIO_MGR: RulePortfolioManagerPolicy(),
TRADER: RuleTraderPolicy(),
}
# ββ Advisory agents (UI flavor only) ββββββββββββββββββββββββββββββββ
self.researcher = QuantResearcher()
self.fa_agent = FundamentalAnalyst(fast_mode=self.config.fast_mode)
# ββ OpenEnv PZ env (separate instance for judge endpoints) βββββββββ
self._openenv_env = MultiAgentTradingEnv(
df=None,
initial_cash=self.config.initial_cash,
ticker=self.config.tickers[0],
commission=self.config.commission,
max_steps=self.config.max_steps,
)
self._openenv_policies = {
RISK_MANAGER: RuleRiskManagerPolicy(),
PORTFOLIO_MGR: RulePortfolioManagerPolicy(),
}
self._openenv_env.reset()
# ββ GRPO ML Agent (Bridges to real trained weights) ββββββββββββββββββ
self.grpo_agent = GRPOAgent()
self.is_ml_active = self.grpo_agent.load()
# ββ Initialize demo PZ env ββββββββββββββββββββββββββββββββββββββββββ
self.env.reset()
self.done = False
sim_state["engine"] = {
"name": "Multi-Agent Governance (PettingZoo AEC)",
"mode": "GRPO (Trained Model)" if self.is_ml_active else "Rule Fallback",
"policy_active": self.is_ml_active,
"note": "Three independent RL agents negotiating via AEC turns: RiskManager β PortfolioManager β Trader.",
}
def step(self):
"""Run one full AEC cycle (RM β PM β Trader β market advance)."""
if self.done:
self.env.reset()
self.fa_agent.reset()
self.done = False
global sim_state
previous_value = sim_state["portfolio"]["value"]
previous_price = sim_state["chart"]["price"]
# ββ Get a base observation for advisory agents ββββββββββββββββββββββ
base_obs = self.env.observe(RISK_MANAGER)
# ββ Advisory: Researcher ββββββββββββββββββββββββββββββββββββββββββββ
r_sig, r_conf, r_reasoning = self.researcher(base_obs)
researcher_message = f"{r_sig.title()} bias. {r_reasoning}"
sim_state["agents"]["Researcher"] = {
"message": researcher_message,
"confidence": r_conf,
"status": "active",
}
# ββ Advisory: Fundamental Analyst βββββββββββββββββββββββββββββββββββ
fa_sent, fa_reasoning = self.fa_agent(base_obs)
sim_state["agents"]["Fundamental Analyst"] = {
"message": fa_reasoning,
"confidence": abs((fa_sent * 2.0) - 1.0),
"status": "active",
}
# ββ AEC Cycle: Step through all 3 agents βββββββββββββββββββββββββββ
rm_action = None
pm_action = None
trader_action = None
cycle_rewards = {}
for agent in [RISK_MANAGER, PORTFOLIO_MGR, TRADER]:
if not self.env.agents:
self.done = True
break
obs = self.env.observe(agent)
# Use ML if active and it's the Trader's turn
action = None
if self.is_ml_active and agent == TRADER:
ml_action = self.grpo_agent.act(obs)
if ml_action:
action = ml_action
if action is None:
action = self.policies[agent].act(obs)
if agent == RISK_MANAGER:
rm_action = action
elif agent == PORTFOLIO_MGR:
pm_action = action
elif agent == TRADER:
trader_action = action
self.env.step(action)
cycle_rewards[agent] = self.env.rewards.get(agent, 0.0)
# ββ Check termination βββββββββββββββββββββββββββββββββββββββββββββββ
if not self.env.agents or all(self.env.terminations.get(ag, False) for ag in ALL_AGENTS):
self.done = True
# ββ Extract state from the env ββββββββββββββββββββββββββββββββββββββ
env_state = self.env.state()
trader_info = self.env.infos.get(TRADER, {})
current_price = env_state["price"]
portfolio_value = env_state["portfolio_value"]
portfolio_delta = portfolio_value - previous_value
price_change = current_price - previous_price
# ββ Parse negotiation messages ββββββββββββββββββββββββββββββββββββββ
rm_msg = env_state.get("rm_message", [0.5, 1.0, 0.0])
pm_msg = env_state.get("pm_message", [0.5, 0.0])
rm_size_limit = float(rm_msg[0]) if len(rm_msg) > 0 else 0.5
rm_allow_new = bool(rm_msg[1] > 0.5) if len(rm_msg) > 1 else True
rm_force_reduce = bool(rm_msg[2] > 0.5) if len(rm_msg) > 2 else False
pm_cap_alloc = float(pm_msg[0]) if len(pm_msg) > 0 else 0.5
pm_override_s = float(pm_msg[1]) if len(pm_msg) > 1 else 0.0
# ββ Update UI state: Risk Manager βββββββββββββββββββββββββββββββββββ
rm_reasoning = f"Limit {rm_size_limit:.2f}"
if rm_force_reduce:
rm_reasoning += " | FORCE REDUCE active"
if not rm_allow_new:
rm_reasoning += " | New positions BLOCKED"
sim_state["agents"]["Risk Manager"] = {
"message": rm_reasoning,
"confidence": 1.0 - rm_size_limit,
"status": "active" if rm_size_limit < 0.4 or rm_force_reduce else "idle",
}
# ββ Update UI state: Portfolio Manager ββββββββββββββββββββββββββββββ
pm_message = f"Capital allocation: {pm_cap_alloc:.0%}"
if pm_override_s > 0.7:
pm_message += " | VETO signal active"
sim_state["agents"]["Portfolio Manager"] = {
"message": pm_message,
"confidence": pm_cap_alloc,
"status": "active" if pm_override_s > 0.5 or pm_cap_alloc < 0.3 else "idle",
}
# ββ Update UI state: Trader βββββββββββββββββββββββββββββββββββββββββ
gov = trader_info.get("governance", {})
executed = gov.get("executed", {}) if gov else {}
direction = executed.get("direction", 0) if executed else 0
size = executed.get("size", 0.0) if executed else 0.0
sl = executed.get("sl", 0.0) if executed else 0.0
tp = executed.get("tp", 0.0) if executed else 0.0
interventions = gov.get("interventions", []) if gov else []
was_compliant = gov.get("was_compliant", True) if gov else True
dir_str = ["HOLD", "BUY", "SELL"][direction]
trader_reasoning = f"{dir_str} {size:.2f}"
if not was_compliant:
intervention_types = [i.get("type", "?") for i in interventions]
trader_reasoning += f" (overridden: {', '.join(intervention_types)})"
else:
trader_reasoning += " (compliant β no governance intervention)"
sim_state["agents"]["Trader"] = {
"message": trader_reasoning,
"confidence": size,
"status": "active" if direction != 0 else "idle",
}
# ββ Sim state update ββββββββββββββββββββββββββββββββββββββββββββββββ
sim_state["current_step"] = env_state["step"]
sim_state["portfolio"] = {
"value": portfolio_value,
"cash": env_state["cash"],
"positions": env_state["positions"],
}
sim_state["metrics"] = {
"reward": float(cycle_rewards.get(TRADER, 0.0)),
"grade": trader_info.get("grade", 0.0),
"drawdown": env_state["max_drawdown"],
"sharpe": env_state["sharpe_ratio"],
}
sim_state["chart"] = {
"price": current_price,
"trade": dir_str if direction != 0 else None,
"price_change": price_change,
}
sim_state["trade"] = {
"pulse": sim_state["trade"]["pulse"] + 1,
"side": dir_str,
"size": float(size),
"price": float(current_price),
"sl": float(sl),
"tp": float(tp),
"portfolio_delta": float(portfolio_delta),
"notional": float(portfolio_value * size if direction != 0 else 0.0),
"reason": trader_reasoning,
"override": not was_compliant,
}
# ββ Flow graph for UI βββββββββββββββββββββββββββββββββββββββββββββββ
sim_state["flow"] = [
{"from": "Researcher", "to": "Risk Manager", "strength": float(r_conf), "active": True, "tone": "signal"},
{"from": "Researcher", "to": "Portfolio Manager", "strength": float(r_conf), "active": r_sig != "neutral", "tone": "research"},
{"from": "Fundamental Analyst", "to": "Portfolio Manager", "strength": float(abs((fa_sent * 2.0) - 1.0)), "active": True, "tone": "macro"},
{"from": "Risk Manager", "to": "Trader", "strength": float(1.0 - rm_size_limit), "active": True, "tone": "risk"},
{"from": "Portfolio Manager", "to": "Trader", "strength": float(pm_cap_alloc), "active": True, "tone": "approval"},
{"from": "Trader", "to": "Market", "strength": float(size), "active": direction != 0, "tone": dir_str.lower()},
]
# ββ Negotiation state (multi-agent-specific) βββββββββββββββββββββββ
sim_state["negotiation"] = {
"rm_size_limit": rm_size_limit,
"rm_allow_new": rm_allow_new,
"rm_force_reduce": rm_force_reduce,
"pm_cap_alloc": pm_cap_alloc,
"pm_override": pm_override_s,
"governance_log": env_state.get("governance_log", []),
}
runner = None
async def simulation_loop():
global sim_state, runner
if runner is None:
runner = SimulationRunner()
while sim_state["is_running"]:
runner.step()
await asyncio.sleep(0.4)
@app.get("/state")
@app.get("/api/state")
def get_state():
return sim_state
@app.post("/start")
@app.post("/api/start")
async def start_sim(background_tasks: BackgroundTasks):
global sim_state
if not sim_state["is_running"]:
sim_state["is_running"] = True
background_tasks.add_task(simulation_loop)
return {"status": "started"}
@app.post("/stop")
@app.post("/api/stop")
def stop_sim():
global sim_state
sim_state["is_running"] = False
return {"status": "stopped"}
@app.post("/api/step")
def step_sim():
global runner
if runner is None:
runner = SimulationRunner()
runner.step()
return {"status": "stepped"}
# --- OpenEnv Standard Endpoints for Judges ---
# These use the PettingZoo MultiAgentTradingEnv directly.
# RM and PM run rule-based policies; the Trader action comes from the external caller.
@app.post("/openenv/reset")
@app.post("/reset")
async def openenv_reset():
"""Standard OpenEnv reset β resets the multi-agent PZ env.
Returns the Trader's initial observation."""
global runner
if runner is None:
runner = SimulationRunner()
runner._openenv_env.reset()
trader_obs = runner._openenv_env.observe(TRADER)
return {"observation": trader_obs.tolist(), "info": {}}
@app.post("/openenv/step")
@app.post("/step")
async def openenv_step(action: dict):
"""Standard OpenEnv step β runs a full AEC cycle.
RM and PM use rule-based policies. The submitted action is used for the Trader.
Returns trader's obs/reward/terminated/truncated/info."""
global runner
if runner is None:
runner = SimulationRunner()
env = runner._openenv_env
policies = runner._openenv_policies
# If the episode is over, auto-reset
if not env.agents:
env.reset()
# Run full AEC cycle: RM β PM β Trader
for agent in [RISK_MANAGER, PORTFOLIO_MGR, TRADER]:
if not env.agents:
break
if agent == TRADER:
# Use the externally-provided trader action
trader_action = {
"direction": int(action.get("direction", 0)),
"size": np.array([float(action.get("size", 0.0))], dtype=np.float32),
"sl": np.array([float(action.get("sl", 0.0))], dtype=np.float32),
"tp": np.array([float(action.get("tp", 0.0))], dtype=np.float32),
}
env.step(trader_action)
else:
obs = env.observe(agent)
agent_action = policies[agent].act(obs)
env.step(agent_action)
# Collect results from the Trader's perspective
trader_obs = env.observe(TRADER)
trader_reward = float(env.rewards.get(TRADER, 0.0))
terminated = bool(env.terminations.get(TRADER, False))
truncated = bool(env.truncations.get(TRADER, False))
trader_info = env.infos.get(TRADER, {})
return {
"observation": trader_obs.tolist(),
"reward": trader_reward,
"terminated": terminated,
"truncated": truncated,
"info": trader_info,
}
if FRONTEND_DIST.exists():
@app.get("/")
def serve_index():
return FileResponse(FRONTEND_DIST / "index.html")
@app.get("/{full_path:path}")
def serve_frontend(full_path: str):
asset_path = FRONTEND_DIST / full_path
if full_path and asset_path.exists() and asset_path.is_file():
return FileResponse(asset_path)
return FileResponse(FRONTEND_DIST / "index.html")
else:
@app.get("/")
def demo_not_built():
return JSONResponse(
{
"message": "Frontend bundle not found. Run `npm install && npm run build` inside `ui/`.",
"frontend_dist": str(FRONTEND_DIST),
}
)
def run_server():
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
run_server()
|