Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI app for the CodeCourt Hugging Face Docker Space. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import threading | |
| import uuid | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request | |
| from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from agents import SetterAgent, SolverAgent | |
| from env import CodeCourtEnv | |
| from env.dynamic_curriculum import build_dynamic_problem | |
| from env.problem_types import ARCHETYPES, build_problem | |
| from env.state import EpisodeState | |
| APP_ROOT = Path(__file__).resolve().parent | |
| DASHBOARD_DIR = APP_ROOT / "dashboard" | |
| README_PATH = APP_ROOT / "README.md" | |
| OPENENV_PATH = APP_ROOT / "openenv.yaml" | |
| OUTPUTS_DIR = APP_ROOT / "outputs" | |
| class SessionCreateRequest(BaseModel): | |
| archetype: Optional[str] = None | |
| task_id: Optional[int] = Field(default=None, ge=0) | |
| difficulty: int = Field(default=1, ge=1, le=3) | |
| seed: int = 42 | |
| class SolverRunRequest(BaseModel): | |
| solver_mode: str = Field(default="brute_force") | |
| setter_mode: str = Field(default="reference") | |
| solver_code: Optional[str] = None | |
| setter_code: Optional[str] = None | |
| class BenchmarkRequest(BaseModel): | |
| episodes: int = Field(default=6, ge=1, le=30) | |
| solver_mode: str = Field(default="brute_force") | |
| seed: int = 42 | |
| class SessionStore: | |
| def __init__(self) -> None: | |
| self._sessions: Dict[str, CodeCourtEnv] = {} | |
| self._lock = threading.Lock() | |
| def put(self, env: CodeCourtEnv) -> str: | |
| session_id = uuid.uuid4().hex[:12] | |
| with self._lock: | |
| self._sessions[session_id] = env | |
| return session_id | |
| def get(self, session_id: str) -> CodeCourtEnv: | |
| with self._lock: | |
| env = self._sessions.get(session_id) | |
| if env is None: | |
| raise KeyError(session_id) | |
| return env | |
| app = FastAPI(title="CodeCourt", version="1.0.0") | |
| app.mount("/dashboard", StaticFiles(directory=str(DASHBOARD_DIR), html=True), name="dashboard") | |
| app.mount("/outputs", StaticFiles(directory=str(OUTPUTS_DIR), html=False), name="outputs") | |
| SESSIONS = SessionStore() | |
| active_connections: list[WebSocket] = [] | |
| def _log_terminal_score(payload: dict[str, Any]) -> None: | |
| if payload.get("event") not in {"session_run", "benchmark_episode"}: | |
| return | |
| label = "SESSION" | |
| if payload.get("event") == "benchmark_episode": | |
| label = f"BENCH #{int(payload.get('episode', 0)) + 1}" | |
| archetype = payload.get("archetype", "random") | |
| task_id = payload.get("task_id", "?") | |
| difficulty = payload.get("difficulty", "?") | |
| outcome = payload.get("outcome", "n/a") | |
| solver_reward = payload.get("reward") | |
| setter_reward = payload.get("setter_reward") | |
| pass_rate = payload.get("pass_rate") | |
| hidden_pass = payload.get("hidden_pass_rate") | |
| def fmt_signed(value: Any) -> str: | |
| if not isinstance(value, (int, float)): | |
| return "n/a" | |
| return f"{value:+.2f}" | |
| def fmt_pct(value: Any) -> str: | |
| if not isinstance(value, (int, float)): | |
| return "n/a" | |
| return f"{value * 100:.1f}%" | |
| lines = [ | |
| "", | |
| f"[{label}] {archetype}/{task_id} · diff {difficulty}", | |
| f" outcome : {outcome}", | |
| f" solver : {fmt_signed(solver_reward)}", | |
| f" setter : {fmt_signed(setter_reward)}", | |
| f" pass : {fmt_pct(pass_rate)}", | |
| ] | |
| if hidden_pass is not None: | |
| lines.append(f" hidden : {fmt_pct(hidden_pass)}") | |
| print("\n".join(lines), flush=True) | |
| async def _broadcast_payload(payload: dict[str, Any]) -> None: | |
| stale_connections: list[WebSocket] = [] | |
| for conn in active_connections.copy(): | |
| try: | |
| await conn.send_json(payload) | |
| except Exception: | |
| stale_connections.append(conn) | |
| for conn in stale_connections: | |
| if conn in active_connections: | |
| active_connections.remove(conn) | |
| def broadcast_payload(payload: dict[str, Any]) -> None: | |
| _log_terminal_score(payload) | |
| if not active_connections: | |
| return | |
| try: | |
| asyncio.run(_broadcast_payload(payload)) | |
| except RuntimeError: | |
| # Fallback for environments that already have an active loop. | |
| loop = asyncio.new_event_loop() | |
| try: | |
| loop.run_until_complete(_broadcast_payload(payload)) | |
| finally: | |
| loop.close() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| active_connections.append(websocket) | |
| try: | |
| while True: | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| active_connections.remove(websocket) | |
| async def internal_broadcast(request: Request): | |
| data = await request.json() | |
| await _broadcast_payload(data) | |
| return JSONResponse(content={"status": "broadcasted"}) | |
| def _new_env(seed: int, difficulty: int) -> CodeCourtEnv: | |
| env = CodeCourtEnv(difficulty_progression=False, seed=seed) | |
| env._current_difficulty = difficulty | |
| return env | |
| def _get_session_env(session_id: str) -> CodeCourtEnv: | |
| try: | |
| return SESSIONS.get(session_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=f"Unknown session_id: {session_id}") from exc | |
| def _force_problem( | |
| env: CodeCourtEnv, | |
| *, | |
| archetype: Optional[str], | |
| task_id: Optional[int], | |
| difficulty: int, | |
| ) -> Dict[str, Any]: | |
| if archetype is None: | |
| obs = env.reset() | |
| return obs | |
| if archetype not in ARCHETYPES: | |
| raise HTTPException(status_code=400, detail=f"Unknown archetype: {archetype}") | |
| max_task_id = len(ARCHETYPES[archetype]["tasks"]) - 1 | |
| chosen_task_id = 0 if task_id is None else task_id | |
| if not (0 <= chosen_task_id <= max_task_id): | |
| raise HTTPException(status_code=400, detail=f"task_id must be between 0 and {max_task_id}") | |
| env._current_state = EpisodeState( | |
| episode_id=env._episode_count, | |
| archetype=archetype, | |
| task_id=chosen_task_id, | |
| difficulty=difficulty, | |
| ) | |
| env._episode_count += 1 | |
| variant_seed = env.rng.randint(0, 10**9) | |
| if env.dynamic_problems: | |
| problem = build_dynamic_problem(archetype, chosen_task_id, difficulty, seed=variant_seed) | |
| else: | |
| problem = build_problem(archetype, chosen_task_id, difficulty, seed=variant_seed) | |
| env._current_state.problem = problem | |
| return { | |
| "episode_id": env._current_state.episode_id, | |
| "archetype": archetype, | |
| "task_id": chosen_task_id, | |
| "difficulty": difficulty, | |
| "problem_template": problem["description"], | |
| "test_cases": problem["test_cases"], | |
| "variant_seed": variant_seed, | |
| "generation_mode": problem.get("generation_mode", "static"), | |
| "elo": env.elo.get_stats(), | |
| } | |
| def _truncate_problem(problem: dict) -> dict: | |
| return { | |
| "title": problem.get("title"), | |
| "description": problem.get("description"), | |
| "constraints": problem.get("constraints"), | |
| "input_format": problem.get("input_format"), | |
| "output_format": problem.get("output_format"), | |
| "optimal_complexity": problem.get("optimal_complexity"), | |
| "variant_seed": problem.get("variant_seed"), | |
| "generation_mode": problem.get("generation_mode", "static"), | |
| "trap_explanation": problem.get("trap_explanation"), | |
| "public_test_cases": problem.get("public_test_cases", problem.get("test_cases", []))[:3], | |
| "public_test_count": len(problem.get("public_test_cases", problem.get("test_cases", []))), | |
| "hidden_test_count": len(problem.get("hidden_test_cases", [])), | |
| "dynamic_trap_count": len(problem.get("trap_test_cases", [])), | |
| } | |
| def _serialize_state(env: CodeCourtEnv) -> dict: | |
| state = env._current_state | |
| if state is None: | |
| raise HTTPException(status_code=400, detail="Session has no active episode") | |
| payload = state.to_dict() | |
| payload["problem"] = _truncate_problem(state.problem or {}) | |
| payload["elo"] = env.elo.get_stats() | |
| payload["rendered"] = env.render() | |
| if state.setter_result: | |
| payload["setter_result"] = state.setter_result | |
| if state.solver_result: | |
| payload["solver_result"] = state.solver_result | |
| if state.setter_code: | |
| payload["setter_code"] = state.setter_code | |
| if state.solver_code: | |
| payload["solver_code"] = state.solver_code | |
| return payload | |
| def _build_demo_snapshot() -> dict: | |
| env = _new_env(seed=42, difficulty=1) | |
| obs = env.reset() | |
| problem = env._current_state.problem | |
| return { | |
| "environment": CodeCourtEnv.ENV_NAME, | |
| "version": CodeCourtEnv.VERSION, | |
| "episode": { | |
| "episode_id": obs["episode_id"], | |
| "archetype": obs["archetype"], | |
| "task_id": obs["task_id"], | |
| "difficulty": obs["difficulty"], | |
| "problem_template": obs["problem_template"], | |
| "problem": _truncate_problem(problem), | |
| }, | |
| "elo": obs["elo"], | |
| } | |
| def _build_overview() -> dict: | |
| return { | |
| "environment": { | |
| "name": CodeCourtEnv.ENV_NAME, | |
| "version": CodeCourtEnv.VERSION, | |
| "agent_count": 2, | |
| "archetype_count": len(ARCHETYPES), | |
| "task_count": sum(len(arch["tasks"]) for arch in ARCHETYPES.values()), | |
| "difficulty_tiers": 3, | |
| }, | |
| "agents": [ | |
| { | |
| "name": "Setter", | |
| "role": "Problem generator", | |
| "goal": "Produce valid tasks and adversarial hidden tests that expose solver weaknesses.", | |
| }, | |
| { | |
| "name": "Solver", | |
| "role": "Code generator", | |
| "goal": "Write efficient Python that passes all hidden and public tests.", | |
| }, | |
| ], | |
| "story": { | |
| "problem": "LLMs often look strong on known coding tasks but fail on adversarial hidden edge cases.", | |
| "intervention": "CodeCourt trains against those failures using dynamic hidden tests, reward shaping, and sandboxed execution.", | |
| "result": "The system can show baseline weakness, training-time improvement, and before-vs-after metrics in one demo.", | |
| "why_hard_to_game": "Seeded task variation and hidden tests reduce memorization and punish public-only overfitting.", | |
| }, | |
| "training_pipeline": [ | |
| "Generate a task template and seeded dynamic test set with adversarial hidden cases.", | |
| "Collect solver completions from the policy being trained.", | |
| "Execute code in the Oracle sandbox and measure correctness, robustness, and pass rate.", | |
| "Transform execution outcomes into shaped GRPO rewards and update the solver policy.", | |
| "Save baseline, manifests, summaries, and before/after plots for judge-facing proof.", | |
| ], | |
| } | |
| def _build_walkthrough() -> dict: | |
| dynamic_seed = random.randint(0, 10**9) | |
| env = _new_env(seed=dynamic_seed, difficulty=1) | |
| obs = _force_problem(env, archetype=None, task_id=None, difficulty=1) | |
| problem = env._current_state.problem | |
| setter = SetterAgent(use_reference=True) | |
| brute_solver = SolverAgent(use_brute_force=True) | |
| optimal_solver = SolverAgent(use_reference=True) | |
| setter_code = setter.generate_solution(problem) | |
| brute_code = brute_solver.solve(problem) | |
| optimal_code = optimal_solver.solve(problem) | |
| _, brute_solver_info, _, brute_info = env.step(setter_code, brute_code) | |
| eval_env = _new_env(seed=dynamic_seed, difficulty=1) | |
| _force_problem(eval_env, archetype=obs["archetype"], task_id=obs["task_id"], difficulty=obs["difficulty"]) | |
| setter_eval_code = setter.generate_solution(eval_env._current_state.problem) | |
| _, optimal_solver_info, _, optimal_info = eval_env.step(setter_eval_code, optimal_code) | |
| return { | |
| "problem": { | |
| "archetype": obs["archetype"], | |
| "task_id": obs["task_id"], | |
| "difficulty": obs["difficulty"], | |
| **_truncate_problem(problem), | |
| }, | |
| "runs": [ | |
| { | |
| "label": "Baseline Solver", | |
| "solver_mode": "brute_force", | |
| "outcome": brute_info["outcome"], | |
| "solver_reward": brute_solver_info["reward"], | |
| "solver_pass_rate": brute_info["solver_pass_rate"], | |
| "setter_valid": brute_info["setter_valid"], | |
| "validation_errors": brute_info["validation_errors"], | |
| }, | |
| { | |
| "label": "Reference Solver", | |
| "solver_mode": "optimal_reference", | |
| "outcome": optimal_info["outcome"], | |
| "solver_reward": optimal_solver_info["reward"], | |
| "solver_pass_rate": optimal_info["solver_pass_rate"], | |
| "setter_valid": optimal_info["setter_valid"], | |
| "validation_errors": optimal_info["validation_errors"], | |
| }, | |
| ], | |
| } | |
| def _read_json_if_exists(path: Path): | |
| if not path.exists(): | |
| return None | |
| try: | |
| return json.loads(path.read_text()) | |
| except json.JSONDecodeError: | |
| return None | |
| def _summarize_training_log(training_log: Any) -> dict[str, Any] | None: | |
| if not isinstance(training_log, list) or not training_log: | |
| return None | |
| metric_rows = [row for row in training_log if isinstance(row, dict) and row.get("step") is not None] | |
| reward_rows = [row for row in metric_rows if row.get("reward") is not None] | |
| if not metric_rows: | |
| return None | |
| final_metric = metric_rows[-1] | |
| final_reward_row = reward_rows[-1] if reward_rows else None | |
| reward_values = [row["reward"] for row in reward_rows if isinstance(row.get("reward"), (int, float))] | |
| inferred_kind = "generic_trace" | |
| if reward_rows and any("loss" in row for row in metric_rows): | |
| inferred_kind = "grpo_trace" | |
| return { | |
| "kind": inferred_kind, | |
| "logged_steps": len(metric_rows), | |
| "reward_points": len(reward_rows), | |
| "final_step": final_metric.get("step"), | |
| "final_reward": final_reward_row.get("reward") if final_reward_row else None, | |
| "best_reward": max(reward_values) if reward_values else None, | |
| "worst_reward": min(reward_values) if reward_values else None, | |
| "final_runtime_sec": final_metric.get("train_runtime"), | |
| "has_loss": any("loss" in row for row in metric_rows), | |
| "has_completion_stats": any("completions/mean_length" in row for row in metric_rows), | |
| "final_clipped_ratio": final_reward_row.get("completions/clipped_ratio") if final_reward_row else None, | |
| } | |
| def _build_artifacts() -> dict: | |
| baseline_path = OUTPUTS_DIR / "baseline_results.json" | |
| training_log_path = OUTPUTS_DIR / "training_history.json" | |
| manifest_path = OUTPUTS_DIR / "artifact_manifest.json" | |
| training_summary_path = OUTPUTS_DIR / "training_summary.json" | |
| capability_eval_path = OUTPUTS_DIR / "capability_boundary_eval.json" | |
| plots_dir = OUTPUTS_DIR / "plots" | |
| evaluation_summary_path = plots_dir / "evaluation_summary.json" | |
| baseline = _read_json_if_exists(baseline_path) | |
| training_log = _read_json_if_exists(training_log_path) | |
| manifest = _read_json_if_exists(manifest_path) | |
| evaluation_summary = _read_json_if_exists(evaluation_summary_path) | |
| training_summary = _read_json_if_exists(training_summary_path) | |
| if training_summary is None and isinstance(evaluation_summary, dict): | |
| training_summary = evaluation_summary | |
| capability_eval = _read_json_if_exists(capability_eval_path) | |
| inferred_training_summary = _summarize_training_log(training_log) | |
| latest_reward = None | |
| latest_pass_rate = None | |
| if isinstance(training_log, list) and training_log: | |
| reward_rows = [row for row in training_log if isinstance(row, dict)] | |
| if reward_rows: | |
| latest = reward_rows[-1] | |
| latest_reward = latest.get("solver_reward", latest.get("reward")) | |
| latest_pass_rate = latest.get("solver_pass_rate", latest.get("reward_pass_rate")) | |
| training_summary_run_type = training_summary.get("run_type") if isinstance(training_summary, dict) else None | |
| manifest_run_type = None | |
| if isinstance(manifest, dict): | |
| manifest_run_type = ( | |
| manifest.get("training_run", {}) or {} | |
| ).get("run_type") or ( | |
| manifest.get("trained_reference", {}) or {} | |
| ).get("run_type") | |
| smoke_detected = ( | |
| training_summary_run_type == "smoke_run_reference" | |
| or manifest_run_type == "smoke_run_reference" | |
| ) | |
| real_grpo_detected = bool(inferred_training_summary and inferred_training_summary.get("kind") == "grpo_trace") | |
| return { | |
| "baseline_available": baseline is not None, | |
| "training_manifest_available": manifest is not None, | |
| "training_log_available": training_log is not None, | |
| "training_summary_available": training_summary is not None, | |
| "plots_available": plots_dir.exists(), | |
| "baseline_summary": baseline.get("summary") if isinstance(baseline, dict) else None, | |
| "training_manifest": manifest, | |
| "training_summary": training_summary, | |
| "evaluation_summary": evaluation_summary, | |
| "capability_eval": capability_eval, | |
| "inferred_training_summary": inferred_training_summary, | |
| "artifact_truth": { | |
| "smoke_detected": smoke_detected, | |
| "real_grpo_detected": real_grpo_detected, | |
| "training_summary_run_type": training_summary_run_type, | |
| "manifest_run_type": manifest_run_type, | |
| }, | |
| "latest_training_metrics": { | |
| "reward": latest_reward, | |
| "reward_pass_rate": latest_pass_rate, | |
| }, | |
| "plot_files": sorted([p.name for p in plots_dir.glob("*")]) if plots_dir.exists() else [], | |
| } | |
| def _select_setter(problem: dict, request: SolverRunRequest) -> str: | |
| if request.setter_code: | |
| return request.setter_code | |
| setter = SetterAgent(use_reference=True) | |
| return setter.generate_solution(problem) | |
| def _select_solver(problem: dict, request: SolverRunRequest) -> str: | |
| if request.solver_mode == "custom": | |
| if not request.solver_code: | |
| raise HTTPException(status_code=400, detail="solver_code is required for custom mode") | |
| return request.solver_code | |
| if request.solver_mode == "reference": | |
| return SolverAgent(use_reference=True).solve(problem) | |
| if request.solver_mode == "brute_force": | |
| return SolverAgent(use_brute_force=True).solve(problem) | |
| raise HTTPException(status_code=400, detail=f"Unknown solver_mode: {request.solver_mode}") | |
| def _run_episode(env: CodeCourtEnv, request: SolverRunRequest) -> dict: | |
| state = env._current_state | |
| if state is None or state.problem is None: | |
| raise HTTPException(status_code=400, detail="Reset the session before running a solver") | |
| setter_code = _select_setter(state.problem, request) | |
| solver_code = _select_solver(state.problem, request) | |
| setter_info, solver_info, _, info = env.step(setter_code, solver_code) | |
| payload = { | |
| "session_state": _serialize_state(env), | |
| "info": info, | |
| "setter_reward_info": setter_info, | |
| "solver_reward_info": solver_info, | |
| } | |
| broadcast_payload({ | |
| "event": "session_run", | |
| "session_id": getattr(env._current_state, "episode_id", None), | |
| "archetype": state.archetype, | |
| "task_id": state.task_id, | |
| "difficulty": state.difficulty, | |
| "outcome": info["outcome"], | |
| "reward": solver_info["reward"], | |
| "pass_rate": info["solver_pass_rate"], | |
| "public_pass_rate": info.get("solver_public_pass_rate"), | |
| "hidden_pass_rate": info.get("solver_hidden_pass_rate"), | |
| "setter_reward": setter_info["reward"], | |
| "setter_elo": info.get("elo", {}).get("setter_elo"), | |
| "solver_elo": info.get("elo", {}).get("solver_elo"), | |
| "timestamp": uuid.uuid4().hex[:8], | |
| }) | |
| return payload | |
| def _benchmark(request: BenchmarkRequest) -> dict: | |
| setter = SetterAgent(use_reference=True) | |
| episodes = [] | |
| for episode_idx in range(request.episodes): | |
| env = _new_env(seed=request.seed + episode_idx, difficulty=1) | |
| obs = env.reset() | |
| problem = env._current_state.problem | |
| setter_code = setter.generate_solution(problem) | |
| solver_request = SolverRunRequest(solver_mode=request.solver_mode) | |
| solver_code = _select_solver(problem, solver_request) | |
| _, solver_info, _, info = env.step(setter_code, solver_code) | |
| effective_pass_rate = 0.0 if info["outcome"] == "invalid" else info["solver_pass_rate"] | |
| episodes.append({ | |
| "episode": episode_idx, | |
| "archetype": obs["archetype"], | |
| "task_id": obs["task_id"], | |
| "difficulty": obs["difficulty"], | |
| "outcome": info["outcome"], | |
| "solver_reward": solver_info["reward"], | |
| "solver_pass_rate": effective_pass_rate, | |
| "raw_solver_pass_rate": info["solver_pass_rate"], | |
| }) | |
| broadcast_payload({ | |
| "event": "benchmark_episode", | |
| "episode": episode_idx, | |
| "archetype": obs["archetype"], | |
| "task_id": obs["task_id"], | |
| "difficulty": obs["difficulty"], | |
| "outcome": info["outcome"], | |
| "reward": solver_info["reward"], | |
| "pass_rate": effective_pass_rate, | |
| "public_pass_rate": info.get("solver_public_pass_rate"), | |
| "hidden_pass_rate": info.get("solver_hidden_pass_rate"), | |
| "setter_elo": info.get("elo", {}).get("setter_elo"), | |
| "solver_elo": info.get("elo", {}).get("solver_elo"), | |
| "stream": "benchmark", | |
| }) | |
| avg_pass_rate = sum(ep["solver_pass_rate"] for ep in episodes) / len(episodes) | |
| avg_reward = sum(ep["solver_reward"] for ep in episodes) / len(episodes) | |
| return { | |
| "solver_mode": request.solver_mode, | |
| "episodes": episodes, | |
| "summary": { | |
| "episodes": request.episodes, | |
| "avg_solver_pass_rate": avg_pass_rate, | |
| "avg_solver_reward": avg_reward, | |
| "solver_win_rate": sum(1 for ep in episodes if ep["outcome"] == "solver_wins") / len(episodes), | |
| "setter_win_rate": sum(1 for ep in episodes if ep["outcome"] == "setter_wins") / len(episodes), | |
| "invalid_rate": sum(1 for ep in episodes if ep["outcome"] == "invalid") / len(episodes), | |
| }, | |
| } | |
| def root() -> FileResponse: | |
| return FileResponse(DASHBOARD_DIR / "index.html") | |
| def health() -> dict: | |
| return {"status": "ok", "app": "codecourt-space"} | |
| def overview() -> JSONResponse: | |
| return JSONResponse(content=_build_overview()) | |
| def spec() -> PlainTextResponse: | |
| if not OPENENV_PATH.exists(): | |
| raise HTTPException(status_code=404, detail="openenv.yaml not found") | |
| return PlainTextResponse(OPENENV_PATH.read_text(), media_type="text/yaml") | |
| def readme() -> PlainTextResponse: | |
| if not README_PATH.exists(): | |
| raise HTTPException(status_code=404, detail="README.md not found") | |
| return PlainTextResponse(README_PATH.read_text(), media_type="text/markdown") | |
| def demo() -> JSONResponse: | |
| return JSONResponse(content=_build_demo_snapshot()) | |
| def walkthrough() -> JSONResponse: | |
| return JSONResponse(content=_build_walkthrough()) | |
| def artifacts() -> JSONResponse: | |
| return JSONResponse(content=_build_artifacts()) | |
| def create_session(request: SessionCreateRequest) -> JSONResponse: | |
| env = _new_env(seed=request.seed, difficulty=request.difficulty) | |
| _force_problem(env, archetype=request.archetype, task_id=request.task_id, difficulty=request.difficulty) | |
| session_id = SESSIONS.put(env) | |
| return JSONResponse(content={"session_id": session_id, "state": _serialize_state(env)}) | |
| def reset_session(session_id: str, request: SessionCreateRequest) -> JSONResponse: | |
| env = _get_session_env(session_id) | |
| env._current_difficulty = request.difficulty | |
| _force_problem(env, archetype=request.archetype, task_id=request.task_id, difficulty=request.difficulty) | |
| return JSONResponse(content={"session_id": session_id, "state": _serialize_state(env)}) | |
| def get_session(session_id: str) -> JSONResponse: | |
| env = _get_session_env(session_id) | |
| return JSONResponse(content={"session_id": session_id, "state": _serialize_state(env)}) | |
| def run_session(session_id: str, request: SolverRunRequest) -> JSONResponse: | |
| env = _get_session_env(session_id) | |
| return JSONResponse(content=_run_episode(env, request)) | |
| def run_benchmark(request: BenchmarkRequest) -> JSONResponse: | |
| return JSONResponse(content=_benchmark(request)) | |