MARL-Gym / app.py
bumie-e's picture
Added support for dynamic code execution
d84d915
from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import base64
import numpy as np
from collections import deque
from pydantic import BaseModel
from typing import Dict, Any, List, Optional
import uuid
import gymnasium as gym
from stable_baselines3 import PPO, DQN, A2C # Added common algos
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback
from datetime import datetime
import asyncio
import os
import glob
import logging
from io import BytesIO
from PIL import Image
import imageio
import traceback
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
os.makedirs("models", exist_ok=True)
# In-memory storage
training_jobs: Dict[str, Dict[str, Any]] = {}
class TrainingRequest(BaseModel):
env_name: str
code: str # <--- WE NOW ACCEPT RAW CODE
# --- WEBSOCKET MANAGER (Unchanged) ---
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
self.frames: Dict[str, deque] = {}
async def connect(self, job_id: str, websocket: WebSocket):
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = []
self.frames[job_id] = deque(maxlen=1)
self.active_connections[job_id].append(websocket)
def disconnect(self, job_id: str, websocket: WebSocket):
if job_id in self.active_connections:
self.active_connections[job_id].remove(websocket)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
if job_id in self.frames:
del self.frames[job_id]
def add_frame(self, job_id: str, frame: np.ndarray):
if job_id not in self.frames:
self.frames[job_id] = deque(maxlen=1)
self.frames[job_id].append(frame)
async def broadcast_frame(self, job_id: str):
if job_id not in self.frames or not self.frames[job_id]: return
frame = self.frames[job_id][-1]
try:
if isinstance(frame, np.ndarray):
if frame.dtype != np.uint8: frame = np.clip(frame * 255, 0, 255).astype(np.uint8)
img = Image.fromarray(frame)
else: return
max_size = 512
if img.width > max_size or img.height > max_size:
ratio = max_size / max(img.width, img.height)
img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.Resampling.LANCZOS)
buffer = BytesIO()
img.save(buffer, format='JPEG', quality=85)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
if job_id in self.active_connections:
for connection in self.active_connections[job_id]:
try: await connection.send_json({"type": "frame", "job_id": job_id, "data": frame_base64})
except: pass
except Exception: pass
manager = ConnectionManager()
# --- CALLBACK (Modified for Generic Use) ---
class MetricsCallback(BaseCallback):
def __init__(self, job_id: str, render_freq: int = 4):
super().__init__()
self.job_id = job_id
self.episode_count = 0
self.render_freq = render_freq
def _on_step(self) -> bool:
job = training_jobs.get(self.job_id)
if not job or job["status"] == "stopped": return False
# Update metrics
job["metrics"]["timesteps"] = self.num_timesteps
# We try to guess total timesteps if user set it, otherwise just show progress
total = job.get("total_timesteps_guess", 100000)
job["metrics"]["progress"] = min(100, int((self.num_timesteps / total) * 100))
# Render
if self.num_timesteps % self.render_freq == 0:
try:
frame = self.model.get_env().render()
if frame is not None and isinstance(frame, np.ndarray):
manager.add_frame(self.job_id, frame)
if len(job["video_buffer"]) < 2000: job["video_buffer"].append(frame)
except: pass
# Episode handling
if self.locals.get("dones", [False])[0]:
if "infos" in self.locals and len(self.locals["infos"]) > 0:
info = self.locals["infos"][0]
if "episode" in info:
self.episode_count += 1
ep_reward = float(info["episode"]["r"])
job["metrics"]["episodes"] = self.episode_count
job["metrics"]["episode_rewards"].append(ep_reward)
job["metrics"]["current_episode_reward"] = ep_reward
if len(job["metrics"]["episode_rewards"]) > 0:
job["metrics"]["mean_reward"] = float(np.mean(job["metrics"]["episode_rewards"][-100:]))
job["metrics"]["std_reward"] = float(np.std(job["metrics"]["episode_rewards"][-100:]))
log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}"
job["metrics"]["logs"].append(log_entry)
if len(job["metrics"]["logs"]) > 100: job["metrics"]["logs"].pop(0)
return True
def save_video_from_buffer(job_id: str, env_name="env"):
job = training_jobs.get(job_id)
if not job or not job["video_buffer"]: return None
try:
video_path = f"models/{env_name}_replay_{job_id}.mp4"
imageio.mimsave(video_path, job['video_buffer'], fps=30)
job["video_buffer"] = []
return video_path
except: return None
# --- DYNAMIC EXECUTION ENGINE ---
def run_custom_code(job_id: str, code: str, env_name: str):
logger.info(f"[EXEC] Starting job {job_id}")
training_jobs[job_id]["status"] = "training"
training_jobs[job_id]["start_time"] = datetime.now()
# 1. Define a specific Callback class for THIS job
# The user code will simply call `StreamCallback()`
class StreamCallback(MetricsCallback):
def __init__(self, render_freq=4):
super().__init__(job_id, render_freq)
# 2. Setup the execution scope (Variables available to user script)
# We inject 'StreamCallback' so the user can pass it to .learn()
local_scope = {
"gym": gym,
"PPO": PPO,
"DQN": DQN,
"A2C": A2C,
"evaluate_policy": evaluate_policy,
"Monitor": Monitor,
"np": np,
"StreamCallback": StreamCallback, # <--- CRITICAL INJECTION
"model_save_path": f"models/model_{job_id}", # User should use this path
}
try:
# 3. EXECUTE USER CODE
# WARNING: This is dangerous in production (RCE).
exec(code, local_scope)
# 4. Post-Execution Cleanup
# We look for variables the user might have set in local_scope to save results
# Save video
video_path = save_video_from_buffer(job_id, env_name)
# Check if model file exists (User should have used model_save_path)
expected_model_path = f"models/model_{job_id}.zip"
# final_model_path = expected_model_path if os.path.exists(expected_model_path) else None
# Check if user put results in a 'results' variable
user_results = local_scope.get("results", {})
training_jobs[job_id]["status"] = "completed"
training_jobs[job_id]["results"] = {
"mean_reward": user_results.get("mean_reward", 0),
"std_reward": user_results.get("std_reward", 0),
"model_path": expected_model_path, # We enforce this naming convention
"video_path": video_path,
"total_episodes": training_jobs[job_id]["metrics"]["episodes"],
}
training_jobs[job_id]["metrics"]["progress"] = 100
except Exception as e:
error_msg = traceback.format_exc()
logger.error(f"[EXEC] Error in job {job_id}: {error_msg}")
training_jobs[job_id]["status"] = "failed"
training_jobs[job_id]["error"] = str(e)
training_jobs[job_id]["metrics"]["logs"].append(f"ERROR: {str(e)}")
@app.post("/train")
def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
job_id = str(uuid.uuid4())
# Basic guess of timesteps for progress bar (parsing strings is hard, defaulting)
total_timesteps_guess = 100000
if "total_timesteps=" in request.code:
try:
# Very naive parsing to make progress bar sort of work
part = request.code.split("total_timesteps=")[1].split(")")[0].split(",")[0]
total_timesteps_guess = int(part)
except: pass
training_jobs[job_id] = {
"status": "queued",
"config": {"env_name": request.env_name}, # Kept for compatibility
"total_timesteps_guess": total_timesteps_guess,
"metrics": {
"timesteps": 0, "episodes": 0, "progress": 0,
"episode_rewards": [], "episode_lengths": [],
"current_episode_reward": 0, "mean_reward": 0, "std_reward": 0,
"eval_mean_reward": None, "eval_std_reward": None, "logs": [],
},
"video_buffer": [],
"results": None, "error": None, "start_time": None,
}
background_tasks.add_task(run_custom_code, job_id, request.code, request.env_name)
return {"message": "Started", "job_id": job_id}
@app.get("/train/{job_id}/status")
def get_training_status(job_id: str):
"""Get full training status with metrics"""
job = training_jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
elapsed_time = 0
if job.get("start_time"):
elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
return {
"status": job["status"],
"metrics": job["metrics"],
"elapsed_time": elapsed_time,
"results": job["results"],
"error": job["error"],
}
@app.get("/train/{job_id}/metrics")
def get_training_metrics(job_id: str):
"""Lightweight endpoint for polling metrics"""
job = training_jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
elapsed_time = 0
if job.get("start_time"):
elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
return {
"status": job["status"],
"metrics": job["metrics"],
"elapsed_time": elapsed_time,
}
@app.post("/train/{job_id}/stop")
def stop_training(job_id: str):
"""Stop a training job"""
job = training_jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job["status"] == "training":
job["status"] = "stopped"
# Attempt to save whatever video we have so far
video_path = save_video_from_buffer(job_id)
if job["results"] is None:
job["results"] = {}
job["results"]["video_path"] = video_path
job["metrics"]["logs"].append(
f"[{datetime.now().strftime('%H:%M:%S')}] Training stopped by user"
)
return {"message": "Training stopped successfully!"}
else:
raise HTTPException(status_code=400, detail="Job is not currently training")
@app.get("/download/{job_id}/video")
def download_video(job_id: str):
job = training_jobs.get(job_id)
if not job or not job.get("results") or not job["results"].get("video_path"):
raise HTTPException(status_code=404, detail="Video not processed yet")
path = job["results"]["video_path"]
if not os.path.exists(path):
raise HTTPException(status_code=404, detail="Video file not found")
return FileResponse(path, media_type='video/mp4', filename=f"training_replay_{job_id}.mp4")
@app.get("/download/{job_id}/model")
def download_model(job_id: str):
"""Download model with filesystem fallback"""
# 1. Try memory
job = training_jobs.get(job_id)
file_path = None
if job and job.get("results"):
file_path = job["results"].get("model_path")
if not file_path or not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Model file not found")
return FileResponse(file_path, media_type='application/octet-stream', filename=f"model_{job_id}.zip")
# WebSocket Endpoint
@app.websocket("/ws/render/{job_id}")
async def websocket_render_endpoint(websocket: WebSocket, job_id: str):
"""
WebSocket endpoint for real-time environment rendering.
Connect from frontend with: ws://localhost:8000/ws/render/{job_id}
"""
await manager.connect(job_id, websocket)
try:
while True:
# Keep connection alive and handle messages
data = await websocket.receive_text()
if data == "request_frame":
await manager.broadcast_frame(job_id)
elif data == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
manager.disconnect(job_id, websocket)
except Exception as e:
logger.error(f"[WS] WebSocket error for job {job_id}: {e}")
manager.disconnect(job_id, websocket)
@app.get("/debug/jobs")
def debug_jobs():
"""Debug endpoint to list all jobs"""
return {
"jobs": [
{
"job_id": job_id,
"status": job["status"],
"progress": job["metrics"]["progress"],
"episodes": job["metrics"]["episodes"],
}
for job_id, job in training_jobs.items()
]
}