File size: 6,268 Bytes
206438f e2f8b29 9e6a926 e2f8b29 9e6a926 e2f8b29 02e58fe e2f8b29 9e6a926 e2f8b29 9e6a926 206438f 9e6a926 e2f8b29 206438f 9e6a926 fc3fbaf e2f8b29 9e6a926 e2f8b29 d449338 02e58fe d449338 e2f8b29 d8eeec6 9e6a926 206438f 9e6a926 e2f8b29 4f58e42 206438f 4f58e42 6be6d8e 4f58e42 0b9b77b 6be6d8e 0b9b77b 6be6d8e 0b9b77b 6be6d8e 0b9b77b e2f8b29 9e6a926 e2f8b29 6be6d8e e2f8b29 02e58fe 6be6d8e 9e6a926 e2f8b29 9e6a926 6be6d8e 9e6a926 6be6d8e 9e6a926 e2f8b29 9e6a926 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """FastAPI app — openenv create_app() + custom routes."""
from __future__ import annotations
import asyncio
import json
import logging
import sys
from typing import Optional
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from openenv.core.env_server.http_server import create_app
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
from server._baseline_results import get_last_grader_result
from server.environment import MLTrainingEnvironment
# Structured JSON logging
class JSONFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
log_data = {
"time": self.formatTime(record),
"level": record.levelname,
"msg": record.getMessage(),
}
if hasattr(record, "session_id"):
log_data["session_id"] = record.session_id
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "step_count"):
log_data["step_count"] = record.step_count
if hasattr(record, "action_type"):
log_data["action_type"] = record.action_type
if hasattr(record, "score"):
log_data["score"] = record.score
return json.dumps(log_data)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JSONFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger(__name__)
# All 7 tasks
ALL_TASKS = [
{"id": "task_001", "difficulty": "easy", "max_steps": 20, "has_grader": True},
{"id": "task_002", "difficulty": "easy", "max_steps": 20, "has_grader": True},
{"id": "task_003", "difficulty": "medium", "max_steps": 25, "has_grader": True},
{"id": "task_004", "difficulty": "medium", "max_steps": 25, "has_grader": True},
{"id": "task_005", "difficulty": "hard", "max_steps": 30, "has_grader": True},
{"id": "task_006", "difficulty": "hard", "max_steps": 30, "has_grader": True},
{"id": "task_007", "difficulty": "hard", "max_steps": 25, "has_grader": True},
]
# create_app takes the class (factory), not an instance
app: FastAPI = create_app(
MLTrainingEnvironment,
MLTrainingAction,
MLTrainingObservation,
env_name="pytorch_training_debugger",
max_concurrent_envs=5,
)
# Override framework's /health route with our custom version
app.routes[:] = [
r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
]
# Thread-safe baseline lock (Fix #14)
_baseline_lock = asyncio.Lock()
@app.get("/")
def root() -> RedirectResponse:
"""Redirect root to dashboard."""
return RedirectResponse(url="/dashboard")
@app.get("/health")
def health_check() -> dict:
"""Health check — required by hackathon auto-validator."""
return {"status": "healthy", "tasks": len(ALL_TASKS)}
@app.get("/dashboard", response_class=HTMLResponse)
def get_dashboard() -> str:
"""Serve live diagnostic dashboard."""
import pathlib
html_path = pathlib.Path(__file__).parent / "dashboard.html"
return html_path.read_text()
@app.get("/validation-report")
def get_validation_report() -> dict:
"""Serve pre-computed simulation fidelity report."""
import pathlib
report_path = (
pathlib.Path(__file__).parent.parent
/ "validation"
/ "reports"
/ "fidelity_report.json"
)
if report_path.exists():
return json.loads(report_path.read_text())
return {
"error": "Validation report not yet generated. "
"Run: python validation/run_all_validations.py"
}
@app.get("/curriculum")
def get_curriculum() -> dict:
"""Recommended task order (easy to hard, with difficulty scaling)."""
curriculum: list[dict] = []
for task in ALL_TASKS:
for level in [1, 3, 5]:
curriculum.append({
"task_id": task["id"],
"difficulty": task["difficulty"],
"difficulty_level": level,
"max_steps": task["max_steps"],
})
return {"curriculum": curriculum, "total_episodes": len(curriculum)}
@app.get("/leaderboard")
def get_leaderboard() -> dict:
"""Sorted leaderboard of completed episode scores."""
from server._baseline_results import _last_results
entries = [
v for k, v in _last_results.items()
if k != "_latest" and isinstance(v, dict)
]
sorted_entries = sorted(
entries, key=lambda x: x.get("score", 0), reverse=True
)
return {"entries": sorted_entries, "total": len(sorted_entries)}
@app.get("/replay/{episode_id}")
def get_replay(episode_id: str) -> dict:
"""Return full action/observation trace for a completed episode."""
from server._baseline_results import _last_results
result = _last_results.get(episode_id)
if result is None:
return {"error": f"Episode '{episode_id}' not found"}
return {"episode_id": episode_id, **result}
@app.get("/tasks")
def get_tasks() -> list[dict]:
"""Return task list with IDs, difficulties, and action schema."""
schema = MLTrainingAction.model_json_schema()
return [{**task, "action_schema": schema} for task in ALL_TASKS]
@app.post("/grader")
def post_grader(session_id: Optional[str] = None) -> dict:
"""Return grader score for most recently completed episode."""
result = get_last_grader_result(session_id)
if result is None:
return {"score": None, "error": "no_completed_episode"}
return result
@app.post("/baseline", response_model=None)
async def post_baseline() -> JSONResponse | dict:
"""Trigger baseline run, return scores for all tasks."""
if _baseline_lock.locked():
return JSONResponse(
status_code=409,
content={"error": "baseline_in_progress"},
)
async with _baseline_lock:
from server._heuristic import run_baseline_all_tasks
scores = await asyncio.get_event_loop().run_in_executor(
None, run_baseline_all_tasks
)
return {"scores": scores}
def main() -> None:
"""Entry point for running the server."""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|