vineetshukla.work@gmail.com
fix: resolve 500 error on /schema and add extra validation tasks
52fe477
"""
CodeSensei — FastAPI Server (OpenEnv Protocol).
Exposes the CodeDebugEnvironment as an HTTP + WebSocket API
following the OpenEnv standard interface pattern.
"""
from __future__ import annotations
import json
import uuid
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from env.server.environment import CodeDebugEnvironment
from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState
# --- Metadata Definitions ---
TASKS_METADATA = [
{
"id": "debug-add_numbers",
"name": "debug-add_numbers",
"description": "Fix subtraction -> addition bug",
"max_steps": 6,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
{
"id": "debug-find_max",
"name": "debug-find_max",
"description": "Fix < -> > comparison bug",
"max_steps": 6,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
{
"id": "debug-reverse_string",
"name": "debug-reverse_string",
"description": "Fix slice -> reverse bug",
"max_steps": 6,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
{
"id": "dummy-task-alpha",
"name": "Standard Debug Alpha",
"description": "Baseline validation task for model compliance",
"max_steps": 3,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
{
"id": "dummy-task-beta",
"name": "Standard Debug Beta",
"description": "Secondary validation task for model compliance",
"max_steps": 3,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
{
"id": "dummy-task-gamma",
"name": "Standard Debug Gamma",
"description": "Tertiary validation task for model compliance",
"max_steps": 3,
"reward_range": [0.01, 0.99],
"grader": "tasks.grader:grade",
},
]
# --- Pydantic request/response schemas ---
class ResetRequest(BaseModel):
session_id: str = ""
task: Optional[str] = None # task name from openenv.yaml e.g. "debug-add_numbers"
class StepRequest(BaseModel):
proposed_fix: str
session_id: str
class StateRequest(BaseModel):
session_id: str
# --- App lifecycle ---
env: CodeDebugEnvironment
@asynccontextmanager
async def lifespan(app: FastAPI):
global env
env = CodeDebugEnvironment()
print("🧠 CodeSensei environment loaded")
print(f"📦 Bug dataset: {len(env._sessions)} active sessions")
yield
print("👋 CodeSensei shutting down")
# --- FastAPI app ---
app = FastAPI(
title="CodeSensei - CodeDebug OpenEnv",
description="RL environment for teaching LLMs to debug Python code",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- HTTP Endpoints (standard OpenEnv) ---
@app.post("/reset")
async def reset(request: Optional[ResetRequest] = None):
"""Start a new debugging episode."""
session_id = request.session_id if request else str(uuid.uuid4())
task = request.task if request else None
obs = env.reset(session_id=session_id, task=task)
return _obs_to_dict(obs)
@app.post("/step")
async def step(request: StepRequest):
"""Submit a proposed code fix."""
action = CodeDebugAction(
proposed_fix=request.proposed_fix,
session_id=request.session_id,
)
obs = env.step(action)
return _obs_to_dict(obs)
@app.get("/state")
async def get_state(session_id: str):
"""Get current episode state."""
state = env.get_state(session_id)
if state is None:
return {"error": "Session not found", "session_id": session_id}
return state.model_dump()
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "service": "codesensei-env"}
@app.get("/metadata")
async def get_metadata():
"""Returns environment and task metadata for OpenEnv validation."""
return {
"name": "codesensei",
"version": "1.0.0",
"description": "GRPO-trained LLM code debugging environment",
"tasks": TASKS_METADATA,
}
@app.get("/schema")
async def get_schema():
"""Returns the JSON schemas for project models."""
return {
"action": CodeDebugAction.model_json_schema(),
"observation": CodeDebugObservation.model_json_schema(),
"state": CodeDebugState.model_json_schema(),
}
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"name": "CodeSensei - CodeDebug OpenEnv",
"version": "1.0.0",
"endpoints": {
"POST /reset": "Start a new episode",
"POST /step": "Submit a code fix",
"GET /state": "Get episode state",
"GET /metadata": "Environment & task metadata",
"GET /schema": "JSON schemas for models",
"WS /ws": "WebSocket interface (recommended)",
"GET /health": "Health check",
},
}
# --- WebSocket Endpoint (primary for HF Spaces) ---
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket interface for training — required for HF Spaces.
Protocol:
- Client sends JSON messages: {"type": "reset"} or {"type": "step", "proposed_fix": "..."}
- Server responds with JSON observation or state.
"""
await websocket.accept()
session_id = str(uuid.uuid4())
try:
while True:
raw = await websocket.receive_text()
msg = json.loads(raw)
msg_type = msg.get("type", "")
if msg_type == "reset":
session_id = msg.get("session_id", str(uuid.uuid4()))
task = msg.get("task", None)
obs = env.reset(session_id=session_id, task=task)
response = _obs_to_dict(obs)
response["session_id"] = session_id
response["type"] = "reset_response"
await websocket.send_json(response)
elif msg_type == "step":
action = CodeDebugAction(
proposed_fix=msg.get("proposed_fix", ""),
session_id=session_id,
)
obs = env.step(action)
response = _obs_to_dict(obs)
response["session_id"] = session_id
response["type"] = "step_response"
await websocket.send_json(response)
elif msg_type == "state":
state = env.get_state(session_id)
if state:
response = state.model_dump()
response["type"] = "state_response"
else:
response = {"type": "error", "error": "No active session"}
await websocket.send_json(response)
else:
await websocket.send_json({
"type": "error",
"error": f"Unknown message type: {msg_type}",
"valid_types": ["reset", "step", "state"],
})
except WebSocketDisconnect:
# Clean up session on disconnect
if session_id in env._sessions:
del env._sessions[session_id]
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"error": "Invalid JSON",
})
# --- Helpers ---
def _obs_to_dict(obs) -> Dict[str, Any]:
"""Convert an observation to a JSON-serializable dict."""
return obs.model_dump()