bumie-e commited on
Commit
d84d915
·
1 Parent(s): de3f783

Added support for dynamic code execution

Browse files
Files changed (1) hide show
  1. app.py +105 -251
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import base64
4
  import numpy as np
@@ -7,20 +8,19 @@ from pydantic import BaseModel
7
  from typing import Dict, Any, List, Optional
8
  import uuid
9
  import gymnasium as gym
10
- from stable_baselines3 import PPO
11
  from stable_baselines3.common.monitor import Monitor
12
  from stable_baselines3.common.evaluation import evaluate_policy
13
  from stable_baselines3.common.callbacks import BaseCallback
14
  from datetime import datetime
15
  import asyncio
16
  import os
 
17
  import logging
18
  from io import BytesIO
19
  from PIL import Image
20
-
21
- # Add to imports in app.py
22
- from fastapi.responses import FileResponse
23
  import imageio
 
24
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO)
@@ -28,7 +28,6 @@ logger = logging.getLogger(__name__)
28
 
29
  app = FastAPI()
30
 
31
- # Add CORS middleware
32
  app.add_middleware(
33
  CORSMiddleware,
34
  allow_origins=["*"],
@@ -37,19 +36,17 @@ app.add_middleware(
37
  allow_headers=["*"],
38
  )
39
 
40
- # In-memory storage for training jobs
 
 
41
  training_jobs: Dict[str, Dict[str, Any]] = {}
42
 
43
- class TrainingJob(BaseModel):
44
- env_name: str = "CartPole-v1"
45
- total_timesteps: int = 100000
46
- learning_rate: float = 0.001
47
- n_steps: int = 2048
48
- batch_size: int = 64
49
- n_epochs: int = 10
50
 
 
51
  class ConnectionManager:
52
- """Manages WebSocket connections and frame broadcasting"""
53
  def __init__(self):
54
  self.active_connections: Dict[str, List[WebSocket]] = {}
55
  self.frames: Dict[str, deque] = {}
@@ -60,7 +57,6 @@ class ConnectionManager:
60
  self.active_connections[job_id] = []
61
  self.frames[job_id] = deque(maxlen=1)
62
  self.active_connections[job_id].append(websocket)
63
- logger.info(f"[WS] Client connected to job {job_id}")
64
 
65
  def disconnect(self, job_id: str, websocket: WebSocket):
66
  if job_id in self.active_connections:
@@ -69,328 +65,186 @@ class ConnectionManager:
69
  del self.active_connections[job_id]
70
  if job_id in self.frames:
71
  del self.frames[job_id]
72
- logger.info(f"[WS] Client disconnected from job {job_id}")
73
 
74
  def add_frame(self, job_id: str, frame: np.ndarray):
75
- """Store the latest frame for this job"""
76
  if job_id not in self.frames:
77
  self.frames[job_id] = deque(maxlen=1)
78
  self.frames[job_id].append(frame)
79
 
80
  async def broadcast_frame(self, job_id: str):
81
- """Broadcast the latest frame to all connected clients"""
82
- if job_id not in self.frames or not self.frames[job_id]:
83
- return
84
-
85
  frame = self.frames[job_id][-1]
86
-
87
  try:
88
- # Convert numpy array to PIL Image for encoding
89
  if isinstance(frame, np.ndarray):
90
- # Handle different frame formats
91
- if frame.dtype != np.uint8:
92
- frame = np.clip(frame * 255, 0, 255).astype(np.uint8)
93
-
94
- # Convert BGR to RGB if needed
95
- if len(frame.shape) == 3 and frame.shape[2] == 3:
96
- # Assuming BGR from gym, convert to RGB
97
- frame = frame[:, :, ::-1] # BGR to RGB
98
-
99
  img = Image.fromarray(frame)
100
- else:
101
- logger.error(f"[ENCODE] Unexpected frame type: {type(frame)}")
102
- return
103
 
104
- # Resize for efficient transmission
105
  max_size = 512
106
  if img.width > max_size or img.height > max_size:
107
  ratio = max_size / max(img.width, img.height)
108
- new_size = (int(img.width * ratio), int(img.height * ratio))
109
- img = img.resize(new_size, Image.Resampling.LANCZOS)
110
 
111
- # Encode to JPEG
112
  buffer = BytesIO()
113
  img.save(buffer, format='JPEG', quality=85)
114
- frame_bytes = buffer.getvalue()
115
- frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
116
 
117
- # Broadcast to all connected clients
118
  if job_id in self.active_connections:
119
- disconnected = []
120
  for connection in self.active_connections[job_id]:
121
- try:
122
- await connection.send_json({
123
- "type": "frame",
124
- "job_id": job_id,
125
- "data": frame_base64,
126
- "timestamp": datetime.now().isoformat()
127
- })
128
- except Exception as e:
129
- logger.error(f"[WS] Failed to send frame: {e}")
130
- disconnected.append(connection)
131
-
132
- # Remove disconnected clients
133
- for conn in disconnected:
134
- self.disconnect(job_id, conn)
135
-
136
- except Exception as e:
137
- logger.error(f"[ENCODE] Failed to encode frame: {e}")
138
 
139
  manager = ConnectionManager()
140
 
 
141
  class MetricsCallback(BaseCallback):
142
- """Custom callback to track training metrics in real-time"""
143
- def __init__(self, job_id: str, render_freq: int = 5):
144
  super().__init__()
145
  self.job_id = job_id
146
  self.episode_count = 0
147
- self.step_count = 0
148
  self.render_freq = render_freq
149
 
150
  def _on_step(self) -> bool:
151
  job = training_jobs.get(self.job_id)
152
- # FIX: Check if job exists OR if status is marked as stopped
153
- if not job or job["status"] == "stopped":
154
- logger.info(f"[CALLBACK] Stopping job {self.job_id}")
155
- return False # Returning False in SB3 stops the training immediately
156
-
157
- # if not job:
158
- # return False
159
-
160
- self.step_count += 1
161
-
162
- # Update timestep count
163
  job["metrics"]["timesteps"] = self.num_timesteps
164
- job["metrics"]["progress"] = int(
165
- (self.num_timesteps / job["config"]["total_timesteps"]) * 100
166
- )
167
 
168
- # Render frame periodically
169
- # if self.step_count % self.render_freq == 0:
170
- # try:
171
- # frame = self.model.get_env().render()
172
- # if frame is not None and isinstance(frame, np.ndarray):
173
- # manager.add_frame(self.job_id, frame)
174
- # except Exception as e:
175
- # logger.debug(f"[RENDER] Render not available: {e}")
176
-
177
- # RENDER & RECORD
178
- # We process frames at the render frequency
179
- if self.step_count % self.render_freq == 0:
180
  try:
181
- # Capture frame
182
  frame = self.model.get_env().render()
183
-
184
  if frame is not None and isinstance(frame, np.ndarray):
185
- # 1. Send to WebSocket for live view
186
  manager.add_frame(self.job_id, frame)
187
-
188
- # 2. Store in memory for video download
189
- # We skip every other captured frame to keep video file size manageable
190
- # (Capturing effectively at render_freq * 2)
191
- if len(job["video_buffer"]) < 2000: # Safety cap: max 2000 frames to prevent RAM overflow
192
- job["video_buffer"].append(frame)
193
- except Exception as e:
194
- logger.debug(f"[RENDER] Render error: {e}")
195
 
196
- # Check for episode completion
197
  if self.locals.get("dones", [False])[0]:
198
  if "infos" in self.locals and len(self.locals["infos"]) > 0:
199
  info = self.locals["infos"][0]
200
  if "episode" in info:
201
  self.episode_count += 1
202
  ep_reward = float(info["episode"]["r"])
203
- ep_length = int(info["episode"]["l"])
204
-
205
  job["metrics"]["episodes"] = self.episode_count
206
  job["metrics"]["episode_rewards"].append(ep_reward)
207
- job["metrics"]["episode_lengths"].append(ep_length)
208
  job["metrics"]["current_episode_reward"] = ep_reward
209
 
210
- # Calculate running average
211
  if len(job["metrics"]["episode_rewards"]) > 0:
212
- job["metrics"]["mean_reward"] = float(
213
- np.mean(job["metrics"]["episode_rewards"][-100:])
214
- )
215
- job["metrics"]["std_reward"] = float(
216
- np.std(job["metrics"]["episode_rewards"][-100:])
217
- )
218
 
219
- # Add log entry
220
- log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}, length = {ep_length}"
221
  job["metrics"]["logs"].append(log_entry)
