| | """ |
| | websocket_server.py (v2 — OpenEnv + Dual Agent) |
| | ───────────────────────────────────────────────── |
| | FastAPI application that: |
| | 1. Loads TWO models at startup: |
| | White → Qwen/Qwen2.5-0.5B-Instruct |
| | Black → meta-llama/Llama-3.2-1B-Instruct |
| | 2. Registers the OpenEnv 0.1 HTTP API at /env/* |
| | 3. Runs continuous self-play games (white=Qwen vs black=Llama). |
| | 4. Streams every game event to all connected WebSocket clients. |
| | 5. Runs GRPO on the WHITE model only (Qwen) — Llama acts as fixed opponent. |
| | |
| | OpenEnv endpoints (for external RL trainers): |
| | POST /env/reset start a new episode |
| | POST /env/step apply one action |
| | GET /env/state inspect current state |
| | GET /env/env_info environment metadata (HF Hub discoverability) |
| | |
| | WebSocket endpoint: /ws |
| | Health check: /health |
| | API docs: /docs |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | import logging |
| | import time |
| | from contextlib import asynccontextmanager |
| | from typing import Any |
| |
|
| | import uvicorn |
| | from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| | from fastapi.middleware.cors import CORSMiddleware |
| |
|
| | from settings import settings |
| | from chess_engine import ChessEngine |
| | from agents.model_agent import ModelAgent |
| | from grpo_trainer import GRPOTrainer |
| | from openenv.router import router as openenv_router, init_env |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | connected_clients: set[WebSocket] = set() |
| | paused = False |
| | game_count = 0 |
| | wallet_white = settings.starting_wallet |
| | wallet_black = settings.starting_wallet |
| |
|
| | |
| | white_agent: ModelAgent | None = None |
| | black_agent: ModelAgent | None = None |
| | trainer: GRPOTrainer | None = None |
| |
|
| |
|
| | |
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | global white_agent, black_agent, trainer |
| |
|
| | logger.info("Loading WHITE model (%s) …", settings.white_model) |
| | white_agent = ModelAgent(settings.white_model).load() |
| |
|
| | logger.info("Loading BLACK model (%s) …", settings.black_model) |
| | black_agent = ModelAgent(settings.black_model).load() |
| |
|
| | |
| | trainer = GRPOTrainer(white_agent.model, white_agent.tokenizer) |
| |
|
| | |
| | init_env( |
| | white_model_id=settings.white_model, |
| | black_model_id=settings.black_model, |
| | ) |
| |
|
| | logger.info("Both models ready. Starting auto-play loop …") |
| | asyncio.create_task(game_loop()) |
| | yield |
| | logger.info("Shutting down.") |
| |
|
| |
|
| | app = FastAPI( |
| | title="ChessEcon", |
| | description=( |
| | "Multi-Agent Chess Economy — OpenEnv 0.1 compliant environment. " |
| | "White: Qwen2.5-0.5B | Black: Llama-3.2-1B | Training: GRPO" |
| | ), |
| | version="2.0.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | app.include_router(openenv_router) |
| |
|
| |
|
| | |
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "ok", |
| | "service": "chessecon", |
| | "version": "2.0.0", |
| | "openenv_version": "0.1", |
| | "white_model": settings.white_model, |
| | "black_model": settings.black_model, |
| | "ws_clients": len(connected_clients), |
| | "games_played": game_count, |
| | } |
| |
|
| |
|
| | |
| | @app.websocket("/ws") |
| | async def websocket_endpoint(ws: WebSocket): |
| | await ws.accept() |
| | connected_clients.add(ws) |
| | logger.info("WS client connected (%d total)", len(connected_clients)) |
| | |
| | try: |
| | await ws.send_text(json.dumps({ |
| | "type": "status", |
| | "data": { |
| | "game_id": game_count, |
| | "wallet_white": round(wallet_white, 2), |
| | "wallet_black": round(wallet_black, 2), |
| | "grpo_step": trainer._step if trainer else 0, |
| | "message": f"Connected — game #{game_count} in progress", |
| | } |
| | })) |
| | except Exception: |
| | pass |
| | try: |
| | while True: |
| | raw = await ws.receive_text() |
| | try: |
| | msg = json.loads(raw) |
| | await handle_client_message(ws, msg) |
| | except json.JSONDecodeError: |
| | pass |
| | except WebSocketDisconnect: |
| | connected_clients.discard(ws) |
| | logger.info("WS client disconnected (%d total)", len(connected_clients)) |
| |
|
| |
|
| | async def handle_client_message(ws: WebSocket, msg: dict): |
| | global paused |
| | action = msg.get("action", "") |
| | if action == "ping": |
| | await ws.send_text(json.dumps({"type": "pong", "data": {}})) |
| | elif action == "pause": |
| | paused = True |
| | logger.info("Game loop paused") |
| | elif action == "resume": |
| | paused = False |
| | logger.info("Game loop resumed") |
| |
|
| |
|
| | |
| | async def broadcast(event_type: str, data: dict[str, Any]): |
| | if not connected_clients: |
| | return |
| | payload = json.dumps({"type": event_type, "data": data}) |
| | dead: set[WebSocket] = set() |
| | for ws in list(connected_clients): |
| | try: |
| | await ws.send_text(payload) |
| | except Exception: |
| | dead.add(ws) |
| | connected_clients.difference_update(dead) |
| |
|
| |
|
| | |
| | async def game_loop(): |
| | global game_count, wallet_white, wallet_black, paused |
| |
|
| | while True: |
| | while paused: |
| | await asyncio.sleep(0.5) |
| |
|
| | game_count += 1 |
| | engine = ChessEngine() |
| |
|
| | wallet_white -= settings.entry_fee |
| | wallet_black -= settings.entry_fee |
| | prize_pool = settings.entry_fee * 2 * settings.prize_pool_fraction |
| |
|
| | await broadcast("game_start", { |
| | "game_id": game_count, |
| | "wallet_white": round(wallet_white, 2), |
| | "wallet_black": round(wallet_black, 2), |
| | "prize_pool": round(prize_pool, 2), |
| | "white_model": settings.white_model, |
| | "black_model": settings.black_model, |
| | "message": ( |
| | f"Game #{game_count} — " |
| | f"Qwen(W) vs Llama(B) — " |
| | f"Prize pool: {prize_pool:.1f} units" |
| | ), |
| | }) |
| |
|
| | trainer.start_game("white") |
| | move_history: list[str] = [] |
| |
|
| | |
| | while not engine.is_game_over and engine.move_number <= settings.max_moves: |
| | while paused: |
| | await asyncio.sleep(0.5) |
| |
|
| | current_color = engine.turn |
| | |
| | active_agent = white_agent if current_color == "white" else black_agent |
| |
|
| | san, log_prob = await asyncio.get_event_loop().run_in_executor( |
| | None, |
| | active_agent.get_move, |
| | engine, current_color, move_history, |
| | ) |
| |
|
| | |
| | if current_color == "white": |
| | ref_log_prob = await asyncio.get_event_loop().run_in_executor( |
| | None, |
| | white_agent.get_move_log_prob_only, |
| | engine, current_color, move_history, san, |
| | ) |
| | else: |
| | ref_log_prob = log_prob |
| |
|
| | uci = engine.apply_move_san(san) |
| | if uci is None: |
| | fallback = engine.random_legal_move_san() |
| | if fallback is None: |
| | break |
| | san = fallback |
| | uci = engine.apply_move_san(san) or "" |
| | log_prob = 0.0 |
| | ref_log_prob = 0.0 |
| |
|
| | trainer.record_move(log_prob, ref_log_prob) |
| | move_history.append(san) |
| |
|
| | await broadcast("move", { |
| | "game_id": game_count, |
| | "player": current_color, |
| | "model": settings.white_model if current_color == "white" else settings.black_model, |
| | "move": san, |
| | "uci": uci, |
| | "fen": engine.fen, |
| | "move_number": engine.move_number, |
| | "turn": engine.turn, |
| | "wallet_white": round(wallet_white, 2), |
| | "wallet_black": round(wallet_black, 2), |
| | "message": f"{'Qwen' if current_color == 'white' else 'Llama'} plays {san}", |
| | }) |
| |
|
| | await asyncio.sleep(settings.move_delay) |
| |
|
| | |
| | |
| | if engine.result: |
| | result = engine.result |
| | else: |
| | |
| | piece_values = {1: 1, 2: 3, 3: 3, 4: 5, 5: 9} |
| | import chess as _chess |
| | white_mat = sum( |
| | piece_values.get(pt, 0) |
| | for pt in range(1, 6) |
| | for _ in engine.board.pieces(pt, _chess.WHITE) |
| | ) |
| | black_mat = sum( |
| | piece_values.get(pt, 0) |
| | for pt in range(1, 6) |
| | for _ in engine.board.pieces(pt, _chess.BLACK) |
| | ) |
| | result = '1-0' if white_mat >= black_mat else '0-1' |
| | white_reward = 1.0 if result == "1-0" else (-1.0 if result == "0-1" else 0.0) |
| | black_reward = 1.0 if result == "0-1" else (-1.0 if result == "1-0" else 0.0) |
| |
|
| | if result == "1-0": |
| | wallet_white += prize_pool |
| | elif result == "0-1": |
| | wallet_black += prize_pool |
| | else: |
| | wallet_white += prize_pool / 2 |
| | wallet_black += prize_pool / 2 |
| |
|
| | white_pnl = ( |
| | prize_pool if result == "1-0" |
| | else prize_pool / 2 if result == "1/2-1/2" |
| | else 0 |
| | ) - settings.entry_fee |
| | black_pnl = ( |
| | prize_pool if result == "0-1" |
| | else prize_pool / 2 if result == "1/2-1/2" |
| | else 0 |
| | ) - settings.entry_fee |
| |
|
| | await broadcast("game_end", { |
| | "game_id": game_count, |
| | "result": result, |
| | "reward": white_reward, |
| | "wallet_white": round(wallet_white, 2), |
| | "wallet_black": round(wallet_black, 2), |
| | "prize_income": round( |
| | prize_pool if result == "1-0" |
| | else prize_pool / 2 if result == "1/2-1/2" |
| | else 0, 2 |
| | ), |
| | "coaching_cost": 0, |
| | "entry_fee": settings.entry_fee, |
| | "net_pnl_white": round(white_pnl, 2), |
| | "net_pnl_black": round(black_pnl, 2), |
| | "move_count": len(move_history), |
| | "white_model": settings.white_model, |
| | "black_model": settings.black_model, |
| | "message": f"Game #{game_count} ended — {result}", |
| | }) |
| |
|
| | |
| | training_metrics = trainer.end_game( |
| | reward=white_reward, |
| | profit=white_pnl, |
| | coaching_calls=0, |
| | ) |
| |
|
| | if training_metrics is not None: |
| | await broadcast("training_step", { |
| | "step": training_metrics.step, |
| | "loss": round(training_metrics.loss, 6), |
| | "reward": round(training_metrics.policy_reward, 4), |
| | "kl_div": round(training_metrics.kl_div, 6), |
| | "win_rate": round(training_metrics.win_rate, 4), |
| | "avg_profit": round(training_metrics.avg_profit, 4), |
| | "coaching_rate": round(training_metrics.coaching_rate, 4), |
| | "model": settings.white_model, |
| | "message": ( |
| | f"GRPO step {training_metrics.step} | " |
| | f"loss={training_metrics.loss:.4f} " |
| | f"win_rate={training_metrics.win_rate:.2%}" |
| | ), |
| | }) |
| |
|
| | await asyncio.sleep(1.0) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | uvicorn.run( |
| | "websocket_server:app", |
| | host=settings.host, |
| | port=settings.port, |
| | reload=False, |
| | log_level="info", |
| | ) |
| |
|