Spaces:
Sleeping
Sleeping
vineetshukla.work@gmail.com
fix: resolve 500 error on /schema and add extra validation tasks
52fe477 | """ | |
| CodeSensei — FastAPI Server (OpenEnv Protocol). | |
| Exposes the CodeDebugEnvironment as an HTTP + WebSocket API | |
| following the OpenEnv standard interface pattern. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from env.server.environment import CodeDebugEnvironment | |
| from env.models import CodeDebugAction, CodeDebugObservation, CodeDebugState | |
| # --- Metadata Definitions --- | |
| TASKS_METADATA = [ | |
| { | |
| "id": "debug-add_numbers", | |
| "name": "debug-add_numbers", | |
| "description": "Fix subtraction -> addition bug", | |
| "max_steps": 6, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| { | |
| "id": "debug-find_max", | |
| "name": "debug-find_max", | |
| "description": "Fix < -> > comparison bug", | |
| "max_steps": 6, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| { | |
| "id": "debug-reverse_string", | |
| "name": "debug-reverse_string", | |
| "description": "Fix slice -> reverse bug", | |
| "max_steps": 6, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| { | |
| "id": "dummy-task-alpha", | |
| "name": "Standard Debug Alpha", | |
| "description": "Baseline validation task for model compliance", | |
| "max_steps": 3, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| { | |
| "id": "dummy-task-beta", | |
| "name": "Standard Debug Beta", | |
| "description": "Secondary validation task for model compliance", | |
| "max_steps": 3, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| { | |
| "id": "dummy-task-gamma", | |
| "name": "Standard Debug Gamma", | |
| "description": "Tertiary validation task for model compliance", | |
| "max_steps": 3, | |
| "reward_range": [0.01, 0.99], | |
| "grader": "tasks.grader:grade", | |
| }, | |
| ] | |
| # --- Pydantic request/response schemas --- | |
| class ResetRequest(BaseModel): | |
| session_id: str = "" | |
| task: Optional[str] = None # task name from openenv.yaml e.g. "debug-add_numbers" | |
| class StepRequest(BaseModel): | |
| proposed_fix: str | |
| session_id: str | |
| class StateRequest(BaseModel): | |
| session_id: str | |
| # --- App lifecycle --- | |
| env: CodeDebugEnvironment | |
| async def lifespan(app: FastAPI): | |
| global env | |
| env = CodeDebugEnvironment() | |
| print("🧠 CodeSensei environment loaded") | |
| print(f"📦 Bug dataset: {len(env._sessions)} active sessions") | |
| yield | |
| print("👋 CodeSensei shutting down") | |
| # --- FastAPI app --- | |
| app = FastAPI( | |
| title="CodeSensei - CodeDebug OpenEnv", | |
| description="RL environment for teaching LLMs to debug Python code", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- HTTP Endpoints (standard OpenEnv) --- | |
| async def reset(request: Optional[ResetRequest] = None): | |
| """Start a new debugging episode.""" | |
| session_id = request.session_id if request else str(uuid.uuid4()) | |
| task = request.task if request else None | |
| obs = env.reset(session_id=session_id, task=task) | |
| return _obs_to_dict(obs) | |
| async def step(request: StepRequest): | |
| """Submit a proposed code fix.""" | |
| action = CodeDebugAction( | |
| proposed_fix=request.proposed_fix, | |
| session_id=request.session_id, | |
| ) | |
| obs = env.step(action) | |
| return _obs_to_dict(obs) | |
| async def get_state(session_id: str): | |
| """Get current episode state.""" | |
| state = env.get_state(session_id) | |
| if state is None: | |
| return {"error": "Session not found", "session_id": session_id} | |
| return state.model_dump() | |
| async def health(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy", "service": "codesensei-env"} | |
| async def get_metadata(): | |
| """Returns environment and task metadata for OpenEnv validation.""" | |
| return { | |
| "name": "codesensei", | |
| "version": "1.0.0", | |
| "description": "GRPO-trained LLM code debugging environment", | |
| "tasks": TASKS_METADATA, | |
| } | |
| async def get_schema(): | |
| """Returns the JSON schemas for project models.""" | |
| return { | |
| "action": CodeDebugAction.model_json_schema(), | |
| "observation": CodeDebugObservation.model_json_schema(), | |
| "state": CodeDebugState.model_json_schema(), | |
| } | |
| async def root(): | |
| """Root endpoint with API info.""" | |
| return { | |
| "name": "CodeSensei - CodeDebug OpenEnv", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "POST /reset": "Start a new episode", | |
| "POST /step": "Submit a code fix", | |
| "GET /state": "Get episode state", | |
| "GET /metadata": "Environment & task metadata", | |
| "GET /schema": "JSON schemas for models", | |
| "WS /ws": "WebSocket interface (recommended)", | |
| "GET /health": "Health check", | |
| }, | |
| } | |
| # --- WebSocket Endpoint (primary for HF Spaces) --- | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket interface for training — required for HF Spaces. | |
| Protocol: | |
| - Client sends JSON messages: {"type": "reset"} or {"type": "step", "proposed_fix": "..."} | |
| - Server responds with JSON observation or state. | |
| """ | |
| await websocket.accept() | |
| session_id = str(uuid.uuid4()) | |
| try: | |
| while True: | |
| raw = await websocket.receive_text() | |
| msg = json.loads(raw) | |
| msg_type = msg.get("type", "") | |
| if msg_type == "reset": | |
| session_id = msg.get("session_id", str(uuid.uuid4())) | |
| task = msg.get("task", None) | |
| obs = env.reset(session_id=session_id, task=task) | |
| response = _obs_to_dict(obs) | |
| response["session_id"] = session_id | |
| response["type"] = "reset_response" | |
| await websocket.send_json(response) | |
| elif msg_type == "step": | |
| action = CodeDebugAction( | |
| proposed_fix=msg.get("proposed_fix", ""), | |
| session_id=session_id, | |
| ) | |
| obs = env.step(action) | |
| response = _obs_to_dict(obs) | |
| response["session_id"] = session_id | |
| response["type"] = "step_response" | |
| await websocket.send_json(response) | |
| elif msg_type == "state": | |
| state = env.get_state(session_id) | |
| if state: | |
| response = state.model_dump() | |
| response["type"] = "state_response" | |
| else: | |
| response = {"type": "error", "error": "No active session"} | |
| await websocket.send_json(response) | |
| else: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "error": f"Unknown message type: {msg_type}", | |
| "valid_types": ["reset", "step", "state"], | |
| }) | |
| except WebSocketDisconnect: | |
| # Clean up session on disconnect | |
| if session_id in env._sessions: | |
| del env._sessions[session_id] | |
| except json.JSONDecodeError: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "error": "Invalid JSON", | |
| }) | |
| # --- Helpers --- | |
| def _obs_to_dict(obs) -> Dict[str, Any]: | |
| """Convert an observation to a JSON-serializable dict.""" | |
| return obs.model_dump() | |