222
- if len(job["metrics"]["logs"]) > 100:
223
- job["metrics"]["logs"].pop(0)
224
-
225
  return True
226
- def save_video_from_buffer(job_id: str):
227
- """Helper to compile stored frames into MP4"""
228
  job = training_jobs.get(job_id)
229
- if not job or not job["video_buffer"]:
230
- return None
231
-
232
  try:
233
- env_name = job["config"]["env_name"]
234
  video_path = f"models/{env_name}_replay_{job_id}.mp4"
235
- os.makedirs("models", exist_ok=True)
236
-
237
- # Save video at 30 FPS
238
- logger.info(f"[VIDEO] Saving {len(job['video_buffer'])} frames to {video_path}")
239
  imageio.mimsave(video_path, job['video_buffer'], fps=30)
240
-
241
- # Clear buffer to free memory
242
  job["video_buffer"] = []
243
  return video_path
244
- except Exception as e:
245
- logger.error(f"[VIDEO] Failed to save video: {e}")
246
- return None
247
-
248
- def run_training(job_id: str, config: Dict[str, Any]):
249
- """Run the RL training loop with rendering"""
250
- logger.info(f"[TRAIN] Starting training for job {job_id}")
251
  training_jobs[job_id]["status"] = "training"
252
  training_jobs[job_id]["start_time"] = datetime.now()
