chessecon / backend /websocket_server.py
suvasis's picture
code add
e4d7d50
"""
websocket_server.py (v2 — OpenEnv + Dual Agent)
─────────────────────────────────────────────────
FastAPI application that:
1. Loads TWO models at startup:
White → Qwen/Qwen2.5-0.5B-Instruct
Black → meta-llama/Llama-3.2-1B-Instruct
2. Registers the OpenEnv 0.1 HTTP API at /env/*
3. Runs continuous self-play games (white=Qwen vs black=Llama).
4. Streams every game event to all connected WebSocket clients.
5. Runs GRPO on the WHITE model only (Qwen) — Llama acts as fixed opponent.
OpenEnv endpoints (for external RL trainers):
POST /env/reset start a new episode
POST /env/step apply one action
GET /env/state inspect current state
GET /env/env_info environment metadata (HF Hub discoverability)
WebSocket endpoint: /ws
Health check: /health
API docs: /docs
"""
import asyncio
import json
import logging
import time
from contextlib import asynccontextmanager
from typing import Any
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from settings import settings
from chess_engine import ChessEngine
from agents.model_agent import ModelAgent
from grpo_trainer import GRPOTrainer
from openenv.router import router as openenv_router, init_env
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# ── Global state ──────────────────────────────────────────────────────────────
connected_clients: set[WebSocket] = set()
paused = False
game_count = 0
wallet_white = settings.starting_wallet
wallet_black = settings.starting_wallet
# Initialised in lifespan
white_agent: ModelAgent | None = None
black_agent: ModelAgent | None = None
trainer: GRPOTrainer | None = None
# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global white_agent, black_agent, trainer
logger.info("Loading WHITE model (%s) …", settings.white_model)
white_agent = ModelAgent(settings.white_model).load()
logger.info("Loading BLACK model (%s) …", settings.black_model)
black_agent = ModelAgent(settings.black_model).load()
# GRPO trains the WHITE agent (Qwen); Llama is a fixed opponent
trainer = GRPOTrainer(white_agent.model, white_agent.tokenizer)
# Initialise the OpenEnv environment (used by /env/* HTTP endpoints)
init_env(
white_model_id=settings.white_model,
black_model_id=settings.black_model,
)
logger.info("Both models ready. Starting auto-play loop …")
asyncio.create_task(game_loop())
yield
logger.info("Shutting down.")
app = FastAPI(
title="ChessEcon",
description=(
"Multi-Agent Chess Economy — OpenEnv 0.1 compliant environment. "
"White: Qwen2.5-0.5B | Black: Llama-3.2-1B | Training: GRPO"
),
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Register OpenEnv HTTP router at /env/*
app.include_router(openenv_router)
# ── Health ────────────────────────────────────────────────────────────────────
@app.get("/health")
async def health():
return {
"status": "ok",
"service": "chessecon",
"version": "2.0.0",
"openenv_version": "0.1",
"white_model": settings.white_model,
"black_model": settings.black_model,
"ws_clients": len(connected_clients),
"games_played": game_count,
}
# ── WebSocket endpoint ────────────────────────────────────────────────────────
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await ws.accept()
connected_clients.add(ws)
logger.info("WS client connected (%d total)", len(connected_clients))
# Send current state snapshot to new client immediately
try:
await ws.send_text(json.dumps({
"type": "status",
"data": {
"game_id": game_count,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"grpo_step": trainer._step if trainer else 0,
"message": f"Connected — game #{game_count} in progress",
}
}))
except Exception:
pass
try:
while True:
raw = await ws.receive_text()
try:
msg = json.loads(raw)
await handle_client_message(ws, msg)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
connected_clients.discard(ws)
logger.info("WS client disconnected (%d total)", len(connected_clients))
async def handle_client_message(ws: WebSocket, msg: dict):
global paused
action = msg.get("action", "")
if action == "ping":
await ws.send_text(json.dumps({"type": "pong", "data": {}}))
elif action == "pause":
paused = True
logger.info("Game loop paused")
elif action == "resume":
paused = False
logger.info("Game loop resumed")
# ── Broadcast helper ──────────────────────────────────────────────────────────
async def broadcast(event_type: str, data: dict[str, Any]):
if not connected_clients:
return
payload = json.dumps({"type": event_type, "data": data})
dead: set[WebSocket] = set()
for ws in list(connected_clients):
try:
await ws.send_text(payload)
except Exception:
dead.add(ws)
connected_clients.difference_update(dead)
# ── Main game loop ────────────────────────────────────────────────────────────
async def game_loop():
global game_count, wallet_white, wallet_black, paused
while True:
while paused:
await asyncio.sleep(0.5)
game_count += 1
engine = ChessEngine()
wallet_white -= settings.entry_fee
wallet_black -= settings.entry_fee
prize_pool = settings.entry_fee * 2 * settings.prize_pool_fraction
await broadcast("game_start", {
"game_id": game_count,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"prize_pool": round(prize_pool, 2),
"white_model": settings.white_model,
"black_model": settings.black_model,
"message": (
f"Game #{game_count} — "
f"Qwen(W) vs Llama(B) — "
f"Prize pool: {prize_pool:.1f} units"
),
})
trainer.start_game("white") # type: ignore[union-attr]
move_history: list[str] = []
# ── Play the game ─────────────────────────────────────────────────
while not engine.is_game_over and engine.move_number <= settings.max_moves:
while paused:
await asyncio.sleep(0.5)
current_color = engine.turn
# Select the right agent
active_agent = white_agent if current_color == "white" else black_agent
san, log_prob = await asyncio.get_event_loop().run_in_executor(
None,
active_agent.get_move, # type: ignore[union-attr]
engine, current_color, move_history,
)
# KL reference: only needed for WHITE (GRPO training target)
if current_color == "white":
ref_log_prob = await asyncio.get_event_loop().run_in_executor(
None,
white_agent.get_move_log_prob_only, # type: ignore[union-attr]
engine, current_color, move_history, san,
)
else:
ref_log_prob = log_prob # Black is fixed; KL = 0
uci = engine.apply_move_san(san)
if uci is None:
fallback = engine.random_legal_move_san()
if fallback is None:
break
san = fallback
uci = engine.apply_move_san(san) or ""
log_prob = 0.0
ref_log_prob = 0.0
trainer.record_move(log_prob, ref_log_prob) # type: ignore[union-attr]
move_history.append(san)
await broadcast("move", {
"game_id": game_count,
"player": current_color,
"model": settings.white_model if current_color == "white" else settings.black_model,
"move": san,
"uci": uci,
"fen": engine.fen,
"move_number": engine.move_number,
"turn": engine.turn,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"message": f"{'Qwen' if current_color == 'white' else 'Llama'} plays {san}",
})
await asyncio.sleep(settings.move_delay)
# ── Game over ─────────────────────────────────────────────────────
# If game ended by chess rules use that result; otherwise adjudicate by material
if engine.result:
result = engine.result
else:
# Count material: Q=9 R=5 B=3 N=3 P=1
piece_values = {1: 1, 2: 3, 3: 3, 4: 5, 5: 9} # pawn,knight,bishop,rook,queen
import chess as _chess
white_mat = sum(
piece_values.get(pt, 0)
for pt in range(1, 6)
for _ in engine.board.pieces(pt, _chess.WHITE)
)
black_mat = sum(
piece_values.get(pt, 0)
for pt in range(1, 6)
for _ in engine.board.pieces(pt, _chess.BLACK)
)
result = '1-0' if white_mat >= black_mat else '0-1' # always decisive
white_reward = 1.0 if result == "1-0" else (-1.0 if result == "0-1" else 0.0)
black_reward = 1.0 if result == "0-1" else (-1.0 if result == "1-0" else 0.0)
if result == "1-0":
wallet_white += prize_pool
elif result == "0-1":
wallet_black += prize_pool
else:
wallet_white += prize_pool / 2
wallet_black += prize_pool / 2
white_pnl = (
prize_pool if result == "1-0"
else prize_pool / 2 if result == "1/2-1/2"
else 0
) - settings.entry_fee
black_pnl = (
prize_pool if result == "0-1"
else prize_pool / 2 if result == "1/2-1/2"
else 0
) - settings.entry_fee
await broadcast("game_end", {
"game_id": game_count,
"result": result,
"reward": white_reward,
"wallet_white": round(wallet_white, 2),
"wallet_black": round(wallet_black, 2),
"prize_income": round(
prize_pool if result == "1-0"
else prize_pool / 2 if result == "1/2-1/2"
else 0, 2
),
"coaching_cost": 0,
"entry_fee": settings.entry_fee,
"net_pnl_white": round(white_pnl, 2),
"net_pnl_black": round(black_pnl, 2),
"move_count": len(move_history),
"white_model": settings.white_model,
"black_model": settings.black_model,
"message": f"Game #{game_count} ended — {result}",
})
# GRPO update (WHITE model only)
training_metrics = trainer.end_game( # type: ignore[union-attr]
reward=white_reward,
profit=white_pnl,
coaching_calls=0,
)
if training_metrics is not None:
await broadcast("training_step", {
"step": training_metrics.step,
"loss": round(training_metrics.loss, 6),
"reward": round(training_metrics.policy_reward, 4),
"kl_div": round(training_metrics.kl_div, 6),
"win_rate": round(training_metrics.win_rate, 4),
"avg_profit": round(training_metrics.avg_profit, 4),
"coaching_rate": round(training_metrics.coaching_rate, 4),
"model": settings.white_model,
"message": (
f"GRPO step {training_metrics.step} | "
f"loss={training_metrics.loss:.4f} "
f"win_rate={training_metrics.win_rate:.2%}"
),
})
await asyncio.sleep(1.0)
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(
"websocket_server:app",
host=settings.host,
port=settings.port,
reload=False,
log_level="info",
)