Spaces:
Runtime error
Runtime error
| """ | |
| websocket_server.py (v2 — OpenEnv + Dual Agent) | |
| ───────────────────────────────────────────────── | |
| FastAPI application that: | |
| 1. Loads TWO models at startup: | |
| White → Qwen/Qwen2.5-0.5B-Instruct | |
| Black → meta-llama/Llama-3.2-1B-Instruct | |
| 2. Registers the OpenEnv 0.1 HTTP API at /env/* | |
| 3. Runs continuous self-play games (white=Qwen vs black=Llama). | |
| 4. Streams every game event to all connected WebSocket clients. | |
| 5. Runs GRPO on the WHITE model only (Qwen) — Llama acts as fixed opponent. | |
| OpenEnv endpoints (for external RL trainers): | |
| POST /env/reset start a new episode | |
| POST /env/step apply one action | |
| GET /env/state inspect current state | |
| GET /env/env_info environment metadata (HF Hub discoverability) | |
| WebSocket endpoint: /ws | |
| Health check: /health | |
| API docs: /docs | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import Any | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from settings import settings | |
| from chess_engine import ChessEngine | |
| from agents.model_agent import ModelAgent | |
| from grpo_trainer import GRPOTrainer | |
| from openenv.router import router as openenv_router, init_env | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ── Global state ────────────────────────────────────────────────────────────── | |
| connected_clients: set[WebSocket] = set() | |
| paused = False | |
| game_count = 0 | |
| wallet_white = settings.starting_wallet | |
| wallet_black = settings.starting_wallet | |
| # Initialised in lifespan | |
| white_agent: ModelAgent | None = None | |
| black_agent: ModelAgent | None = None | |
| trainer: GRPOTrainer | None = None | |
| # ── Lifespan ────────────────────────────────────────────────────────────────── | |
| async def lifespan(app: FastAPI): | |
| global white_agent, black_agent, trainer | |
| logger.info("Loading WHITE model (%s) …", settings.white_model) | |
| white_agent = ModelAgent(settings.white_model).load() | |
| logger.info("Loading BLACK model (%s) …", settings.black_model) | |
| black_agent = ModelAgent(settings.black_model).load() | |
| # GRPO trains the WHITE agent (Qwen); Llama is a fixed opponent | |
| trainer = GRPOTrainer(white_agent.model, white_agent.tokenizer) | |
| # Initialise the OpenEnv environment (used by /env/* HTTP endpoints) | |
| init_env( | |
| white_model_id=settings.white_model, | |
| black_model_id=settings.black_model, | |
| ) | |
| logger.info("Both models ready. Starting auto-play loop …") | |
| asyncio.create_task(game_loop()) | |
| yield | |
| logger.info("Shutting down.") | |
| app = FastAPI( | |
| title="ChessEcon", | |
| description=( | |
| "Multi-Agent Chess Economy — OpenEnv 0.1 compliant environment. " | |
| "White: Qwen2.5-0.5B | Black: Llama-3.2-1B | Training: GRPO" | |
| ), | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Register OpenEnv HTTP router at /env/* | |
| app.include_router(openenv_router) | |
| # ── Health ──────────────────────────────────────────────────────────────────── | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "service": "chessecon", | |
| "version": "2.0.0", | |
| "openenv_version": "0.1", | |
| "white_model": settings.white_model, | |
| "black_model": settings.black_model, | |
| "ws_clients": len(connected_clients), | |
| "games_played": game_count, | |
| } | |
| # ── WebSocket endpoint ──────────────────────────────────────────────────────── | |
| async def websocket_endpoint(ws: WebSocket): | |
| await ws.accept() | |
| connected_clients.add(ws) | |
| logger.info("WS client connected (%d total)", len(connected_clients)) | |
| # Send current state snapshot to new client immediately | |
| try: | |
| await ws.send_text(json.dumps({ | |
| "type": "status", | |
| "data": { | |
| "game_id": game_count, | |
| "wallet_white": round(wallet_white, 2), | |
| "wallet_black": round(wallet_black, 2), | |
| "grpo_step": trainer._step if trainer else 0, | |
| "message": f"Connected — game #{game_count} in progress", | |
| } | |
| })) | |
| except Exception: | |
| pass | |
| try: | |
| while True: | |
| raw = await ws.receive_text() | |
| try: | |
| msg = json.loads(raw) | |
| await handle_client_message(ws, msg) | |
| except json.JSONDecodeError: | |
| pass | |
| except WebSocketDisconnect: | |
| connected_clients.discard(ws) | |
| logger.info("WS client disconnected (%d total)", len(connected_clients)) | |
| async def handle_client_message(ws: WebSocket, msg: dict): | |
| global paused | |
| action = msg.get("action", "") | |
| if action == "ping": | |
| await ws.send_text(json.dumps({"type": "pong", "data": {}})) | |
| elif action == "pause": | |
| paused = True | |
| logger.info("Game loop paused") | |
| elif action == "resume": | |
| paused = False | |
| logger.info("Game loop resumed") | |
| # ── Broadcast helper ────────────────────────────────────────────────────────── | |
| async def broadcast(event_type: str, data: dict[str, Any]): | |
| if not connected_clients: | |
| return | |
| payload = json.dumps({"type": event_type, "data": data}) | |
| dead: set[WebSocket] = set() | |
| for ws in list(connected_clients): | |
| try: | |
| await ws.send_text(payload) | |
| except Exception: | |
| dead.add(ws) | |
| connected_clients.difference_update(dead) | |
| # ── Main game loop ──────────────────────────────────────────────────────────── | |
| async def game_loop(): | |
| global game_count, wallet_white, wallet_black, paused | |
| while True: | |
| while paused: | |
| await asyncio.sleep(0.5) | |
| game_count += 1 | |
| engine = ChessEngine() | |
| wallet_white -= settings.entry_fee | |
| wallet_black -= settings.entry_fee | |
| prize_pool = settings.entry_fee * 2 * settings.prize_pool_fraction | |
| await broadcast("game_start", { | |
| "game_id": game_count, | |
| "wallet_white": round(wallet_white, 2), | |
| "wallet_black": round(wallet_black, 2), | |
| "prize_pool": round(prize_pool, 2), | |
| "white_model": settings.white_model, | |
| "black_model": settings.black_model, | |
| "message": ( | |
| f"Game #{game_count} — " | |
| f"Qwen(W) vs Llama(B) — " | |
| f"Prize pool: {prize_pool:.1f} units" | |
| ), | |
| }) | |
| trainer.start_game("white") # type: ignore[union-attr] | |
| move_history: list[str] = [] | |
| # ── Play the game ───────────────────────────────────────────────── | |
| while not engine.is_game_over and engine.move_number <= settings.max_moves: | |
| while paused: | |
| await asyncio.sleep(0.5) | |
| current_color = engine.turn | |
| # Select the right agent | |
| active_agent = white_agent if current_color == "white" else black_agent | |
| san, log_prob = await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| active_agent.get_move, # type: ignore[union-attr] | |
| engine, current_color, move_history, | |
| ) | |
| # KL reference: only needed for WHITE (GRPO training target) | |
| if current_color == "white": | |
| ref_log_prob = await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| white_agent.get_move_log_prob_only, # type: ignore[union-attr] | |
| engine, current_color, move_history, san, | |
| ) | |
| else: | |
| ref_log_prob = log_prob # Black is fixed; KL = 0 | |
| uci = engine.apply_move_san(san) | |
| if uci is None: | |
| fallback = engine.random_legal_move_san() | |
| if fallback is None: | |
| break | |
| san = fallback | |
| uci = engine.apply_move_san(san) or "" | |
| log_prob = 0.0 | |
| ref_log_prob = 0.0 | |
| trainer.record_move(log_prob, ref_log_prob) # type: ignore[union-attr] | |
| move_history.append(san) | |
| await broadcast("move", { | |
| "game_id": game_count, | |
| "player": current_color, | |
| "model": settings.white_model if current_color == "white" else settings.black_model, | |
| "move": san, | |
| "uci": uci, | |
| "fen": engine.fen, | |
| "move_number": engine.move_number, | |
| "turn": engine.turn, | |
| "wallet_white": round(wallet_white, 2), | |
| "wallet_black": round(wallet_black, 2), | |
| "message": f"{'Qwen' if current_color == 'white' else 'Llama'} plays {san}", | |
| }) | |
| await asyncio.sleep(settings.move_delay) | |
| # ── Game over ───────────────────────────────────────────────────── | |
| # If game ended by chess rules use that result; otherwise adjudicate by material | |
| if engine.result: | |
| result = engine.result | |
| else: | |
| # Count material: Q=9 R=5 B=3 N=3 P=1 | |
| piece_values = {1: 1, 2: 3, 3: 3, 4: 5, 5: 9} # pawn,knight,bishop,rook,queen | |
| import chess as _chess | |
| white_mat = sum( | |
| piece_values.get(pt, 0) | |
| for pt in range(1, 6) | |
| for _ in engine.board.pieces(pt, _chess.WHITE) | |
| ) | |
| black_mat = sum( | |
| piece_values.get(pt, 0) | |
| for pt in range(1, 6) | |
| for _ in engine.board.pieces(pt, _chess.BLACK) | |
| ) | |
| result = '1-0' if white_mat >= black_mat else '0-1' # always decisive | |
| white_reward = 1.0 if result == "1-0" else (-1.0 if result == "0-1" else 0.0) | |
| black_reward = 1.0 if result == "0-1" else (-1.0 if result == "1-0" else 0.0) | |
| if result == "1-0": | |
| wallet_white += prize_pool | |
| elif result == "0-1": | |
| wallet_black += prize_pool | |
| else: | |
| wallet_white += prize_pool / 2 | |
| wallet_black += prize_pool / 2 | |
| white_pnl = ( | |
| prize_pool if result == "1-0" | |
| else prize_pool / 2 if result == "1/2-1/2" | |
| else 0 | |
| ) - settings.entry_fee | |
| black_pnl = ( | |
| prize_pool if result == "0-1" | |
| else prize_pool / 2 if result == "1/2-1/2" | |
| else 0 | |
| ) - settings.entry_fee | |
| await broadcast("game_end", { | |
| "game_id": game_count, | |
| "result": result, | |
| "reward": white_reward, | |
| "wallet_white": round(wallet_white, 2), | |
| "wallet_black": round(wallet_black, 2), | |
| "prize_income": round( | |
| prize_pool if result == "1-0" | |
| else prize_pool / 2 if result == "1/2-1/2" | |
| else 0, 2 | |
| ), | |
| "coaching_cost": 0, | |
| "entry_fee": settings.entry_fee, | |
| "net_pnl_white": round(white_pnl, 2), | |
| "net_pnl_black": round(black_pnl, 2), | |
| "move_count": len(move_history), | |
| "white_model": settings.white_model, | |
| "black_model": settings.black_model, | |
| "message": f"Game #{game_count} ended — {result}", | |
| }) | |
| # GRPO update (WHITE model only) | |
| training_metrics = trainer.end_game( # type: ignore[union-attr] | |
| reward=white_reward, | |
| profit=white_pnl, | |
| coaching_calls=0, | |
| ) | |
| if training_metrics is not None: | |
| await broadcast("training_step", { | |
| "step": training_metrics.step, | |
| "loss": round(training_metrics.loss, 6), | |
| "reward": round(training_metrics.policy_reward, 4), | |
| "kl_div": round(training_metrics.kl_div, 6), | |
| "win_rate": round(training_metrics.win_rate, 4), | |
| "avg_profit": round(training_metrics.avg_profit, 4), | |
| "coaching_rate": round(training_metrics.coaching_rate, 4), | |
| "model": settings.white_model, | |
| "message": ( | |
| f"GRPO step {training_metrics.step} | " | |
| f"loss={training_metrics.loss:.4f} " | |
| f"win_rate={training_metrics.win_rate:.2%}" | |
| ), | |
| }) | |
| await asyncio.sleep(1.0) | |
| # ── Entry point ─────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "websocket_server:app", | |
| host=settings.host, | |
| port=settings.port, | |
| reload=False, | |
| log_level="info", | |
| ) | |