253
 
254
- env = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  try:
256
- env_name = config.get("env_name", "CartPole-v1")
257
- total_timesteps = config.get("total_timesteps", 100000)
258
- learning_rate = config.get("learning_rate", 0.001)
259
- n_steps = config.get("n_steps", 2048)
260
- batch_size = config.get("batch_size", 64)
261
- n_epochs = config.get("n_epochs", 10)
262
 
263
- # Initialize environment with rgb_array rendering
264
- logger.info(f"[TRAIN] Creating environment: {env_name}")
265
- env = gym.make(env_name, render_mode='rgb_array')
266
- env = Monitor(env)
267
 
268
- # Initialize model
269
- logger.info(f"[TRAIN] Creating PPO model")
270
- model = PPO(
271
- "MlpPolicy",
272
- env,
273
- verbose=0,
274
- learning_rate=learning_rate,
275
- n_steps=n_steps,
276
- batch_size=batch_size,
277
- n_epochs=n_epochs,
278
- )
279
 
280
- # Add initial logs
281
- training_jobs[job_id]["metrics"]["logs"].append(
282
- f"[{datetime.now().strftime('%H:%M:%S')}] Environment: {env_name}"
283
- )
284
- training_jobs[job_id]["metrics"]["logs"].append(
285
- f"[{datetime.now().strftime('%H:%M:%S')}] Total timesteps: {total_timesteps:,}"
286
- )
287
- training_jobs[job_id]["metrics"]["logs"].append(
288
- f"[{datetime.now().strftime('%H:%M:%S')}] Starting training..."
289
- )
290
-
291
- # Train with callback
292
- logger.info(f"[TRAIN] Starting learning loop")
293
- model.learn(
294
- total_timesteps=total_timesteps,
295
- callback=MetricsCallback(job_id, render_freq=5),
296
- )
297
-
298
- # Evaluate
299
- logger.info(f"[TRAIN] Evaluating model")
300
- training_jobs[job_id]["metrics"]["logs"].append(
301
- f"[{datetime.now().strftime('%H:%M:%S')}] Training completed! Evaluating..."
302
- )
303
- mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
304
- training_jobs[job_id]["metrics"]["eval_mean_reward"] = float(mean_reward)
305
- training_jobs[job_id]["metrics"]["eval_std_reward"] = float(std_reward)
306
-
307
- # Save model
308
- model_path = f"models/{env_name}_ppo_{job_id}"
309
- os.makedirs("models", exist_ok=True)
310
- model.save(model_path)
311
-
312
- # --- NEW: SAVE VIDEO FROM BUFFER ---
313
- video_path = save_video_from_buffer(job_id)
314
- # -----------------------------------
315
- training_jobs[job_id]["metrics"]["logs"].append(
316
- f"[{datetime.now().strftime('%H:%M:%S')}] Model saved!"
317
- )
318
 
319
- # Store results
320
  training_jobs[job_id]["status"] = "completed"
321
  training_jobs[job_id]["results"] = {
322
- "mean_reward": mean_reward,
323
- "std_reward": std_reward,
324
- "model_path": f"{model_path}.zip",
325
- "video_path": video_path, # Add this to results
326
  "total_episodes": training_jobs[job_id]["metrics"]["episodes"],
327
- "total_timesteps": total_timesteps,
328
  }
