rust_coder / server /app.py
Parthiban007's picture
Upload folder using huggingface_hub
efe528e verified
"""
FastAPI application for the Rust Coder OpenEnv environment.
Entrypoint: server.app:app (see openenv.yaml and Dockerfile CMD)
Standard OpenEnv endpoints (via create_app):
POST /reset β€” start a new episode
POST /step β€” submit an action, receive observation + reward
GET /state β€” current episode state
GET /schema β€” action / observation JSON schemas
WS /ws β€” WebSocket interface
Custom endpoints:
GET /health β€” health check
GET /tasks β€” list all tasks with action schema
POST /grader?task_id=X β€” programmatic grader for task X
"""
import os
import logging
from dotenv import load_dotenv
from fastapi import HTTPException
from openenv.core.env_server.http_server import create_app
from models import RustCoderAction, RustCoderObservation, TaskInfo
from server.rust_coder_environment import RustCoderEnvironment
load_dotenv()
_LOG_LEVEL = (os.getenv("LOG_LEVEL") or "INFO").upper()
logging.basicConfig(
level=getattr(logging, _LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
app = create_app(
RustCoderEnvironment,
RustCoderAction,
RustCoderObservation,
env_name="rust_coder",
max_concurrent_envs=1,
)
# ---------------------------------------------------------------------------
# Task registry
# ---------------------------------------------------------------------------
TASK_REGISTRY = {
"task_1": {
"index": 0,
"difficulty": "easy",
"description": "Fix enum variant mismatches and incomplete match arms in a CLI argument parser.",
"success_threshold": 0.7,
},
"task_2": {
"index": 1,
"difficulty": "easy",
"description": "Resolve mutable/immutable borrow conflicts in a string collection processor.",
"success_threshold": 0.7,
},
"task_3": {
"index": 2,
"difficulty": "medium",
"description": "Add correct lifetime annotations so a struct holding references compiles and works.",
"success_threshold": 0.6,
},
"task_4": {
"index": 3,
"difficulty": "medium",
"description": "Fix off-by-one errors and logic bugs in a financial calculation module.",
"success_threshold": 0.6,
},
"task_5": {
"index": 4,
"difficulty": "medium",
"description": "Implement a safe singly-linked list with push, pop, and peek operations.",
"success_threshold": 0.6,
},
"task_6": {
"index": 5,
"difficulty": "hard",
"description": "Identify and fix deadlock conditions in a multi-threaded producer-consumer pattern.",
"success_threshold": 0.5,
},
"task_7": {
"index": 6,
"difficulty": "hard",
"description": "Fix async/await borrowing conflicts in a concurrent file processor.",
"success_threshold": 0.5,
},
"task_8": {
"index": 7,
"difficulty": "hard",
"description": "Write safe Rust wrappers around unsafe FFI calls to a C library.",
"success_threshold": 0.5,
},
"task_9": {
"index": 8,
"difficulty": "hard",
"description": "Optimize a data pipeline using iterators and avoiding unnecessary allocations.",
"success_threshold": 0.5,
},
"task_10": {
"index": 9,
"difficulty": "hard",
"description": "Fix memory leak patterns and ensure correct Drop implementations.",
"success_threshold": 0.4,
},
}
TASK_IDS = list(TASK_REGISTRY.keys())
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/tasks")
async def list_tasks():
"""
Return all available tasks.
The competition platform enumerates this endpoint to discover tasks.
Each entry includes task_id, difficulty, description, and action_schema.
"""
return [
TaskInfo(
task_id=task_id,
difficulty=task["difficulty"],
description=task["description"],
action_schema=RustCoderAction.model_json_schema(),
)
for task_id, task in TASK_REGISTRY.items()
]
@app.post("/grader")
async def grader(task_id: str, action: RustCoderAction):
"""
Programmatic grader for a specific task.
Usage: POST /grader?task_id=task_1
Body: {"code": "<rust source code>"}
Scores are strictly in the open interval (0, 1):
- Minimum 0.01 β€” floor for any submission (even empty/non-compiling)
- Maximum 0.99 β€” ceiling so no submission hits the theoretical perfect
- Weighted: Compilation(40%) + Correctness(20%) + Coverage(20%) +
Elegance(10%) + Efficiency(10%)
"""
task_meta = TASK_REGISTRY.get(task_id)
if task_meta is None:
raise HTTPException(
status_code=404,
detail=f"Unknown task_id '{task_id}'. Valid IDs: {TASK_IDS}",
)
_EMPTY_BREAKDOWN = {
"compilation": 0.0,
"correctness": 0.0,
"coverage": 0.0,
"elegance": 0.0,
"efficiency": 0.0,
}
# Fast path: empty code β€” skip compilation + avoid triggering auto-LLM
if not action.code.strip():
return {
"task_id": task_id,
"score": 0.01,
"passed": 0,
"total": 1,
"metric": "rust_code_quality",
"reward_breakdown": _EMPTY_BREAKDOWN,
"compilation_success": False,
"compilation_output": "No code submitted.",
"test_results": [],
}
# Full evaluation path
env = RustCoderEnvironment()
env.reset(start_index=task_meta["index"])
obs = env.step(action)
# Explicit None check β€” 0.0 is falsy but a valid reward
raw_score = float(obs.reward if obs.reward is not None else 0.0)
# Enforce strictly open interval (0, 1) β€” never exactly 0.0 or 1.0
score = round(max(0.01, min(0.99, raw_score)), 4)
success = score >= task_meta["success_threshold"]
return {
"task_id": task_id,
"score": score,
"passed": 1 if success else 0,
"total": 1,
"metric": "rust_code_quality",
"reward_breakdown": obs.reward_breakdown,
"compilation_success": obs.compilation_success,
"compilation_output": obs.compilation_output,
"test_results": obs.test_results,
}
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()