File size: 13,963 Bytes
e4d7d50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 | """
websocket_server.py (v2 — OpenEnv + Dual Agent)
─────────────────────────────────────────────────
FastAPI application that:
1. Loads TWO models at startup:
White → Qwen/Qwen2.5-0.5B-Instruct
Black → meta-llama/Llama-3.2-1B-Instruct
2. Registers the OpenEnv 0.1 HTTP API at /env/*
3. Runs continuous self-play games (white=Qwen vs black=Llama).
4. Streams every game event to all connected WebSocket clients.
5. Runs GRPO on the WHITE model only (Qwen) — Llama acts as fixed opponent.
OpenEnv endpoints (for external RL trainers):
POST /env/reset start a new episode
POST /env/step apply one action
GET /env/state inspect current state
GET /env/env_info environment metadata (HF Hub discoverability)
WebSocket endpoint: /ws
Health check: /health
API docs: /docs
"""
import asyncio
import json
import logging
import time
from contextlib import asynccontextmanager
from typing import Any
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from settings import settings
from chess_engine import ChessEngine
from agents.model_agent import ModelAgent
from grpo_trainer import GRPOTrainer
from openenv.router import router as openenv_router, init_env
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# ── Global state ──────────────────────────────────────────────────────────────
connected_clients: set[WebSocket] = set()
paused = False
game_count = 0
wallet_white = settings.starting_wallet
wallet_black = settings.starting_wallet
# Initialised in lifespan
white_agent: ModelAgent | None = None
black_agent: ModelAgent | None = None
trainer: GRPOTrainer | None = None
# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global white_agent, black_agent, trainer
logger.info("Loading WHITE model (%s) …", settings.white_model)
white_agent = ModelAgent(settings.white_model).load()
logger.info("Loading BLACK model (%s) …", settings.black_model)
black_agent = ModelAgent(settings.black_model).load()
# GRPO trains the WHITE agent (Qwen); Llama is a fixed opponent
trainer = GRPOTrainer(white_agent.model, white_agent.tokenizer)
# Initialise the OpenEnv environment (used by /env/* HTTP endpoints)
init_env(
white_model_id=settings.white_model,
black_model_id=settings.black_model,
)
logger.info("Both models ready. Starting auto-play loop …")
asyncio.create_task(game_loop())
yield
logger.info("Shutting down.")
app = FastAPI(
title="ChessEcon",
description=(
"Multi-Agent Chess Economy — OpenEnv 0.1 compliant environment. "
"White: Qwen2.5-0.5B | Black: Llama-3.2-1B | Training: GRPO"
),
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Register OpenEnv HTTP router at /env/*
app.include_router(openenv_router)
# ── Health ────────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"service": "chessecon",
"version": "2.0.0",
"openenv_version": "0.1",
"white_model": settings.white_model,
"black_model": settings.black_model,
"ws_clients": len(connected_clients),
"games_played": game_count,
}
# ── WebSocket endpoint ────────────────────────────────────────────────────────
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await ws.accept()
connected_clients.add(ws)
logger.info("WS client connected (%d total)", len(connected_clients))
# Send current state snapshot to new client immediately
try:
await ws.send_text(json.dumps({
"type": "status",
"data": {
"game_id": game_count,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"grpo_step": trainer._step if trainer else 0,
"message": f"Connected — game #{game_count} in progress",
}
}))
except Exception:
pass
try:
while True:
raw = await ws.receive_text()
try:
msg = json.loads(raw)
await handle_client_message(ws, msg)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
connected_clients.discard(ws)
logger.info("WS client disconnected (%d total)", len(connected_clients))
async def handle_client_message(ws: WebSocket, msg: dict):
global paused
action = msg.get("action", "")
if action == "ping":
await ws.send_text(json.dumps({"type": "pong", "data": {}}))
elif action == "pause":
paused = True
logger.info("Game loop paused")
elif action == "resume":
paused = False
logger.info("Game loop resumed")
# ── Broadcast helper ──────────────────────────────────────────────────────────
async def broadcast(event_type: str, data: dict[str, Any]):
if not connected_clients:
return
payload = json.dumps({"type": event_type, "data": data})
dead: set[WebSocket] = set()
for ws in list(connected_clients):
try:
await ws.send_text(payload)
except Exception:
dead.add(ws)
connected_clients.difference_update(dead)
# ── Main game loop ────────────────────────────────────────────────────────────
async def game_loop():
global game_count, wallet_white, wallet_black, paused
while True:
while paused:
await asyncio.sleep(0.5)
game_count += 1
engine = ChessEngine()
wallet_white -= settings.entry_fee
wallet_black -= settings.entry_fee
prize_pool = settings.entry_fee * 2 * settings.prize_pool_fraction
await broadcast("game_start", {
"game_id": game_count,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"prize_pool": round(prize_pool, 2),
"white_model": settings.white_model,
"black_model": settings.black_model,
"message": (
f"Game #{game_count} — "
f"Qwen(W) vs Llama(B) — "
f"Prize pool: {prize_pool:.1f} units"
),
})
trainer.start_game("white") # type: ignore[union-attr]
move_history: list[str] = []
# ── Play the game ─────────────────────────────────────────────────
while not engine.is_game_over and engine.move_number <= settings.max_moves:
while paused:
await asyncio.sleep(0.5)
current_color = engine.turn
# Select the right agent
active_agent = white_agent if current_color == "white" else black_agent
san, log_prob = await asyncio.get_event_loop().run_in_executor(
None,
active_agent.get_move, # type: ignore[union-attr]
engine, current_color, move_history,
)
# KL reference: only needed for WHITE (GRPO training target)
if current_color == "white":
ref_log_prob = await asyncio.get_event_loop().run_in_executor(
None,
white_agent.get_move_log_prob_only, # type: ignore[union-attr]
engine, current_color, move_history, san,
)
else:
ref_log_prob = log_prob # Black is fixed; KL = 0
uci = engine.apply_move_san(san)
if uci is None:
fallback = engine.random_legal_move_san()
if fallback is None:
break
san = fallback
uci = engine.apply_move_san(san) or ""
log_prob = 0.0
ref_log_prob = 0.0
trainer.record_move(log_prob, ref_log_prob) # type: ignore[union-attr]
move_history.append(san)
await broadcast("move", {
"game_id": game_count,
"player": current_color,
"model": settings.white_model if current_color == "white" else settings.black_model,
"move": san,
"uci": uci,
"fen": engine.fen,
"move_number": engine.move_number,
"turn": engine.turn,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"message": f"{'Qwen' if current_color == 'white' else 'Llama'} plays {san}",
})
await asyncio.sleep(settings.move_delay)
# ── Game over ─────────────────────────────────────────────────────
# If game ended by chess rules use that result; otherwise adjudicate by material
if engine.result:
result = engine.result
else:
# Count material: Q=9 R=5 B=3 N=3 P=1
piece_values = {1: 1, 2: 3, 3: 3, 4: 5, 5: 9} # pawn,knight,bishop,rook,queen
import chess as _chess
white_mat = sum(
piece_values.get(pt, 0)
for pt in range(1, 6)
for _ in engine.board.pieces(pt, _chess.WHITE)
)
black_mat = sum(
piece_values.get(pt, 0)
for pt in range(1, 6)
for _ in engine.board.pieces(pt, _chess.BLACK)
)
result = '1-0' if white_mat >= black_mat else '0-1' # always decisive
white_reward = 1.0 if result == "1-0" else (-1.0 if result == "0-1" else 0.0)
black_reward = 1.0 if result == "0-1" else (-1.0 if result == "1-0" else 0.0)
if result == "1-0":
wallet_white += prize_pool
elif result == "0-1":
wallet_black += prize_pool
else:
wallet_white += prize_pool / 2
wallet_black += prize_pool / 2
white_pnl = (
prize_pool if result == "1-0"
else prize_pool / 2 if result == "1/2-1/2"
else 0
) - settings.entry_fee
black_pnl = (
prize_pool if result == "0-1"
else prize_pool / 2 if result == "1/2-1/2"
else 0
) - settings.entry_fee
await broadcast("game_end", {
"game_id": game_count,
"result": result,
"reward": white_reward,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"prize_income": round(
prize_pool if result == "1-0"
else prize_pool / 2 if result == "1/2-1/2"
else 0, 2
),
"coaching_cost": 0,
"entry_fee": settings.entry_fee,
"net_pnl_white": round(white_pnl, 2),
"net_pnl_black": round(black_pnl, 2),
"move_count": len(move_history),
"white_model": settings.white_model,
"black_model": settings.black_model,
"message": f"Game #{game_count} ended — {result}",
})
# GRPO update (WHITE model only)
training_metrics = trainer.end_game( # type: ignore[union-attr]
reward=white_reward,
profit=white_pnl,
coaching_calls=0,
)
if training_metrics is not None:
await broadcast("training_step", {
"step": training_metrics.step,
"loss": round(training_metrics.loss, 6),
"reward": round(training_metrics.policy_reward, 4),
"kl_div": round(training_metrics.kl_div, 6),
"win_rate": round(training_metrics.win_rate, 4),
"avg_profit": round(training_metrics.avg_profit, 4),
"coaching_rate": round(training_metrics.coaching_rate, 4),
"model": settings.white_model,
"message": (
f"GRPO step {training_metrics.step} | "
f"loss={training_metrics.loss:.4f} "
f"win_rate={training_metrics.win_rate:.2%}"
),
})
await asyncio.sleep(1.0)
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(
"websocket_server:app",
host=settings.host,
port=settings.port,
reload=False,
log_level="info",
)
|