329
  training_jobs[job_id]["metrics"]["progress"] = 100
330
 
331
- logger.info(f"[TRAIN] Training completed for job {job_id}")
332
-
333
  except Exception as e:
 
 
334
  training_jobs[job_id]["status"] = "failed"
335
  training_jobs[job_id]["error"] = str(e)
336
- training_jobs[job_id]["metrics"]["logs"].append(
337
- f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {str(e)}"
338
- )
339
- logger.error(f"[TRAIN] Training failed for job {job_id}: {e}", exc_info=True)
340
-
341
- finally:
342
- if env:
343
- try:
344
- env.close()
345
- except:
346
- pass
347
-
348
- # REST Endpoints
349
-
350
- @app.get("/")
351
- def read_root():
352
- return {"message": "Welcome to the RL Training API!"}
353
 
354
  @app.post("/train")
355
- def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
356
- """Start a new training job"""
357
  job_id = str(uuid.uuid4())
358
 
 
 
 
 
 
 
 
 
 
359
  training_jobs[job_id] = {
360
  "status": "queued",
361
- "config": {
362
- "env_name": job.env_name,
363
- "total_timesteps": job.total_timesteps,
364
- "learning_rate": job.learning_rate,
365
- "n_steps": job.n_steps,
366
- "batch_size": job.batch_size,
367
- "n_epochs": job.n_epochs,
368
- },
369
  "metrics": {
370
- "timesteps": 0,
371
- "episodes": 0,
372
- "progress": 0,
373
- "episode_rewards": [],
374
- "episode_lengths": [],
375
- "current_episode_reward": 0,
376
- "mean_reward": 0,
377
- "std_reward": 0,
378
- "eval_mean_reward": None,
379
- "eval_std_reward": None,
380
- "logs": [],
381
  },
382
- "video_buffer": [], # <--- NEW: Initialize empty buffer
383
- "results": None,
384
- "error": None,
385
- "start_time": None,
386
  }
387
 
388
- background_tasks.add_task(run_training, job_id, training_jobs[job_id]["config"])
389
-
390
- return {
391
- "message": "Training job started successfully!",
392
- "job_id": job_id,
393
- }
394
 
395
  @app.get("/train/{job_id}/status")
396
  def get_training_status(job_id: str):
 
1
  from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect
2
+ from fastapi.responses import FileResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import base64
5
  import numpy as np
 
8
  from typing import Dict, Any, List, Optional
9
  import uuid
10
  import gymnasium as gym
11
+ from stable_baselines3 import PPO, DQN, A2C # Added common algos
12
  from stable_baselines3.common.monitor import Monitor
13
  from stable_baselines3.common.evaluation import evaluate_policy
14
  from stable_baselines3.common.callbacks import BaseCallback
15
  from datetime import datetime
16
  import asyncio
17
  import os
18
+ import glob
19
  import logging
20
  from io import BytesIO
21
  from PIL import Image
 
 
 
22
  import imageio
23
+ import traceback
24
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO)
 
28
 
29
  app = FastAPI()
30
 
 
31
  app.add_middleware(
32
  CORSMiddleware,
33
  allow_origins=["*"],
 
36
  allow_headers=["*"],
37
  )
38
 
39
+ os.makedirs("models", exist_ok=True)
40
+
41
+ # In-memory storage
42
  training_jobs: Dict[str, Dict[str, Any]] = {}
43
 
44
+ class TrainingRequest(BaseModel):
45
+ env_name: str
46
+ code: str # <--- WE NOW ACCEPT RAW CODE
 
 
 
 
47
 
48
+ # --- WEBSOCKET MANAGER (Unchanged) ---
49
  class ConnectionManager:
 
50
  def __init__(self):
51
  self.active_connections: Dict[str, List[WebSocket]] = {}
52
  self.frames: Dict[str, deque] = {}
 
57
  self.active_connections[job_id] = []
58
  self.frames[job_id] = deque(maxlen=1)
59
  self.active_connections[job_id].append(websocket)
 
60
 
61
  def disconnect(self, job_id: str, websocket: WebSocket):
62
  if job_id in self.active_connections:
 
65
  del self.active_connections[job_id]
66
  if job_id in self.frames:
67
  del self.frames[job_id]
 
68
 
