from __future__ import annotations
import argparse
import csv
import json
import subprocess
import sys
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Optional
from uuid import uuid4
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from pydantic import BaseModel
from env.adapt_env import AdaptEnvironment
from env.test_cases import load_problem_bank
from models import AdaptAction, AdaptObservation, AdaptState
from server.runtime import SpaceTrainingManager
ENV_NAME = "adapt-dsa-tutor"
ENV_DESCRIPTION = (
"Adversarial DSA Programming Tutor - RL environment for training LLMs to solve "
"algorithmic problems through adaptive curriculum and self-repair."
)
ENV_VERSION = "0.4.0"
SESSION_TTL = timedelta(minutes=30)
SESSIONS: dict[str, AdaptEnvironment] = {}
SESSION_LAST_ACCESSED: dict[str, datetime] = {}
TRAINING_MANAGER = SpaceTrainingManager()
TASKS = [
{
"name": problem["problem_id"],
"difficulty": problem["difficulty"],
"description": problem["problem"],
}
for problem in load_problem_bank()
]
app = FastAPI(title="ADAPT DSA Tutor OpenEnv", version=ENV_VERSION)
class ResetRequest(BaseModel):
session_id: Optional[str] = None
seed: Optional[int] = None
episode_id: Optional[str] = None
problem_id: Optional[str] = None
difficulty: Optional[str] = None
class TrainRequest(BaseModel):
preset: str = "overnight"
model_name: Optional[str] = None
output_dir: Optional[str] = None
dataset_size: Optional[int] = None
max_steps: Optional[int] = None
batch_size: Optional[int] = None
gradient_accumulation_steps: Optional[int] = None
num_generations: Optional[int] = None
max_seq_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
learning_rate: Optional[float] = None
lora_rank: Optional[int] = None
lora_alpha: Optional[int] = None
load_in_4bit: Optional[bool] = None
gradient_checkpointing: Optional[bool] = None
bf16: Optional[bool] = None
evaluation_episodes: Optional[int] = None
eval_max_new_tokens: Optional[int] = None
baseline_eval: Optional[bool] = None
wandb_project: Optional[str] = None
wandb_run_name: Optional[str] = None
generator_mode: Optional[str] = None
non_deterministic_generator: Optional[bool] = None
use_dataset: bool = False
dataset_name: str = "deepmind/code_contests"
dataset_max_problems: int = 5000
disable_wandb: Optional[bool] = None
trace_logging_enabled: Optional[bool] = None
checkpoint_log_interval_steps: Optional[int] = None
save_steps: Optional[int] = None
save_total_limit: Optional[int] = None
upload_checkpoints_to_hub: Optional[bool] = None
save_merged_model: Optional[bool] = None
class RunTrainedPolicyRequest(BaseModel):
problem_id: Optional[str] = None
difficulty: Optional[str] = None
max_new_tokens: int = 512
class GenerateCodeRequest(BaseModel):
problem: str
input_format: str
constraints: str
feedback: Optional[str] = None
problem_id: str = "custom_problem"
problem_type: str = "custom"
difficulty: str = "custom"
attempt_number: int = 1
max_steps: int = 1
max_new_tokens: int = 512
class RunCodeRequest(BaseModel):
code: str
stdin: str = ""
DEMO_PAGE_HTML = """
ADAPT Judge Demo
ADAPT Judge Demo
Explore the model, run generated Python against custom stdin, and inspect the latest training results
from the same Space deployment.
Ready
Section 1 - Model Playground
Paste a DSA problem statement, generate a solution, then run it with your own input.
# Generated code will appear here.
# Stdout and stderr will appear here.
Section 2 - Training Results
Live status, charted reward data, and rollout metrics fetched directly from this Space.
No reward curve is available yet. Start a run or refresh after logs are written.
| Loading | Fetching training metrics... |
"""
def _metadata() -> dict[str, Any]:
return {
"name": ENV_NAME,
"description": ENV_DESCRIPTION,
"version": ENV_VERSION,
"tasks": TASKS,
"mode": "simulation",
}
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _cleanup_sessions() -> None:
now = _utc_now()
expired = [
session_id
for session_id, last_seen in SESSION_LAST_ACCESSED.items()
if now - last_seen > SESSION_TTL
]
for session_id in expired:
SESSIONS.pop(session_id, None)
SESSION_LAST_ACCESSED.pop(session_id, None)
def _touch_session(session_id: str) -> None:
SESSION_LAST_ACCESSED[session_id] = _utc_now()
def _require_session(session_id: str) -> AdaptEnvironment:
_cleanup_sessions()
env = SESSIONS.get(session_id)
if env is None:
raise HTTPException(status_code=404, detail=f"Unknown or expired session_id: {session_id}")
_touch_session(session_id)
return env
def _float_or_none(value: Any) -> float | None:
if value is None or value == "":
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _int_or_none(value: Any) -> int | None:
if value is None or value == "":
return None
try:
return int(float(value))
except (TypeError, ValueError):
return None
def _read_json_dict(path_value: Any) -> dict[str, Any]:
if not path_value:
return {}
try:
path = Path(str(path_value))
if not path.exists() or not path.is_file():
return {}
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError, TypeError, ValueError):
return {}
return payload if isinstance(payload, dict) else {}
def _parse_reward_curve(csv_path_value: Any) -> list[dict[str, Any]]:
if not csv_path_value:
return []
try:
csv_path = Path(str(csv_path_value))
if not csv_path.exists() or not csv_path.is_file():
return []
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
reward_curve: list[dict[str, Any]] = []
for row in reader:
step = _int_or_none(row.get("step"))
if step is None:
continue
reward_curve.append(
{
"step": step,
"episode_reward": _float_or_none(row.get("episode_reward")),
"pass_rate": _float_or_none(row.get("pass_rate")),
"visible_pass_rate": _float_or_none(row.get("visible_pass_rate")),
}
)
return reward_curve
except (OSError, csv.Error):
return []
def _stringify_subprocess_output(value: str | bytes | None) -> str:
if value is None:
return ""
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return value
def _enriched_train_status() -> dict[str, Any]:
payload = TRAINING_MANAGER.status_payload()
run_summary = _read_json_dict(payload.get("run_summary_path"))
if run_summary:
rolling_metrics = run_summary.get("rolling_metrics")
if isinstance(rolling_metrics, dict) and not payload.get("rolling_metrics"):
payload["rolling_metrics"] = rolling_metrics
final_metrics = run_summary.get("final_metrics")
if isinstance(final_metrics, dict):
if not payload.get("baseline_summary") and isinstance(final_metrics.get("baseline_summary"), dict):
payload["baseline_summary"] = final_metrics["baseline_summary"]
if not payload.get("trained_summary") and isinstance(final_metrics.get("trained_summary"), dict):
payload["trained_summary"] = final_metrics["trained_summary"]
if not payload.get("timing_summary") and isinstance(final_metrics.get("timing_summary"), dict):
payload["timing_summary"] = final_metrics["timing_summary"]
payload["reward_curve"] = _parse_reward_curve(payload.get("reward_curve_csv"))
config = payload.get("config") if isinstance(payload.get("config"), dict) else {}
trained_summary = payload.get("trained_summary") if isinstance(payload.get("trained_summary"), dict) else {}
rolling_metrics = payload.get("rolling_metrics") if isinstance(payload.get("rolling_metrics"), dict) else {}
overall_accuracy = None
if config.get("baseline_eval"):
overall_accuracy = _float_or_none(trained_summary.get("overall"))
if overall_accuracy is None:
overall_accuracy = _float_or_none(rolling_metrics.get("avg_pass_rate"))
if overall_accuracy is None:
overall_accuracy = _float_or_none(payload.get("last_pass_rate"))
payload["overall_accuracy"] = overall_accuracy
return payload
@app.on_event("startup")
def startup() -> None:
TRAINING_MANAGER.load_latest_model()
@app.get("/")
def root() -> HTMLResponse:
_cleanup_sessions()
return HTMLResponse(content=DEMO_PAGE_HTML)
@app.get("/web", include_in_schema=False)
def web_root() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/web/", include_in_schema=False)
def web_root_slash() -> RedirectResponse:
return RedirectResponse(url="/", status_code=307)
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
return Response(status_code=204)
@app.get("/health")
def health() -> dict[str, Any]:
_cleanup_sessions()
return {
"status": "healthy",
"active_sessions": len(SESSIONS),
"training": TRAINING_MANAGER.status_payload()["status"],
"model_loaded": TRAINING_MANAGER.model_status_payload()["loaded"],
}
@app.get("/metadata")
def metadata() -> dict[str, Any]:
_cleanup_sessions()
return _metadata()
@app.get("/tasks")
def list_tasks() -> dict[str, Any]:
_cleanup_sessions()
return {"tasks": TASKS}
@app.get("/schema")
def schema() -> dict[str, Any]:
_cleanup_sessions()
return {
"action": AdaptAction.model_json_schema(),
"observation": AdaptObservation.model_json_schema(),
"state": AdaptState.model_json_schema(),
}
@app.get("/train/status")
def train_status() -> dict[str, Any]:
return _enriched_train_status()
@app.get("/model/status")
def model_status() -> dict[str, Any]:
return TRAINING_MANAGER.model_status_payload()
@app.post("/train")
def train(request: Optional[TrainRequest] = None) -> dict[str, Any]:
try:
return TRAINING_MANAGER.start_training((request or TrainRequest()).model_dump())
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
@app.post("/run-trained-policy")
def run_trained_policy(request: Optional[RunTrainedPolicyRequest] = None) -> dict[str, Any]:
effective_request = request or RunTrainedPolicyRequest()
try:
return TRAINING_MANAGER.run_trained_policy(
problem_id=effective_request.problem_id,
difficulty=effective_request.difficulty,
max_new_tokens=effective_request.max_new_tokens,
)
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"run-trained-policy failed: {exc}") from exc
@app.post("/run-code")
def run_code(request: RunCodeRequest) -> dict[str, str]:
temp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False, encoding="utf-8") as handle:
handle.write(request.code)
temp_path = Path(handle.name)
completed = subprocess.run(
[sys.executable, str(temp_path)],
input=request.stdin,
text=True,
capture_output=True,
timeout=5,
check=False,
)
return {
"stdout": completed.stdout,
"stderr": completed.stderr,
}
except subprocess.TimeoutExpired as exc:
stderr = _stringify_subprocess_output(exc.stderr)
timeout_message = "Execution timed out after 5 seconds."
stderr = f"{stderr.rstrip()}\n{timeout_message}".strip() if stderr else timeout_message
return {
"stdout": _stringify_subprocess_output(exc.stdout),
"stderr": stderr,
}
except OSError as exc:
raise HTTPException(status_code=500, detail=f"run-code failed: {exc}") from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"run-code failed: {exc}") from exc
finally:
if temp_path is not None:
try:
temp_path.unlink(missing_ok=True)
except OSError:
pass
@app.post("/generate-code")
def generate_code(request: GenerateCodeRequest) -> dict[str, Any]:
try:
return TRAINING_MANAGER.generate_code(
problem=request.problem,
input_format=request.input_format,
constraints=request.constraints,
feedback=request.feedback,
problem_id=request.problem_id,
problem_type=request.problem_type,
difficulty=request.difficulty,
attempt_number=request.attempt_number,
max_steps=request.max_steps,
max_new_tokens=request.max_new_tokens,
)
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"generate-code failed: {exc}") from exc
@app.post("/mcp")
def mcp(payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
_cleanup_sessions()
return {
"jsonrpc": "2.0",
"id": payload.get("id"),
"error": {
"code": -32601,
"message": "MCP methods are not implemented for this environment.",
},
}
@app.post("/reset")
def reset(request: Optional[ResetRequest] = None) -> dict[str, Any]:
_cleanup_sessions()
effective_request = request or ResetRequest()
session_id = effective_request.session_id or str(uuid4())
env = AdaptEnvironment(session_id=session_id)
SESSIONS[session_id] = env
_touch_session(session_id)
observation = env.reset(
session_id=session_id,
seed=effective_request.seed,
episode_id=effective_request.episode_id,
problem_id=effective_request.problem_id,
difficulty=effective_request.difficulty,
)
return observation.model_dump()
@app.post("/step")
async def step(request: Request) -> dict[str, Any]:
_cleanup_sessions()
payload = await request.json()
if not isinstance(payload, dict):
raise HTTPException(status_code=422, detail="Request body must be a JSON object.")
raw_action = payload.get("action", payload)
try:
effective_action = AdaptAction.model_validate(raw_action)
except Exception as exc:
raise HTTPException(status_code=422, detail=f"Invalid action payload: {exc}") from exc
if not effective_action.session_id:
raise HTTPException(status_code=422, detail="`session_id` is required in the /step request body.")
env = _require_session(effective_action.session_id)
observation = env.step(effective_action)
return {
"observation": observation.model_dump(),
"reward": float(observation.reward),
"done": bool(observation.done),
"info": {
"session_id": observation.session_id,
"feedback": observation.feedback,
"pass_rate": observation.pass_rate,
"visible_pass_rate": observation.visible_pass_rate,
"execution_status": observation.execution_status,
},
}
@app.get("/state")
def state(session_id: str = Query(..., description="Session id returned from /reset.")) -> dict[str, Any]:
env = _require_session(session_id)
if not env.problem:
env.reset(session_id=session_id)
return env.state.model_dump()
def main(host: Optional[str] = None, port: Optional[int] = None) -> None:
if host is None or port is None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
host = args.host if host is None else host
port = args.port if port is None else port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()