from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import base64 import numpy as np from collections import deque from pydantic import BaseModel from typing import Dict, Any, List, Optional import uuid import gymnasium as gym from stable_baselines3 import PPO, DQN, A2C # Added common algos from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.callbacks import BaseCallback from datetime import datetime import asyncio import os import glob import logging from io import BytesIO from PIL import Image import imageio import traceback # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) os.makedirs("models", exist_ok=True) # In-memory storage training_jobs: Dict[str, Dict[str, Any]] = {} class TrainingRequest(BaseModel): env_name: str code: str # <--- WE NOW ACCEPT RAW CODE # --- WEBSOCKET MANAGER (Unchanged) --- class ConnectionManager: def __init__(self): self.active_connections: Dict[str, List[WebSocket]] = {} self.frames: Dict[str, deque] = {} async def connect(self, job_id: str, websocket: WebSocket): await websocket.accept() if job_id not in self.active_connections: self.active_connections[job_id] = [] self.frames[job_id] = deque(maxlen=1) self.active_connections[job_id].append(websocket) def disconnect(self, job_id: str, websocket: WebSocket): if job_id in self.active_connections: self.active_connections[job_id].remove(websocket) if not self.active_connections[job_id]: del self.active_connections[job_id] if job_id in self.frames: del self.frames[job_id] def add_frame(self, job_id: str, frame: np.ndarray): if job_id not in self.frames: self.frames[job_id] = deque(maxlen=1) self.frames[job_id].append(frame) async def broadcast_frame(self, job_id: str): if job_id not in self.frames or not self.frames[job_id]: return frame = self.frames[job_id][-1] try: if isinstance(frame, np.ndarray): if frame.dtype != np.uint8: frame = np.clip(frame * 255, 0, 255).astype(np.uint8) img = Image.fromarray(frame) else: return max_size = 512 if img.width > max_size or img.height > max_size: ratio = max_size / max(img.width, img.height) img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.Resampling.LANCZOS) buffer = BytesIO() img.save(buffer, format='JPEG', quality=85) frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') if job_id in self.active_connections: for connection in self.active_connections[job_id]: try: await connection.send_json({"type": "frame", "job_id": job_id, "data": frame_base64}) except: pass except Exception: pass manager = ConnectionManager() # --- CALLBACK (Modified for Generic Use) --- class MetricsCallback(BaseCallback): def __init__(self, job_id: str, render_freq: int = 4): super().__init__() self.job_id = job_id self.episode_count = 0 self.render_freq = render_freq def _on_step(self) -> bool: job = training_jobs.get(self.job_id) if not job or job["status"] == "stopped": return False # Update metrics job["metrics"]["timesteps"] = self.num_timesteps # We try to guess total timesteps if user set it, otherwise just show progress total = job.get("total_timesteps_guess", 100000) job["metrics"]["progress"] = min(100, int((self.num_timesteps / total) * 100)) # Render if self.num_timesteps % self.render_freq == 0: try: frame = self.model.get_env().render() if frame is not None and isinstance(frame, np.ndarray): manager.add_frame(self.job_id, frame) if len(job["video_buffer"]) < 2000: job["video_buffer"].append(frame) except: pass # Episode handling if self.locals.get("dones", [False])[0]: if "infos" in self.locals and len(self.locals["infos"]) > 0: info = self.locals["infos"][0] if "episode" in info: self.episode_count += 1 ep_reward = float(info["episode"]["r"]) job["metrics"]["episodes"] = self.episode_count job["metrics"]["episode_rewards"].append(ep_reward) job["metrics"]["current_episode_reward"] = ep_reward if len(job["metrics"]["episode_rewards"]) > 0: job["metrics"]["mean_reward"] = float(np.mean(job["metrics"]["episode_rewards"][-100:])) job["metrics"]["std_reward"] = float(np.std(job["metrics"]["episode_rewards"][-100:])) log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}" job["metrics"]["logs"].append(log_entry) if len(job["metrics"]["logs"]) > 100: job["metrics"]["logs"].pop(0) return True def save_video_from_buffer(job_id: str, env_name="env"): job = training_jobs.get(job_id) if not job or not job["video_buffer"]: return None try: video_path = f"models/{env_name}_replay_{job_id}.mp4" imageio.mimsave(video_path, job['video_buffer'], fps=30) job["video_buffer"] = [] return video_path except: return None # --- DYNAMIC EXECUTION ENGINE --- def run_custom_code(job_id: str, code: str, env_name: str): logger.info(f"[EXEC] Starting job {job_id}") training_jobs[job_id]["status"] = "training" training_jobs[job_id]["start_time"] = datetime.now() # 1. Define a specific Callback class for THIS job # The user code will simply call `StreamCallback()` class StreamCallback(MetricsCallback): def __init__(self, render_freq=4): super().__init__(job_id, render_freq) # 2. Setup the execution scope (Variables available to user script) # We inject 'StreamCallback' so the user can pass it to .learn() local_scope = { "gym": gym, "PPO": PPO, "DQN": DQN, "A2C": A2C, "evaluate_policy": evaluate_policy, "Monitor": Monitor, "np": np, "StreamCallback": StreamCallback, # <--- CRITICAL INJECTION "model_save_path": f"models/model_{job_id}", # User should use this path } try: # 3. EXECUTE USER CODE # WARNING: This is dangerous in production (RCE). exec(code, local_scope) # 4. Post-Execution Cleanup # We look for variables the user might have set in local_scope to save results # Save video video_path = save_video_from_buffer(job_id, env_name) # Check if model file exists (User should have used model_save_path) expected_model_path = f"models/model_{job_id}.zip" # final_model_path = expected_model_path if os.path.exists(expected_model_path) else None # Check if user put results in a 'results' variable user_results = local_scope.get("results", {}) training_jobs[job_id]["status"] = "completed" training_jobs[job_id]["results"] = { "mean_reward": user_results.get("mean_reward", 0), "std_reward": user_results.get("std_reward", 0), "model_path": expected_model_path, # We enforce this naming convention "video_path": video_path, "total_episodes": training_jobs[job_id]["metrics"]["episodes"], } training_jobs[job_id]["metrics"]["progress"] = 100 except Exception as e: error_msg = traceback.format_exc() logger.error(f"[EXEC] Error in job {job_id}: {error_msg}") training_jobs[job_id]["status"] = "failed" training_jobs[job_id]["error"] = str(e) training_jobs[job_id]["metrics"]["logs"].append(f"ERROR: {str(e)}") @app.post("/train") def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): job_id = str(uuid.uuid4()) # Basic guess of timesteps for progress bar (parsing strings is hard, defaulting) total_timesteps_guess = 100000 if "total_timesteps=" in request.code: try: # Very naive parsing to make progress bar sort of work part = request.code.split("total_timesteps=")[1].split(")")[0].split(",")[0] total_timesteps_guess = int(part) except: pass training_jobs[job_id] = { "status": "queued", "config": {"env_name": request.env_name}, # Kept for compatibility "total_timesteps_guess": total_timesteps_guess, "metrics": { "timesteps": 0, "episodes": 0, "progress": 0, "episode_rewards": [], "episode_lengths": [], "current_episode_reward": 0, "mean_reward": 0, "std_reward": 0, "eval_mean_reward": None, "eval_std_reward": None, "logs": [], }, "video_buffer": [], "results": None, "error": None, "start_time": None, } background_tasks.add_task(run_custom_code, job_id, request.code, request.env_name) return {"message": "Started", "job_id": job_id} @app.get("/train/{job_id}/status") def get_training_status(job_id: str): """Get full training status with metrics""" job = training_jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") elapsed_time = 0 if job.get("start_time"): elapsed_time = (datetime.now() - job["start_time"]).total_seconds() return { "status": job["status"], "metrics": job["metrics"], "elapsed_time": elapsed_time, "results": job["results"], "error": job["error"], } @app.get("/train/{job_id}/metrics") def get_training_metrics(job_id: str): """Lightweight endpoint for polling metrics""" job = training_jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") elapsed_time = 0 if job.get("start_time"): elapsed_time = (datetime.now() - job["start_time"]).total_seconds() return { "status": job["status"], "metrics": job["metrics"], "elapsed_time": elapsed_time, } @app.post("/train/{job_id}/stop") def stop_training(job_id: str): """Stop a training job""" job = training_jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") if job["status"] == "training": job["status"] = "stopped" # Attempt to save whatever video we have so far video_path = save_video_from_buffer(job_id) if job["results"] is None: job["results"] = {} job["results"]["video_path"] = video_path job["metrics"]["logs"].append( f"[{datetime.now().strftime('%H:%M:%S')}] Training stopped by user" ) return {"message": "Training stopped successfully!"} else: raise HTTPException(status_code=400, detail="Job is not currently training") @app.get("/download/{job_id}/video") def download_video(job_id: str): job = training_jobs.get(job_id) if not job or not job.get("results") or not job["results"].get("video_path"): raise HTTPException(status_code=404, detail="Video not processed yet") path = job["results"]["video_path"] if not os.path.exists(path): raise HTTPException(status_code=404, detail="Video file not found") return FileResponse(path, media_type='video/mp4', filename=f"training_replay_{job_id}.mp4") @app.get("/download/{job_id}/model") def download_model(job_id: str): """Download model with filesystem fallback""" # 1. Try memory job = training_jobs.get(job_id) file_path = None if job and job.get("results"): file_path = job["results"].get("model_path") if not file_path or not os.path.exists(file_path): raise HTTPException(status_code=404, detail="Model file not found") return FileResponse(file_path, media_type='application/octet-stream', filename=f"model_{job_id}.zip") # WebSocket Endpoint @app.websocket("/ws/render/{job_id}") async def websocket_render_endpoint(websocket: WebSocket, job_id: str): """ WebSocket endpoint for real-time environment rendering. Connect from frontend with: ws://localhost:8000/ws/render/{job_id} """ await manager.connect(job_id, websocket) try: while True: # Keep connection alive and handle messages data = await websocket.receive_text() if data == "request_frame": await manager.broadcast_frame(job_id) elif data == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: manager.disconnect(job_id, websocket) except Exception as e: logger.error(f"[WS] WebSocket error for job {job_id}: {e}") manager.disconnect(job_id, websocket) @app.get("/debug/jobs") def debug_jobs(): """Debug endpoint to list all jobs""" return { "jobs": [ { "job_id": job_id, "status": job["status"], "progress": job["metrics"]["progress"], "episodes": job["metrics"]["episodes"], } for job_id, job in training_jobs.items() ] }