69
  def add_frame(self, job_id: str, frame: np.ndarray):
 
70
  if job_id not in self.frames:
71
  self.frames[job_id] = deque(maxlen=1)
72
  self.frames[job_id].append(frame)
73
 
74
  async def broadcast_frame(self, job_id: str):
75
+ if job_id not in self.frames or not self.frames[job_id]: return
 
 
 
76
  frame = self.frames[job_id][-1]
 
77
  try:
 
78
  if isinstance(frame, np.ndarray):
79
+ if frame.dtype != np.uint8: frame = np.clip(frame * 255, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
80
  img = Image.fromarray(frame)
81
+ else: return
 
 
82
 
 
83
  max_size = 512
84
  if img.width > max_size or img.height > max_size:
85
  ratio = max_size / max(img.width, img.height)
86
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.Resampling.LANCZOS)
 
87
 
 
88
  buffer = BytesIO()
89
  img.save(buffer, format='JPEG', quality=85)
90
+ frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
 
91
 
 
92
  if job_id in self.active_connections:
 
93
  for connection in self.active_connections[job_id]:
94
+ try: await connection.send_json({"type": "frame", "job_id": job_id, "data": frame_base64})
95
+ except: pass
96
+ except Exception: pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  manager = ConnectionManager()
99
 
100
+ # --- CALLBACK (Modified for Generic Use) ---
101
  class MetricsCallback(BaseCallback):
102
+ def __init__(self, job_id: str, render_freq: int = 4):
 
103
  super().__init__()
104
  self.job_id = job_id
105
  self.episode_count = 0
 
106
  self.render_freq = render_freq
107
 
108
  def _on_step(self) -> bool:
109
  job = training_jobs.get(self.job_id)
110
+ if not job or job["status"] == "stopped": return False
111
+
112
+ # Update metrics
 
 
 
 
 
 
 
 
113
  job["metrics"]["timesteps"] = self.num_timesteps
 
 
 
114
 
115
+ # We try to guess total timesteps if user set it, otherwise just show progress
116
+ total = job.get("total_timesteps_guess", 100000)
117
+ job["metrics"]["progress"] = min(100, int((self.num_timesteps / total) * 100))
118
+
119
+ # Render
120
+ if self.num_timesteps % self.render_freq == 0:
 
 
 
 
 
 
121
  try:
 
122
  frame = self.model.get_env().render()
 
123
  if frame is not None and isinstance(frame, np.ndarray):
 
124
  manager.add_frame(self.job_id, frame)
125
+ if len(job["video_buffer"]) < 2000: job["video_buffer"].append(frame)
126
+ except: pass
 
 
 
 
 
 
127
 
128
+ # Episode handling
129
  if self.locals.get("dones", [False])[0]:
130
  if "infos" in self.locals and len(self.locals["infos"]) > 0:
131
  info = self.locals["infos"][0]
132
  if "episode" in info:
133
  self.episode_count += 1
134
  ep_reward = float(info["episode"]["r"])
 
 
135
  job["metrics"]["episodes"] = self.episode_count
136
  job["metrics"]["episode_rewards"].append(ep_reward)
 
137
  job["metrics"]["current_episode_reward"] = ep_reward
138
 
 
139
  if len(job["metrics"]["episode_rewards"]) > 0:
140
+ job["metrics"]["mean_reward"] = float(np.mean(job["metrics"]["episode_rewards"][-100:]))
141
+ job["metrics"]["std_reward"] = float(np.std(job["metrics"]["episode_rewards"][-100:]))
 
 
 
 
142
 
143
+ log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}"
 
144
  job["metrics"]["logs"].append(log_entry)
145
+ if len(job["metrics"]["logs"]) > 100: job["metrics"]["logs"].pop(0)
 
 
146
  return True
147
+
148
+ def save_video_from_buffer(job_id: str, env_name="env"):
149
  job = training_jobs.get(job_id)
150
+ if not job or not job["video_buffer"]: return None
 
 
151
  try:
 
152
  video_path = f"models/{env_name}_replay_{job_id}.mp4"
 
 
 
 
153
  imageio.mimsave(video_path, job['video_buffer'], fps=30)
 
 
154
  job["video_buffer"] = []
155
  return video_path
156
+ except: return None
157
+
158
+ # --- DYNAMIC EXECUTION ENGINE ---
159
+ def run_custom_code(job_id: str, code: str, env_name: str):
160
+ logger.info(f"[EXEC] Starting job {job_id}")
 
 
161
  training_jobs[job_id]["status"] = "training"
