|
|
from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect |
|
|
from fastapi.responses import FileResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import base64 |
|
|
import numpy as np |
|
|
from collections import deque |
|
|
from pydantic import BaseModel |
|
|
from typing import Dict, Any, List, Optional |
|
|
import uuid |
|
|
import gymnasium as gym |
|
|
from stable_baselines3 import PPO, DQN, A2C |
|
|
from stable_baselines3.common.monitor import Monitor |
|
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
|
from stable_baselines3.common.callbacks import BaseCallback |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
import os |
|
|
import glob |
|
|
import logging |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import imageio |
|
|
import traceback |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
os.makedirs("models", exist_ok=True) |
|
|
|
|
|
|
|
|
training_jobs: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
class TrainingRequest(BaseModel): |
|
|
env_name: str |
|
|
code: str |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
def __init__(self): |
|
|
self.active_connections: Dict[str, List[WebSocket]] = {} |
|
|
self.frames: Dict[str, deque] = {} |
|
|
|
|
|
async def connect(self, job_id: str, websocket: WebSocket): |
|
|
await websocket.accept() |
|
|
if job_id not in self.active_connections: |
|
|
self.active_connections[job_id] = [] |
|
|
self.frames[job_id] = deque(maxlen=1) |
|
|
self.active_connections[job_id].append(websocket) |
|
|
|
|
|
def disconnect(self, job_id: str, websocket: WebSocket): |
|
|
if job_id in self.active_connections: |
|
|
self.active_connections[job_id].remove(websocket) |
|
|
if not self.active_connections[job_id]: |
|
|
del self.active_connections[job_id] |
|
|
if job_id in self.frames: |
|
|
del self.frames[job_id] |
|
|
|
|
|
def add_frame(self, job_id: str, frame: np.ndarray): |
|
|
if job_id not in self.frames: |
|
|
self.frames[job_id] = deque(maxlen=1) |
|
|
self.frames[job_id].append(frame) |
|
|
|
|
|
async def broadcast_frame(self, job_id: str): |
|
|
if job_id not in self.frames or not self.frames[job_id]: return |
|
|
frame = self.frames[job_id][-1] |
|
|
try: |
|
|
if isinstance(frame, np.ndarray): |
|
|
if frame.dtype != np.uint8: frame = np.clip(frame * 255, 0, 255).astype(np.uint8) |
|
|
img = Image.fromarray(frame) |
|
|
else: return |
|
|
|
|
|
max_size = 512 |
|
|
if img.width > max_size or img.height > max_size: |
|
|
ratio = max_size / max(img.width, img.height) |
|
|
img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.Resampling.LANCZOS) |
|
|
|
|
|
buffer = BytesIO() |
|
|
img.save(buffer, format='JPEG', quality=85) |
|
|
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
if job_id in self.active_connections: |
|
|
for connection in self.active_connections[job_id]: |
|
|
try: await connection.send_json({"type": "frame", "job_id": job_id, "data": frame_base64}) |
|
|
except: pass |
|
|
except Exception: pass |
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
|
|
|
class MetricsCallback(BaseCallback): |
|
|
def __init__(self, job_id: str, render_freq: int = 4): |
|
|
super().__init__() |
|
|
self.job_id = job_id |
|
|
self.episode_count = 0 |
|
|
self.render_freq = render_freq |
|
|
|
|
|
def _on_step(self) -> bool: |
|
|
job = training_jobs.get(self.job_id) |
|
|
if not job or job["status"] == "stopped": return False |
|
|
|
|
|
|
|
|
job["metrics"]["timesteps"] = self.num_timesteps |
|
|
|
|
|
|
|
|
total = job.get("total_timesteps_guess", 100000) |
|
|
job["metrics"]["progress"] = min(100, int((self.num_timesteps / total) * 100)) |
|
|
|
|
|
|
|
|
if self.num_timesteps % self.render_freq == 0: |
|
|
try: |
|
|
frame = self.model.get_env().render() |
|
|
if frame is not None and isinstance(frame, np.ndarray): |
|
|
manager.add_frame(self.job_id, frame) |
|
|
if len(job["video_buffer"]) < 2000: job["video_buffer"].append(frame) |
|
|
except: pass |
|
|
|
|
|
|
|
|
if self.locals.get("dones", [False])[0]: |
|
|
if "infos" in self.locals and len(self.locals["infos"]) > 0: |
|
|
info = self.locals["infos"][0] |
|
|
if "episode" in info: |
|
|
self.episode_count += 1 |
|
|
ep_reward = float(info["episode"]["r"]) |
|
|
job["metrics"]["episodes"] = self.episode_count |
|
|
job["metrics"]["episode_rewards"].append(ep_reward) |
|
|
job["metrics"]["current_episode_reward"] = ep_reward |
|
|
|
|
|
if len(job["metrics"]["episode_rewards"]) > 0: |
|
|
job["metrics"]["mean_reward"] = float(np.mean(job["metrics"]["episode_rewards"][-100:])) |
|
|
job["metrics"]["std_reward"] = float(np.std(job["metrics"]["episode_rewards"][-100:])) |
|
|
|
|
|
log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}" |
|
|
job["metrics"]["logs"].append(log_entry) |
|
|
if len(job["metrics"]["logs"]) > 100: job["metrics"]["logs"].pop(0) |
|
|
return True |
|
|
|
|
|
def save_video_from_buffer(job_id: str, env_name="env"): |
|
|
job = training_jobs.get(job_id) |
|
|
if not job or not job["video_buffer"]: return None |
|
|
try: |
|
|
video_path = f"models/{env_name}_replay_{job_id}.mp4" |
|
|
imageio.mimsave(video_path, job['video_buffer'], fps=30) |
|
|
job["video_buffer"] = [] |
|
|
return video_path |
|
|
except: return None |
|
|
|
|
|
|
|
|
def run_custom_code(job_id: str, code: str, env_name: str): |
|
|
logger.info(f"[EXEC] Starting job {job_id}") |
|
|
training_jobs[job_id]["status"] = "training" |
|
|
training_jobs[job_id]["start_time"] = datetime.now() |
|
|
|
|
|
|
|
|
|
|
|
class StreamCallback(MetricsCallback): |
|
|
def __init__(self, render_freq=4): |
|
|
super().__init__(job_id, render_freq) |
|
|
|
|
|
|
|
|
|
|
|
local_scope = { |
|
|
"gym": gym, |
|
|
"PPO": PPO, |
|
|
"DQN": DQN, |
|
|
"A2C": A2C, |
|
|
"evaluate_policy": evaluate_policy, |
|
|
"Monitor": Monitor, |
|
|
"np": np, |
|
|
"StreamCallback": StreamCallback, |
|
|
"model_save_path": f"models/model_{job_id}", |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
exec(code, local_scope) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_path = save_video_from_buffer(job_id, env_name) |
|
|
|
|
|
|
|
|
expected_model_path = f"models/model_{job_id}.zip" |
|
|
|
|
|
|
|
|
|
|
|
user_results = local_scope.get("results", {}) |
|
|
|
|
|
training_jobs[job_id]["status"] = "completed" |
|
|
training_jobs[job_id]["results"] = { |
|
|
"mean_reward": user_results.get("mean_reward", 0), |
|
|
"std_reward": user_results.get("std_reward", 0), |
|
|
"model_path": expected_model_path, |
|
|
"video_path": video_path, |
|
|
"total_episodes": training_jobs[job_id]["metrics"]["episodes"], |
|
|
} |
|
|
training_jobs[job_id]["metrics"]["progress"] = 100 |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = traceback.format_exc() |
|
|
logger.error(f"[EXEC] Error in job {job_id}: {error_msg}") |
|
|
training_jobs[job_id]["status"] = "failed" |
|
|
training_jobs[job_id]["error"] = str(e) |
|
|
training_jobs[job_id]["metrics"]["logs"].append(f"ERROR: {str(e)}") |
|
|
|
|
|
@app.post("/train") |
|
|
def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): |
|
|
job_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
total_timesteps_guess = 100000 |
|
|
if "total_timesteps=" in request.code: |
|
|
try: |
|
|
|
|
|
part = request.code.split("total_timesteps=")[1].split(")")[0].split(",")[0] |
|
|
total_timesteps_guess = int(part) |
|
|
except: pass |
|
|
|
|
|
training_jobs[job_id] = { |
|
|
"status": "queued", |
|
|
"config": {"env_name": request.env_name}, |
|
|
"total_timesteps_guess": total_timesteps_guess, |
|
|
"metrics": { |
|
|
"timesteps": 0, "episodes": 0, "progress": 0, |
|
|
"episode_rewards": [], "episode_lengths": [], |
|
|
"current_episode_reward": 0, "mean_reward": 0, "std_reward": 0, |
|
|
"eval_mean_reward": None, "eval_std_reward": None, "logs": [], |
|
|
}, |
|
|
"video_buffer": [], |
|
|
"results": None, "error": None, "start_time": None, |
|
|
} |
|
|
|
|
|
background_tasks.add_task(run_custom_code, job_id, request.code, request.env_name) |
|
|
return {"message": "Started", "job_id": job_id} |
|
|
|
|
|
@app.get("/train/{job_id}/status") |
|
|
def get_training_status(job_id: str): |
|
|
"""Get full training status with metrics""" |
|
|
job = training_jobs.get(job_id) |
|
|
if not job: |
|
|
raise HTTPException(status_code=404, detail="Job not found") |
|
|
|
|
|
elapsed_time = 0 |
|
|
if job.get("start_time"): |
|
|
elapsed_time = (datetime.now() - job["start_time"]).total_seconds() |
|
|
|
|
|
return { |
|
|
"status": job["status"], |
|
|
"metrics": job["metrics"], |
|
|
"elapsed_time": elapsed_time, |
|
|
"results": job["results"], |
|
|
"error": job["error"], |
|
|
} |
|
|
|
|
|
@app.get("/train/{job_id}/metrics") |
|
|
def get_training_metrics(job_id: str): |
|
|
"""Lightweight endpoint for polling metrics""" |
|
|
job = training_jobs.get(job_id) |
|
|
if not job: |
|
|
raise HTTPException(status_code=404, detail="Job not found") |
|
|
|
|
|
elapsed_time = 0 |
|
|
if job.get("start_time"): |
|
|
elapsed_time = (datetime.now() - job["start_time"]).total_seconds() |
|
|
|
|
|
return { |
|
|
"status": job["status"], |
|
|
"metrics": job["metrics"], |
|
|
"elapsed_time": elapsed_time, |
|
|
} |
|
|
|
|
|
@app.post("/train/{job_id}/stop") |
|
|
def stop_training(job_id: str): |
|
|
"""Stop a training job""" |
|
|
job = training_jobs.get(job_id) |
|
|
if not job: |
|
|
raise HTTPException(status_code=404, detail="Job not found") |
|
|
|
|
|
if job["status"] == "training": |
|
|
job["status"] = "stopped" |
|
|
|
|
|
video_path = save_video_from_buffer(job_id) |
|
|
|
|
|
if job["results"] is None: |
|
|
job["results"] = {} |
|
|
job["results"]["video_path"] = video_path |
|
|
job["metrics"]["logs"].append( |
|
|
f"[{datetime.now().strftime('%H:%M:%S')}] Training stopped by user" |
|
|
) |
|
|
return {"message": "Training stopped successfully!"} |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Job is not currently training") |
|
|
|
|
|
@app.get("/download/{job_id}/video") |
|
|
def download_video(job_id: str): |
|
|
job = training_jobs.get(job_id) |
|
|
if not job or not job.get("results") or not job["results"].get("video_path"): |
|
|
raise HTTPException(status_code=404, detail="Video not processed yet") |
|
|
|
|
|
path = job["results"]["video_path"] |
|
|
|
|
|
if not os.path.exists(path): |
|
|
raise HTTPException(status_code=404, detail="Video file not found") |
|
|
|
|
|
return FileResponse(path, media_type='video/mp4', filename=f"training_replay_{job_id}.mp4") |
|
|
|
|
|
@app.get("/download/{job_id}/model") |
|
|
def download_model(job_id: str): |
|
|
"""Download model with filesystem fallback""" |
|
|
|
|
|
job = training_jobs.get(job_id) |
|
|
file_path = None |
|
|
if job and job.get("results"): |
|
|
file_path = job["results"].get("model_path") |
|
|
|
|
|
if not file_path or not os.path.exists(file_path): |
|
|
raise HTTPException(status_code=404, detail="Model file not found") |
|
|
|
|
|
return FileResponse(file_path, media_type='application/octet-stream', filename=f"model_{job_id}.zip") |
|
|
|
|
|
|
|
|
@app.websocket("/ws/render/{job_id}") |
|
|
async def websocket_render_endpoint(websocket: WebSocket, job_id: str): |
|
|
""" |
|
|
WebSocket endpoint for real-time environment rendering. |
|
|
Connect from frontend with: ws://localhost:8000/ws/render/{job_id} |
|
|
""" |
|
|
await manager.connect(job_id, websocket) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
if data == "request_frame": |
|
|
await manager.broadcast_frame(job_id) |
|
|
elif data == "ping": |
|
|
await websocket.send_json({"type": "pong"}) |
|
|
except WebSocketDisconnect: |
|
|
manager.disconnect(job_id, websocket) |
|
|
except Exception as e: |
|
|
logger.error(f"[WS] WebSocket error for job {job_id}: {e}") |
|
|
manager.disconnect(job_id, websocket) |
|
|
|
|
|
@app.get("/debug/jobs") |
|
|
def debug_jobs(): |
|
|
"""Debug endpoint to list all jobs""" |
|
|
return { |
|
|
"jobs": [ |
|
|
{ |
|
|
"job_id": job_id, |
|
|
"status": job["status"], |
|
|
"progress": job["metrics"]["progress"], |
|
|
"episodes": job["metrics"]["episodes"], |
|
|
} |
|
|
for job_id, job in training_jobs.items() |
|
|
] |
|
|
} |