""" CodeSensei — FastAPI Server (OpenEnv Protocol). Exposes the CodeDebugEnvironment as an HTTP + WebSocket API following the OpenEnv standard interface pattern. """ from __future__ import annotations import json import uuid from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from env.server.environment import CodeDebugEnvironment from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState # --- Metadata Definitions --- TASKS_METADATA = [ { "id": "debug-add_numbers", "name": "debug-add_numbers", "description": "Fix subtraction -> addition bug", "max_steps": 6, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, { "id": "debug-find_max", "name": "debug-find_max", "description": "Fix < -> > comparison bug", "max_steps": 6, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, { "id": "debug-reverse_string", "name": "debug-reverse_string", "description": "Fix slice -> reverse bug", "max_steps": 6, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, { "id": "dummy-task-alpha", "name": "Standard Debug Alpha", "description": "Baseline validation task for model compliance", "max_steps": 3, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, { "id": "dummy-task-beta", "name": "Standard Debug Beta", "description": "Secondary validation task for model compliance", "max_steps": 3, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, { "id": "dummy-task-gamma", "name": "Standard Debug Gamma", "description": "Tertiary validation task for model compliance", "max_steps": 3, "reward_range": [0.01, 0.99], "grader": "tasks.grader:grade", }, ] # --- Pydantic request/response schemas --- class ResetRequest(BaseModel): session_id: str = "" task: Optional[str] = None # task name from openenv.yaml e.g. "debug-add_numbers" class StepRequest(BaseModel): proposed_fix: str session_id: str class StateRequest(BaseModel): session_id: str # --- App lifecycle --- env: CodeDebugEnvironment @asynccontextmanager async def lifespan(app: FastAPI): global env env = CodeDebugEnvironment() print("🧠 CodeSensei environment loaded") print(f"📦 Bug dataset: {len(env._sessions)} active sessions") yield print("👋 CodeSensei shutting down") # --- FastAPI app --- app = FastAPI( title="CodeSensei - CodeDebug OpenEnv", description="RL environment for teaching LLMs to debug Python code", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- HTTP Endpoints (standard OpenEnv) --- @app.post("/reset") async def reset(request: Optional[ResetRequest] = None): """Start a new debugging episode.""" session_id = request.session_id if request else str(uuid.uuid4()) task = request.task if request else None obs = env.reset(session_id=session_id, task=task) return _obs_to_dict(obs) @app.post("/step") async def step(request: StepRequest): """Submit a proposed code fix.""" action = CodeDebugAction( proposed_fix=request.proposed_fix, session_id=request.session_id, ) obs = env.step(action) return _obs_to_dict(obs) @app.get("/state") async def get_state(session_id: str): """Get current episode state.""" state = env.get_state(session_id) if state is None: return {"error": "Session not found", "session_id": session_id} return state.model_dump() @app.get("/health") async def health(): """Health check endpoint.""" return {"status": "healthy", "service": "codesensei-env"} @app.get("/metadata") async def get_metadata(): """Returns environment and task metadata for OpenEnv validation.""" return { "name": "codesensei", "version": "1.0.0", "description": "GRPO-trained LLM code debugging environment", "tasks": TASKS_METADATA, } @app.get("/schema") async def get_schema(): """Returns the JSON schemas for project models.""" return { "action": CodeDebugAction.model_json_schema(), "observation": CodeDebugObservation.model_json_schema(), "state": CodeDebugState.model_json_schema(), } @app.get("/") async def root(): """Root endpoint with API info.""" return { "name": "CodeSensei - CodeDebug OpenEnv", "version": "1.0.0", "endpoints": { "POST /reset": "Start a new episode", "POST /step": "Submit a code fix", "GET /state": "Get episode state", "GET /metadata": "Environment & task metadata", "GET /schema": "JSON schemas for models", "WS /ws": "WebSocket interface (recommended)", "GET /health": "Health check", }, } # --- WebSocket Endpoint (primary for HF Spaces) --- @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """WebSocket interface for training — required for HF Spaces. Protocol: - Client sends JSON messages: {"type": "reset"} or {"type": "step", "proposed_fix": "..."} - Server responds with JSON observation or state. """ await websocket.accept() session_id = str(uuid.uuid4()) try: while True: raw = await websocket.receive_text() msg = json.loads(raw) msg_type = msg.get("type", "") if msg_type == "reset": session_id = msg.get("session_id", str(uuid.uuid4())) task = msg.get("task", None) obs = env.reset(session_id=session_id, task=task) response = _obs_to_dict(obs) response["session_id"] = session_id response["type"] = "reset_response" await websocket.send_json(response) elif msg_type == "step": action = CodeDebugAction( proposed_fix=msg.get("proposed_fix", ""), session_id=session_id, ) obs = env.step(action) response = _obs_to_dict(obs) response["session_id"] = session_id response["type"] = "step_response" await websocket.send_json(response) elif msg_type == "state": state = env.get_state(session_id) if state: response = state.model_dump() response["type"] = "state_response" else: response = {"type": "error", "error": "No active session"} await websocket.send_json(response) else: await websocket.send_json({ "type": "error", "error": f"Unknown message type: {msg_type}", "valid_types": ["reset", "step", "state"], }) except WebSocketDisconnect: # Clean up session on disconnect if session_id in env._sessions: del env._sessions[session_id] except json.JSONDecodeError: await websocket.send_json({ "type": "error", "error": "Invalid JSON", }) # --- Helpers --- def _obs_to_dict(obs) -> Dict[str, Any]: """Convert an observation to a JSON-serializable dict.""" return obs.model_dump()