162
  training_jobs[job_id]["start_time"] = datetime.now()
163
 
164
+ # 1. Define a specific Callback class for THIS job
165
+ # The user code will simply call `StreamCallback()`
166
+ class StreamCallback(MetricsCallback):
167
+ def __init__(self, render_freq=4):
168
+ super().__init__(job_id, render_freq)
169
+
170
+ # 2. Setup the execution scope (Variables available to user script)
171
+ # We inject 'StreamCallback' so the user can pass it to .learn()
172
+ local_scope = {
173
+ "gym": gym,
174
+ "PPO": PPO,
175
+ "DQN": DQN,
176
+ "A2C": A2C,
177
+ "evaluate_policy": evaluate_policy,
178
+ "Monitor": Monitor,
179
+ "np": np,
180
+ "StreamCallback": StreamCallback, # <--- CRITICAL INJECTION
181
+ "model_save_path": f"models/model_{job_id}", # User should use this path
182
+ }
183
+
184
  try:
185
+ # 3. EXECUTE USER CODE
186
+ # WARNING: This is dangerous in production (RCE).
187
+ exec(code, local_scope)
 
 
 
188
 
189
+ # 4. Post-Execution Cleanup
190
+ # We look for variables the user might have set in local_scope to save results
 
 
191
 
192
+ # Save video
193
+ video_path = save_video_from_buffer(job_id, env_name)
 
 
 
 
 
 
 
 
 
194
 
195
+ # Check if model file exists (User should have used model_save_path)
196
+ expected_model_path = f"models/model_{job_id}.zip"
197
+ # final_model_path = expected_model_path if os.path.exists(expected_model_path) else None
198
+
199
+ # Check if user put results in a 'results' variable
200
+ user_results = local_scope.get("results", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
 
202
  training_jobs[job_id]["status"] = "completed"
203
  training_jobs[job_id]["results"] = {
204
+ "mean_reward": user_results.get("mean_reward", 0),
205
+ "std_reward": user_results.get("std_reward", 0),
206
+ "model_path": expected_model_path, # We enforce this naming convention
207
+ "video_path": video_path,
208
  "total_episodes": training_jobs[job_id]["metrics"]["episodes"],
 
209
  }
210
  training_jobs[job_id]["metrics"]["progress"] = 100
211
 
 
 
212
  except Exception as e:
213
+ error_msg = traceback.format_exc()
214
+ logger.error(f"[EXEC] Error in job {job_id}: {error_msg}")
215
  training_jobs[job_id]["status"] = "failed"
216
  training_jobs[job_id]["error"] = str(e)
217
+ training_jobs[job_id]["metrics"]["logs"].append(f"ERROR: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  @app.post("/train")
220
+ def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
 
221
  job_id = str(uuid.uuid4())
222
 
223
+ # Basic guess of timesteps for progress bar (parsing strings is hard, defaulting)
224
+ total_timesteps_guess = 100000
225
+ if "total_timesteps=" in request.code:
226
+ try:
227
+ # Very naive parsing to make progress bar sort of work
228
+ part = request.code.split("total_timesteps=")[1].split(")")[0].split(",")[0]
229
+ total_timesteps_guess = int(part)
230
+ except: pass
231
+
232
  training_jobs[job_id] = {
233
  "status": "queued",
234
+ "config": {"env_name": request.env_name}, # Kept for compatibility
235
+ "total_timesteps_guess": total_timesteps_guess,
 
 
 
 
 
 
236
  "metrics": {
237
+ "timesteps": 0, "episodes": 0, "progress": 0,
238
+ "episode_rewards": [], "episode_lengths": [],
239
+ "current_episode_reward": 0, "mean_reward": 0, "std_reward": 0,
240
+ "eval_mean_reward": None, "eval_std_reward": None, "logs": [],
 
 
 
 
 
 
 
241
  },
242
+ "video_buffer": [],
243
+ "results": None, "error": None, "start_time": None,
 
 
244
  }
245
 
246
+ background_tasks.add_task(run_custom_code, job_id, request.code, request.env_name)
247
+ return {"message": "Started", "job_id": job_id}
 
 
 
 
248
 
249
  @app.get("/train/{job_id}/status")
250
  def get_training_status(job_id: str):