rust_coder / server /app.py
Parthiban007's picture
Upload folder using huggingface_hub
8a096e2 verified
raw
history blame
5.18 kB
"""
FastAPI application for the Rust Coder OpenEnv environment.
This module is the Hugging Face Space entrypoint (see `openenv.yaml` and Docker `CMD`).
Endpoints (provided by OpenEnv `create_app`):
- POST /reset
- POST /step
- GET /state
- GET /schema
- WS /ws
Additional endpoints:
- GET /health
- GET /tasks — list all tasks with grader metadata
- POST /grade/{task_id} — grade a code submission for a specific task
"""
import os
import logging
from dotenv import load_dotenv
from fastapi import HTTPException
from pydantic import BaseModel
from openenv.core.env_server.http_server import create_app
from models import RustCoderAction, RustCoderObservation
from server.rust_coder_environment import RustCoderEnvironment
load_dotenv()
_LOG_LEVEL = (os.getenv("LOG_LEVEL") or "INFO").upper()
logging.basicConfig(
level=getattr(logging, _LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
app = create_app(
RustCoderEnvironment,
RustCoderAction,
RustCoderObservation,
env_name="rust_coder",
max_concurrent_envs=1,
)
# ---------------------------------------------------------------------------
# Task metadata — mirrors openenv.yaml tasks section
# ---------------------------------------------------------------------------
_TASK_REGISTRY = [
{"id": "task_1", "index": 0, "title": "Broken CLI Argument Parser", "difficulty": "easy", "success_threshold": 0.7},
{"id": "task_2", "index": 1, "title": "Conflicting Borrows in Collection Processing", "difficulty": "easy", "success_threshold": 0.7},
{"id": "task_3", "index": 2, "title": "Lifetime Annotations", "difficulty": "medium", "success_threshold": 0.6},
{"id": "task_4", "index": 3, "title": "Business Logic Bug", "difficulty": "medium", "success_threshold": 0.6},
{"id": "task_5", "index": 4, "title": "Linked List Management", "difficulty": "medium", "success_threshold": 0.6},
{"id": "task_6", "index": 5, "title": "Multi-threaded Deadlocks", "difficulty": "hard", "success_threshold": 0.5},
{"id": "task_7", "index": 6, "title": "Async Borrowing", "difficulty": "hard", "success_threshold": 0.5},
{"id": "task_8", "index": 7, "title": "Unsafe FFI Integration", "difficulty": "hard", "success_threshold": 0.5},
{"id": "task_9", "index": 8, "title": "Inefficient Data Pipelines", "difficulty": "hard", "success_threshold": 0.5},
{"id": "task_10", "index": 9, "title": "Memory Leak Prevention", "difficulty": "hard", "success_threshold": 0.4},
]
_TASK_BY_ID = {t["id"]: t for t in _TASK_REGISTRY}
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/tasks")
async def list_tasks():
"""Return the list of all tasks with their grader metadata."""
tasks_out = []
for t in _TASK_REGISTRY:
tasks_out.append({
"id": t["id"],
"title": t["title"],
"difficulty": t["difficulty"],
"grader": {
"type": "programmatic",
"endpoint": f"/grade/{t['id']}",
"success_threshold": t["success_threshold"],
"reward_range": [0.0, 1.0],
"description": "Compilation(40%) + Correctness(20%) + Coverage(20%) + Elegance(10%) + Efficiency(10%)",
},
})
return {"tasks": tasks_out, "total": len(tasks_out)}
class GradeRequest(BaseModel):
code: str = ""
@app.post("/grade/{task_id}")
async def grade_task(task_id: str, request: GradeRequest):
"""
Grade a Rust code submission for a specific task.
Returns a score in [0.0, 1.0] with detailed breakdown.
This is the programmatic grader endpoint referenced in openenv.yaml.
"""
task_meta = _TASK_BY_ID.get(task_id)
if task_meta is None:
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
env = RustCoderEnvironment()
# Reset to the specific task
env.reset(start_index=task_meta["index"])
# Submit the code
action = RustCoderAction(code=request.code)
obs = env.step(action)
score = float(obs.reward) if obs.reward is not None else 0.0
score = max(0.0, min(1.0, score))
success = score >= task_meta["success_threshold"]
return {
"task_id": task_id,
"score": round(score, 4),
"success": success,
"success_threshold": task_meta["success_threshold"],
"reward_breakdown": obs.reward_breakdown,
"compilation_success": obs.compilation_success,
"compilation_output": obs.compilation_output,
"test_results": obs.test_results,
}
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()