Spaces:
Configuration error
Configuration error
Aman Nindra
Enhance backend terminal session management and frontend command comparison features. Added new API endpoints for terminal session creation, input handling, resizing, and stopping. Updated frontend to support real-time command broadcasting and display runtime comparisons between jobs. Improved styling and layout for terminal panes and comparison statistics.
4dcc016 | from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from backend.terminal_manager import ALLOWED_JOBS, TerminalManager | |
| app = FastAPI( | |
| title="RL Autotuning Backend", | |
| description="Backend API for the multi-family GPU autotuning benchmark", | |
| version="0.1.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5173", | |
| "http://127.0.0.1:5173", | |
| "http://localhost:4173", | |
| "http://127.0.0.1:4173", | |
| ], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| env = None | |
| terminal_manager = TerminalManager() | |
| def _get_env(): | |
| global env | |
| if env is None: | |
| try: | |
| from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment | |
| env = SoftmaxSurrogateEnvironment() | |
| except ImportError as exc: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Environment unavailable – missing dependency: {exc.name}", | |
| ) | |
| return env | |
| class ResetRequest(BaseModel): | |
| task: Optional[str] = None | |
| seed: Optional[int] = None | |
| class StepRequest(BaseModel): | |
| config_id: Optional[int] = None | |
| x: Optional[List[float]] = None | |
| class SessionRequest(BaseModel): | |
| job_id: str | |
| restart: bool = False | |
| class SessionInputRequest(BaseModel): | |
| data: str | |
| append_newline: bool = True | |
| class SessionResizeRequest(BaseModel): | |
| cols: int | |
| rows: int | |
| def health() -> Dict[str, str]: | |
| return {"status": "ok"} | |
| def reset(payload: ResetRequest) -> Dict[str, Any]: | |
| return _get_env().reset(task=payload.task, seed=payload.seed) | |
| def step(payload: StepRequest) -> Dict[str, Any]: | |
| e = _get_env() | |
| if payload.config_id is not None: | |
| return e.step({"config_id": payload.config_id}) | |
| if payload.x is not None: | |
| return e.step({"x": payload.x}) | |
| raise HTTPException(status_code=400, detail="Provide config_id or x.") | |
| def state() -> Dict[str, Any]: | |
| return _get_env().state() | |
| def terminal_jobs() -> Dict[str, Any]: | |
| return {"jobs": terminal_manager.list_jobs()} | |
| async def create_terminal_session(payload: SessionRequest) -> Dict[str, Any]: | |
| if payload.job_id not in ALLOWED_JOBS: | |
| raise HTTPException(status_code=404, detail=f"Unknown job_id: {payload.job_id}") | |
| session = await terminal_manager.ensure_session(payload.job_id, restart=payload.restart) | |
| return session.snapshot() | |
| def terminal_session_snapshot(session_id: str) -> Dict[str, Any]: | |
| session = terminal_manager.get_session(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return session.snapshot() | |
| def terminal_session_input(session_id: str, payload: SessionInputRequest) -> Dict[str, Any]: | |
| session = terminal_manager.get_session(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| if not session.is_active: | |
| raise HTTPException(status_code=409, detail="Session is not running") | |
| session.write(payload.data, append_newline=payload.append_newline) | |
| return {"ok": True} | |
| def terminal_session_resize(session_id: str, payload: SessionResizeRequest) -> Dict[str, Any]: | |
| session = terminal_manager.get_session(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session.resize(payload.cols, payload.rows) | |
| return {"ok": True} | |
| def terminal_session_stop(session_id: str) -> Dict[str, Any]: | |
| session = terminal_manager.get_session(session_id) | |
| if session is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session.interrupt() | |
| return {"ok": True} | |
| async def terminal_session_stream(websocket: WebSocket, session_id: str) -> None: | |
| session = terminal_manager.get_session(session_id) | |
| if session is None: | |
| await websocket.close(code=4404) | |
| return | |
| await websocket.accept() | |
| queue = await session.subscribe() | |
| try: | |
| await websocket.send_json(session.snapshot()) | |
| while True: | |
| event = await queue.get() | |
| await websocket.send_json(event) | |
| except WebSocketDisconnect: | |
| pass | |
| finally: | |
| session.unsubscribe(queue) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | |