bumie-e commited on
Commit
a91c2bd
·
1 Parent(s): 779ee0c

Updated render implementation

Browse files
Files changed (1) hide show
  1. app.py +192 -408
app.py CHANGED
@@ -1,10 +1,13 @@
1
- from fastapi import FastAPI, BackgroundTasks, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
3
  from pydantic import BaseModel
4
- from typing import Dict, Any, List
5
  import uuid
6
  import threading
7
- import numpy as np
8
  import gymnasium as gym
9
  from stable_baselines3 import PPO
10
  from stable_baselines3.common.monitor import Monitor
@@ -12,6 +15,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
12
  from stable_baselines3.common.callbacks import BaseCallback
13
  from datetime import datetime
14
  import asyncio
 
 
15
 
16
  app = FastAPI()
17
 
@@ -28,7 +33,6 @@ app.add_middleware(
28
  training_jobs: Dict[str, Dict[str, Any]] = {}
29
 
30
  class TrainingJob(BaseModel):
31
- code: str
32
  env_name: str = "CartPole-v1"
33
  total_timesteps: int = 100000
34
  learning_rate: float = 0.001
@@ -36,24 +40,117 @@ class TrainingJob(BaseModel):
36
  batch_size: int = 64
37
  n_epochs: int = 10
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class MetricsCallback(BaseCallback):
40
  """Custom callback to track training metrics in real-time"""
41
- def __init__(self, job_id: str):
42
  super().__init__()
43
  self.job_id = job_id
44
  self.episode_count = 0
45
-
 
 
 
46
  def _on_step(self) -> bool:
47
  job = training_jobs.get(self.job_id)
48
  if not job:
49
  return False
50
-
 
 
51
  # Update timestep count
52
  job["metrics"]["timesteps"] = self.num_timesteps
53
  job["metrics"]["progress"] = int(
54
  (self.num_timesteps / job["config"]["total_timesteps"]) * 100
55
  )
56
-
 
 
 
 
 
 
 
 
 
57
  # Check for episode completion
58
  if self.locals.get("dones", [False])[0]:
59
  if "infos" in self.locals and len(self.locals["infos"]) > 0:
@@ -62,12 +159,12 @@ class MetricsCallback(BaseCallback):
62
  self.episode_count += 1
63
  ep_reward = float(info["episode"]["r"])
64
  ep_length = int(info["episode"]["l"])
65
-
66
  job["metrics"]["episodes"] = self.episode_count
67
  job["metrics"]["episode_rewards"].append(ep_reward)
68
  job["metrics"]["episode_lengths"].append(ep_length)
69
  job["metrics"]["current_episode_reward"] = ep_reward
70
-
71
  # Calculate running average
72
  if len(job["metrics"]["episode_rewards"]) > 0:
73
  job["metrics"]["mean_reward"] = float(
@@ -76,23 +173,22 @@ class MetricsCallback(BaseCallback):
76
  job["metrics"]["std_reward"] = float(
77
  np.std(job["metrics"]["episode_rewards"][-100:])
78
  )
79
-
80
  # Add log entry
81
- log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}"
82
  job["metrics"]["logs"].append(log_entry)
83
- if len(job["metrics"]["logs"]) > 50:
84
  job["metrics"]["logs"].pop(0)
85
-
86
  return True
87
 
88
  def run_training(job_id: str, config: Dict[str, Any]):
89
- """
90
- This function runs the training and updates the job status with real-time metrics.
91
- """
92
- print(f"--- Starting Training for job {job_id} ---")
93
  training_jobs[job_id]["status"] = "training"
94
  training_jobs[job_id]["start_time"] = datetime.now()
95
-
 
96
  try:
97
  env_name = config.get("env_name", "CartPole-v1")
98
  total_timesteps = config.get("total_timesteps", 100000)
@@ -100,11 +196,11 @@ def run_training(job_id: str, config: Dict[str, Any]):
100
  n_steps = config.get("n_steps", 2048)
101
  batch_size = config.get("batch_size", 64)
102
  n_epochs = config.get("n_epochs", 10)
103
-
104
- # Initialize environment
105
- env = gym.make(env_name)
106
  env = Monitor(env)
107
-
108
  # Initialize model
109
  model = PPO(
110
  "MlpPolicy",
@@ -115,58 +211,66 @@ def run_training(job_id: str, config: Dict[str, Any]):
115
  batch_size=batch_size,
116
  n_epochs=n_epochs,
117
  )
118
-
119
  # Add initial logs
120
  training_jobs[job_id]["metrics"]["logs"].append(
121
- f"[{datetime.now().strftime('%H:%M:%S')}] Initializing environment: {env_name}"
122
  )
123
  training_jobs[job_id]["metrics"]["logs"].append(
124
- f"[{datetime.now().strftime('%H:%M:%S')}] Creating PPO agent with MlpPolicy..."
125
  )
126
  training_jobs[job_id]["metrics"]["logs"].append(
127
- f"[{datetime.now().strftime('%H:%M:%S')}] Starting training for {total_timesteps:,} timesteps"
128
  )
129
-
130
  # Train with callback
131
  model.learn(
132
  total_timesteps=total_timesteps,
133
- callback=MetricsCallback(job_id),
134
  )
135
-
136
  # Evaluate
137
  training_jobs[job_id]["metrics"]["logs"].append(
138
- f"[{datetime.now().strftime('%H:%M:%S')}] Training completed! Evaluating model..."
139
  )
140
- mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
141
  training_jobs[job_id]["metrics"]["eval_mean_reward"] = float(mean_reward)
142
  training_jobs[job_id]["metrics"]["eval_std_reward"] = float(std_reward)
143
-
144
  # Save model
145
- model.save(f"{env_name}_ppo_{job_id}")
 
 
146
  training_jobs[job_id]["metrics"]["logs"].append(
147
- f"[{datetime.now().strftime('%H:%M:%S')}] Model saved as {env_name}_ppo_{job_id}.zip"
148
  )
149
-
150
  # Store results
151
  training_jobs[job_id]["status"] = "completed"
152
  training_jobs[job_id]["results"] = {
153
  "mean_reward": mean_reward,
154
  "std_reward": std_reward,
155
- "model_path": f"{env_name}_ppo_{job_id}.zip",
156
  "total_episodes": training_jobs[job_id]["metrics"]["episodes"],
157
  "total_timesteps": total_timesteps,
158
  }
159
  training_jobs[job_id]["metrics"]["progress"] = 100
160
-
161
- print(f"--- Training for job {job_id} Finished ---")
162
-
163
  except Exception as e:
164
  training_jobs[job_id]["status"] = "failed"
165
  training_jobs[job_id]["error"] = str(e)
166
  training_jobs[job_id]["metrics"]["logs"].append(
167
  f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {str(e)}"
168
  )
169
- print(f"--- Training for job {job_id} Failed: {e} ---")
 
 
 
 
 
 
170
 
171
  @app.get("/")
172
  def read_root():
@@ -176,8 +280,7 @@ def read_root():
176
  def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
177
  """Start a new training job"""
178
  job_id = str(uuid.uuid4())
179
-
180
- # Initialize the job in our in-memory storage
181
  training_jobs[job_id] = {
182
  "status": "queued",
183
  "config": {
@@ -205,10 +308,9 @@ def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
205
  "error": None,
206
  "start_time": None,
207
  }
208
-
209
- # Start the training in the background
210
  background_tasks.add_task(run_training, job_id, training_jobs[job_id]["config"])
211
-
212
  return {
213
  "message": "Training job started successfully!",
214
  "job_id": job_id,
@@ -216,18 +318,15 @@ def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
216
 
217
  @app.get("/train/{job_id}/status")
218
  def get_training_status(job_id: str):
219
- """
220
- Returns the status and metrics of a training job.
221
- """
222
  job = training_jobs.get(job_id)
223
  if not job:
224
  raise HTTPException(status_code=404, detail="Job not found")
225
-
226
- # Calculate elapsed time
227
  elapsed_time = 0
228
  if job.get("start_time"):
229
  elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
230
-
231
  return {
232
  "status": job["status"],
233
  "metrics": job["metrics"],
@@ -238,19 +337,15 @@ def get_training_status(job_id: str):
238
 
239
  @app.get("/train/{job_id}/metrics")
240
  def get_training_metrics(job_id: str):
241
- """
242
- Returns only the metrics of a training job (lightweight endpoint for polling).
243
- """
244
  job = training_jobs.get(job_id)
245
  if not job:
246
- print(f"DEBUG: Job {job_id} not found in training_jobs")
247
- print(f"DEBUG: Available jobs: {list(training_jobs.keys())}")
248
- raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
249
-
250
  elapsed_time = 0
251
  if job.get("start_time"):
252
  elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
253
-
254
  return {
255
  "status": job["status"],
256
  "metrics": job["metrics"],
@@ -259,13 +354,11 @@ def get_training_metrics(job_id: str):
259
 
260
  @app.post("/train/{job_id}/stop")
261
  def stop_training(job_id: str):
262
- """
263
- Stop a training job.
264
- """
265
  job = training_jobs.get(job_id)
266
  if not job:
267
  raise HTTPException(status_code=404, detail="Job not found")
268
-
269
  if job["status"] == "training":
270
  job["status"] = "stopped"
271
  job["metrics"]["logs"].append(
@@ -274,351 +367,42 @@ def stop_training(job_id: str):
274
  return {"message": "Training stopped successfully!"}
275
  else:
276
  raise HTTPException(status_code=400, detail="Job is not currently training")
277
- @app.get("/debug")
278
- def debug():
279
- return {"jobs": list(training_jobs.keys())}
280
- # from fastapi import FastAPI, BackgroundTasks, HTTPException
281
- # from fastapi.middleware.cors import CORSMiddleware
282
- # from pydantic import BaseModel
283
- # from typing import Dict, Any, List
284
- # import uuid
285
- # import threading
286
- # import numpy as np
287
- # import gymnasium as gym
288
- # from stable_baselines3 import PPO
289
- # from stable_baselines3.common.monitor import Monitor
290
- # from stable_baselines3.common.evaluation import evaluate_policy
291
- # from stable_baselines3.common.callbacks import BaseCallback
292
- # from datetime import datetime
293
- # import asyncio
294
-
295
- # app = FastAPI()
296
-
297
- # # Add CORS middleware
298
- # app.add_middleware(
299
- # CORSMiddleware,
300
- # allow_origins=["*"],
301
- # allow_credentials=True,
302
- # allow_methods=["*"],
303
- # allow_headers=["*"],
304
- # )
305
-
306
- # # In-memory storage for training jobs
307
- # training_jobs: Dict[str, Dict[str, Any]] = {}
308
-
309
- # class TrainingJob(BaseModel):
310
- # code: str
311
- # env_name: str = "CartPole-v1"
312
- # total_timesteps: int = 100000
313
- # learning_rate: float = 0.001
314
- # n_steps: int = 2048
315
- # batch_size: int = 64
316
- # n_epochs: int = 10
317
-
318
- # class MetricsCallback(BaseCallback):
319
- # """Custom callback to track training metrics in real-time"""
320
- # def __init__(self, job_id: str):
321
- # super().__init__()
322
- # self.job_id = job_id
323
- # self.episode_count = 0
324
-
325
- # def _on_step(self) -> bool:
326
- # job = training_jobs.get(self.job_id)
327
- # if not job:
328
- # return False
329
-
330
- # # Update timestep count
331
- # job["metrics"]["timesteps"] = self.num_timesteps
332
- # job["metrics"]["progress"] = int(
333
- # (self.num_timesteps / job["config"]["total_timesteps"]) * 100
334
- # )
335
-
336
- # # Check for episode completion
337
- # if self.locals.get("dones", [False])[0]:
338
- # if "infos" in self.locals and len(self.locals["infos"]) > 0:
339
- # info = self.locals["infos"][0]
340
- # if "episode" in info:
341
- # self.episode_count += 1
342
- # ep_reward = float(info["episode"]["r"])
343
- # ep_length = int(info["episode"]["l"])
344
-
345
- # job["metrics"]["episodes"] = self.episode_count
346
- # job["metrics"]["episode_rewards"].append(ep_reward)
347
- # job["metrics"]["episode_lengths"].append(ep_length)
348
- # job["metrics"]["current_episode_reward"] = ep_reward
349
-
350
- # # Calculate running average
351
- # if len(job["metrics"]["episode_rewards"]) > 0:
352
- # job["metrics"]["mean_reward"] = float(
353
- # np.mean(job["metrics"]["episode_rewards"][-100:])
354
- # )
355
- # job["metrics"]["std_reward"] = float(
356
- # np.std(job["metrics"]["episode_rewards"][-100:])
357
- # )
358
-
359
- # # Add log entry
360
- # log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}"
361
- # job["metrics"]["logs"].append(log_entry)
362
- # if len(job["metrics"]["logs"]) > 50:
363
- # job["metrics"]["logs"].pop(0)
364
-
365
- # return True
366
-
367
- # def run_training(job_id: str, config: Dict[str, Any]):
368
- # """
369
- # This function runs the training and updates the job status with real-time metrics.
370
- # """
371
- # print(f"--- Starting Training for job {job_id} ---")
372
- # training_jobs[job_id]["status"] = "training"
373
- # training_jobs[job_id]["start_time"] = datetime.now()
374
-
375
- # try:
376
- # env_name = config.get("env_name", "CartPole-v1")
377
- # total_timesteps = config.get("total_timesteps", 100000)
378
- # learning_rate = config.get("learning_rate", 0.001)
379
- # n_steps = config.get("n_steps", 2048)
380
- # batch_size = config.get("batch_size", 64)
381
- # n_epochs = config.get("n_epochs", 10)
382
-
383
- # # Initialize environment
384
- # env = gym.make(env_name)
385
- # env = Monitor(env)
386
-
387
- # # Initialize model
388
- # model = PPO(
389
- # "MlpPolicy",
390
- # env,
391
- # verbose=0,
392
- # learning_rate=learning_rate,
393
- # n_steps=n_steps,
394
- # batch_size=batch_size,
395
- # n_epochs=n_epochs,
396
- # )
397
-
398
- # # Add initial logs
399
- # training_jobs[job_id]["metrics"]["logs"].append(
400
- # f"[{datetime.now().strftime('%H:%M:%S')}] Initializing environment: {env_name}"
401
- # )
402
- # training_jobs[job_id]["metrics"]["logs"].append(
403
- # f"[{datetime.now().strftime('%H:%M:%S')}] Creating PPO agent with MlpPolicy..."
404
- # )
405
- # training_jobs[job_id]["metrics"]["logs"].append(
406
- # f"[{datetime.now().strftime('%H:%M:%S')}] Starting training for {total_timesteps:,} timesteps"
407
- # )
408
-
409
- # # Train with callback
410
- # model.learn(
411
- # total_timesteps=total_timesteps,
412
- # callback=MetricsCallback(job_id),
413
- # )
414
-
415
- # # Evaluate
416
- # training_jobs[job_id]["metrics"]["logs"].append(
417
- # f"[{datetime.now().strftime('%H:%M:%S')}] Training completed! Evaluating model..."
418
- # )
419
- # mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
420
- # training_jobs[job_id]["metrics"]["eval_mean_reward"] = float(mean_reward)
421
- # training_jobs[job_id]["metrics"]["eval_std_reward"] = float(std_reward)
422
-
423
- # # Save model
424
- # model.save(f"{env_name}_ppo_{job_id}")
425
- # training_jobs[job_id]["metrics"]["logs"].append(
426
- # f"[{datetime.now().strftime('%H:%M:%S')}] Model saved as {env_name}_ppo_{job_id}.zip"
427
- # )
428
-
429
- # # Store results
430
- # training_jobs[job_id]["status"] = "completed"
431
- # training_jobs[job_id]["results"] = {
432
- # "mean_reward": mean_reward,
433
- # "std_reward": std_reward,
434
- # "model_path": f"{env_name}_ppo_{job_id}.zip",
435
- # "total_episodes": training_jobs[job_id]["metrics"]["episodes"],
436
- # "total_timesteps": total_timesteps,
437
- # }
438
- # training_jobs[job_id]["metrics"]["progress"] = 100
439
-
440
- # print(f"--- Training for job {job_id} Finished ---")
441
 
442
- # except Exception as e:
443
- # training_jobs[job_id]["status"] = "failed"
444
- # training_jobs[job_id]["error"] = str(e)
445
- # training_jobs[job_id]["metrics"]["logs"].append(
446
- # f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {str(e)}"
447
- # )
448
- # print(f"--- Training for job {job_id} Failed: {e} ---")
449
 
450
- # @app.get("/")
451
- # def read_root():
452
- # return {"message": "Welcome to the RL Training API!"}
453
-
454
- # @app.post("/train")
455
- # def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
456
- # """Start a new training job"""
457
- # job_id = str(uuid.uuid4())
458
-
459
- # # Initialize the job in our in-memory storage
460
- # training_jobs[job_id] = {
461
- # "status": "queued",
462
- # "config": {
463
- # "env_name": job.env_name,
464
- # "total_timesteps": job.total_timesteps,
465
- # "learning_rate": job.learning_rate,
466
- # "n_steps": job.n_steps,
467
- # "batch_size": job.batch_size,
468
- # "n_epochs": job.n_epochs,
469
- # },
470
- # "metrics": {
471
- # "timesteps": 0,
472
- # "episodes": 0,
473
- # "progress": 0,
474
- # "episode_rewards": [],
475
- # "episode_lengths": [],
476
- # "current_episode_reward": 0,
477
- # "mean_reward": 0,
478
- # "std_reward": 0,
479
- # "eval_mean_reward": None,
480
- # "eval_std_reward": None,
481
- # "logs": [],
482
- # },
483
- # "results": None,
484
- # "error": None,
485
- # "start_time": None,
486
- # }
487
-
488
- # # Start the training in the background
489
- # background_tasks.add_task(run_training, job_id, training_jobs[job_id]["config"])
490
-
491
- # return {
492
- # "message": "Training job started successfully!",
493
- # "job_id": job_id,
494
- # }
495
-
496
- # @app.get("/train/{job_id}/status")
497
- # def get_training_status(job_id: str):
498
- # """
499
- # Returns the status and metrics of a training job.
500
- # """
501
- # job = training_jobs.get(job_id)
502
- # if not job:
503
- # raise HTTPException(status_code=404, detail="Job not found")
504
-
505
- # # Calculate elapsed time
506
- # elapsed_time = 0
507
- # if job.get("start_time"):
508
- # elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
509
-
510
- # return {
511
- # "status": job["status"],
512
- # "metrics": job["metrics"],
513
- # "elapsed_time": elapsed_time,
514
- # "results": job["results"],
515
- # "error": job["error"],
516
- # }
517
-
518
- # @app.get("/train/{job_id}/metrics")
519
- # def get_training_metrics(job_id: str):
520
- # """
521
- # Returns only the metrics of a training job (lightweight endpoint for polling).
522
- # """
523
- # job = training_jobs.get(job_id)
524
- # if not job:
525
- # raise HTTPException(status_code=404, detail="Job not found")
526
-
527
- # elapsed_time = 0
528
- # if job.get("start_time"):
529
- # elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
530
-
531
- # return {
532
- # "status": job["status"],
533
- # "metrics": job["metrics"],
534
- # "elapsed_time": elapsed_time,
535
- # }
536
-
537
- # @app.post("/train/{job_id}/stop")
538
- # def stop_training(job_id: str):
539
- # """
540
- # Stop a training job.
541
- # """
542
- # job = training_jobs.get(job_id)
543
- # if not job:
544
- # raise HTTPException(status_code=404, detail="Job not found")
545
-
546
- # if job["status"] == "training":
547
- # job["status"] = "stopped"
548
- # job["metrics"]["logs"].append(
549
- # f"[{datetime.now().strftime('%H:%M:%S')}] Training stopped by user"
550
- # )
551
- # return {"message": "Training stopped successfully!"}
552
- # else:
553
- # raise HTTPException(status_code=400, detail="Job is not currently training")
554
-
555
- # # from fastapi import FastAPI, BackgroundTasks, HTTPException
556
- # # from pydantic import BaseModel
557
- # # import os
558
- # # import uuid
559
- # # from typing import Dict, Any
560
-
561
- # # app = FastAPI()
562
-
563
- # # # In-memory storage for training jobs
564
- # # training_jobs: Dict[str, Dict[str, Any]] = {}
565
-
566
- # # # Define the request body for the training job
567
- # # class TrainingJob(BaseModel):
568
- # # code: str
569
-
570
- # # # This is where you'll put your training logic
571
- # # def run_training(job_id: str, user_code: str):
572
- # # """
573
- # # This function runs the user's code and updates the job status.
574
- # # """
575
- # # print(f"--- Starting Training for job {job_id} ---")
576
- # # training_jobs[job_id]["status"] = "training"
577
- # # try:
578
- # # # Create a dictionary to serve as the local namespace for exec
579
- # # local_namespace = {}
580
-
581
- # # # Execute the user's code
582
- # # exec(user_code, {}, local_namespace)
583
-
584
- # # # Assume the user's code stores results in a 'results' dictionary
585
- # # results = local_namespace.get('results', {})
586
-
587
- # # # Store the results and mark the job as completed
588
- # # training_jobs[job_id]["status"] = "completed"
589
- # # training_jobs[job_id]["results"] = results
590
- # # print(f"--- Training for job {job_id} Finished ---")
591
-
592
- # # except Exception as e:
593
- # # # Mark the job as failed and store the error message
594
- # # training_jobs[job_id]["status"] = "failed"
595
- # # training_jobs[job_id]["error"] = str(e)
596
- # # print(f"--- Training for job {job_id} Failed: {e} ---")
597
-
598
- # # @app.get('/')
599
- # # def read_root():
600
- # # return {"message": "Welcome to the Training API!"}
601
-
602
- # # @app.post("/train")
603
- # # def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
604
-
605
- # # # Generate a unique job ID
606
- # # job_id = str(uuid.uuid4())
607
 
608
- # # # Initialize the job in our in-memory storage
609
- # # training_jobs[job_id] = {"status": "queued"}
610
-
611
- # # # Start the training in the background
612
- # # background_tasks.add_task(run_training, job_id, job.code)
613
-
614
- # # return {"message": "Training job started successfully!", "job_id": job_id}
 
 
 
 
 
 
615
 
616
- # # @app.get("/train/{job_id}/status")
617
- # # def get_training_status(job_id: str):
618
- # # """
619
- # # Returns the status and results of a training job.
620
- # # """
621
- # # job = training_jobs.get(job_id)
622
- # # if not job:
623
- # # raise HTTPException(status_code=404, detail="Job not found")
624
- # # return job
 
 
 
 
 
 
1
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ import base64
4
+ import cv2
5
+ import numpy as np
6
+ from collections import deque
7
  from pydantic import BaseModel
8
+ from typing import Dict, Any, List, Optional
9
  import uuid
10
  import threading
 
11
  import gymnasium as gym
12
  from stable_baselines3 import PPO
13
  from stable_baselines3.common.monitor import Monitor
 
15
  from stable_baselines3.common.callbacks import BaseCallback
16
  from datetime import datetime
17
  import asyncio
18
+ import os
19
+ from enum import Enum
20
 
21
  app = FastAPI()
22
 
 
33
  training_jobs: Dict[str, Dict[str, Any]] = {}
34
 
35
  class TrainingJob(BaseModel):
 
36
  env_name: str = "CartPole-v1"
37
  total_timesteps: int = 100000
38
  learning_rate: float = 0.001
 
40
  batch_size: int = 64
41
  n_epochs: int = 10
42
 
43
+ class ConnectionManager:
44
+ """Manages WebSocket connections and frame broadcasting"""
45
+ def __init__(self):
46
+ self.active_connections: Dict[str, List[WebSocket]] = {}
47
+ self.frames: Dict[str, deque] = {}
48
+
49
+ async def connect(self, job_id: str, websocket: WebSocket):
50
+ await websocket.accept()
51
+ if job_id not in self.active_connections:
52
+ self.active_connections[job_id] = []
53
+ self.frames[job_id] = deque(maxlen=1)
54
+ self.active_connections[job_id].append(websocket)
55
+ print(f"[WS] Client connected to job {job_id}")
56
+
57
+ def disconnect(self, job_id: str, websocket: WebSocket):
58
+ if job_id in self.active_connections:
59
+ self.active_connections[job_id].remove(websocket)
60
+ if not self.active_connections[job_id]:
61
+ del self.active_connections[job_id]
62
+ if job_id in self.frames:
63
+ del self.frames[job_id]
64
+ print(f"[WS] Client disconnected from job {job_id}")
65
+
66
+ def add_frame(self, job_id: str, frame: np.ndarray):
67
+ """Store the latest frame for this job"""
68
+ if job_id not in self.frames:
69
+ self.frames[job_id] = deque(maxlen=1)
70
+ self.frames[job_id].append(frame)
71
+
72
+ async def broadcast_frame(self, job_id: str):
73
+ """Broadcast the latest frame to all connected clients"""
74
+ if job_id not in self.frames or not self.frames[job_id]:
75
+ return
76
+
77
+ frame = self.frames[job_id][-1]
78
+
79
+ # Ensure frame is in RGB format (not BGR from cv2)
80
+ if len(frame.shape) == 3 and frame.shape[2] == 3:
81
+ # Assume BGR from gym, convert to RGB
82
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
83
+ else:
84
+ frame_rgb = frame
85
+
86
+ # Resize for efficient transmission (optional)
87
+ height, width = frame_rgb.shape[:2]
88
+ if height > 512 or width > 512:
89
+ scale = 512 / max(height, width)
90
+ new_size = (int(width * scale), int(height * scale))
91
+ frame_rgb = cv2.resize(frame_rgb, new_size, interpolation=cv2.INTER_LINEAR)
92
+
93
+ # Encode to JPEG
94
+ success, buffer = cv2.imencode('.jpg', frame_rgb, [cv2.IMWRITE_JPEG_QUALITY, 85])
95
+ if not success:
96
+ print(f"[ERROR] Failed to encode frame for job {job_id}")
97
+ return
98
+
99
+ frame_base64 = base64.b64encode(buffer).decode('utf-8')
100
+
101
+ # Broadcast to all connected clients
102
+ if job_id in self.active_connections:
103
+ disconnected = []
104
+ for connection in self.active_connections[job_id]:
105
+ try:
106
+ await connection.send_json({
107
+ "type": "frame",
108
+ "job_id": job_id,
109
+ "data": frame_base64,
110
+ "timestamp": datetime.now().isoformat()
111
+ })
112
+ except Exception as e:
113
+ print(f"[ERROR] Failed to send frame: {e}")
114
+ disconnected.append(connection)
115
+
116
+ # Remove disconnected clients
117
+ for conn in disconnected:
118
+ self.disconnect(job_id, conn)
119
+
120
+ manager = ConnectionManager()
121
+
122
  class MetricsCallback(BaseCallback):
123
  """Custom callback to track training metrics in real-time"""
124
+ def __init__(self, job_id: str, render_freq: int = 5):
125
  super().__init__()
126
  self.job_id = job_id
127
  self.episode_count = 0
128
+ self.step_count = 0
129
+ self.render_freq = render_freq # Render every N steps
130
+ self.env = None
131
+
132
  def _on_step(self) -> bool:
133
  job = training_jobs.get(self.job_id)
134
  if not job:
135
  return False
136
+
137
+ self.step_count += 1
138
+
139
  # Update timestep count
140
  job["metrics"]["timesteps"] = self.num_timesteps
141
  job["metrics"]["progress"] = int(
142
  (self.num_timesteps / job["config"]["total_timesteps"]) * 100
143
  )
144
+
145
+ # Render frame periodically
146
+ if self.step_count % self.render_freq == 0:
147
+ try:
148
+ frame = self.model.get_env().render()
149
+ if frame is not None:
150
+ manager.add_frame(self.job_id, frame)
151
+ except Exception as e:
152
+ print(f"[ERROR] Failed to render frame: {e}")
153
+
154
  # Check for episode completion
155
  if self.locals.get("dones", [False])[0]:
156
  if "infos" in self.locals and len(self.locals["infos"]) > 0:
 
159
  self.episode_count += 1
160
  ep_reward = float(info["episode"]["r"])
161
  ep_length = int(info["episode"]["l"])
162
+
163
  job["metrics"]["episodes"] = self.episode_count
164
  job["metrics"]["episode_rewards"].append(ep_reward)
165
  job["metrics"]["episode_lengths"].append(ep_length)
166
  job["metrics"]["current_episode_reward"] = ep_reward
167
+
168
  # Calculate running average
169
  if len(job["metrics"]["episode_rewards"]) > 0:
170
  job["metrics"]["mean_reward"] = float(
 
173
  job["metrics"]["std_reward"] = float(
174
  np.std(job["metrics"]["episode_rewards"][-100:])
175
  )
176
+
177
  # Add log entry
178
+ log_entry = f"[{datetime.now().strftime('%H:%M:%S')}] Episode {self.episode_count}: reward = {ep_reward:.2f}, length = {ep_length}"
179
  job["metrics"]["logs"].append(log_entry)
180
+ if len(job["metrics"]["logs"]) > 100:
181
  job["metrics"]["logs"].pop(0)
182
+
183
  return True
184
 
185
  def run_training(job_id: str, config: Dict[str, Any]):
186
+ """Run the RL training loop with rendering"""
187
+ print(f"[TRAIN] Starting training for job {job_id}")
 
 
188
  training_jobs[job_id]["status"] = "training"
189
  training_jobs[job_id]["start_time"] = datetime.now()
190
+
191
+ env = None
192
  try:
193
  env_name = config.get("env_name", "CartPole-v1")
194
  total_timesteps = config.get("total_timesteps", 100000)
 
196
  n_steps = config.get("n_steps", 2048)
197
  batch_size = config.get("batch_size", 64)
198
  n_epochs = config.get("n_epochs", 10)
199
+
200
+ # Initialize environment with rgb_array rendering
201
+ env = gym.make(env_name, render_mode='rgb_array')
202
  env = Monitor(env)
203
+
204
  # Initialize model
205
  model = PPO(
206
  "MlpPolicy",
 
211
  batch_size=batch_size,
212
  n_epochs=n_epochs,
213
  )
214
+
215
  # Add initial logs
216
  training_jobs[job_id]["metrics"]["logs"].append(
217
+ f"[{datetime.now().strftime('%H:%M:%S')}] Environment: {env_name}"
218
  )
219
  training_jobs[job_id]["metrics"]["logs"].append(
220
+ f"[{datetime.now().strftime('%H:%M:%S')}] Total timesteps: {total_timesteps:,}"
221
  )
222
  training_jobs[job_id]["metrics"]["logs"].append(
223
+ f"[{datetime.now().strftime('%H:%M:%S')}] Starting training..."
224
  )
225
+
226
  # Train with callback
227
  model.learn(
228
  total_timesteps=total_timesteps,
229
+ callback=MetricsCallback(job_id, render_freq=5),
230
  )
231
+
232
  # Evaluate
233
  training_jobs[job_id]["metrics"]["logs"].append(
234
+ f"[{datetime.now().strftime('%H:%M:%S')}] Training completed! Evaluating..."
235
  )
236
+ mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
237
  training_jobs[job_id]["metrics"]["eval_mean_reward"] = float(mean_reward)
238
  training_jobs[job_id]["metrics"]["eval_std_reward"] = float(std_reward)
239
+
240
  # Save model
241
+ model_path = f"models/{env_name}_ppo_{job_id}"
242
+ os.makedirs("models", exist_ok=True)
243
+ model.save(model_path)
244
  training_jobs[job_id]["metrics"]["logs"].append(
245
+ f"[{datetime.now().strftime('%H:%M:%S')}] Model saved!"
246
  )
247
+
248
  # Store results
249
  training_jobs[job_id]["status"] = "completed"
250
  training_jobs[job_id]["results"] = {
251
  "mean_reward": mean_reward,
252
  "std_reward": std_reward,
253
+ "model_path": f"{model_path}.zip",
254
  "total_episodes": training_jobs[job_id]["metrics"]["episodes"],
255
  "total_timesteps": total_timesteps,
256
  }
257
  training_jobs[job_id]["metrics"]["progress"] = 100
258
+
259
+ print(f"[TRAIN] Training completed for job {job_id}")
260
+
261
  except Exception as e:
262
  training_jobs[job_id]["status"] = "failed"
263
  training_jobs[job_id]["error"] = str(e)
264
  training_jobs[job_id]["metrics"]["logs"].append(
265
  f"[{datetime.now().strftime('%H:%M:%S')}] ERROR: {str(e)}"
266
  )
267
+ print(f"[ERROR] Training failed for job {job_id}: {e}")
268
+
269
+ finally:
270
+ if env:
271
+ env.close()
272
+
273
+ # REST Endpoints
274
 
275
  @app.get("/")
276
  def read_root():
 
280
  def start_training(job: TrainingJob, background_tasks: BackgroundTasks):
281
  """Start a new training job"""
282
  job_id = str(uuid.uuid4())
283
+
 
284
  training_jobs[job_id] = {
285
  "status": "queued",
286
  "config": {
 
308
  "error": None,
309
  "start_time": None,
310
  }
311
+
 
312
  background_tasks.add_task(run_training, job_id, training_jobs[job_id]["config"])
313
+
314
  return {
315
  "message": "Training job started successfully!",
316
  "job_id": job_id,
 
318
 
319
  @app.get("/train/{job_id}/status")
320
  def get_training_status(job_id: str):
321
+ """Get full training status with metrics"""
 
 
322
  job = training_jobs.get(job_id)
323
  if not job:
324
  raise HTTPException(status_code=404, detail="Job not found")
325
+
 
326
  elapsed_time = 0
327
  if job.get("start_time"):
328
  elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
329
+
330
  return {
331
  "status": job["status"],
332
  "metrics": job["metrics"],
 
337
 
338
  @app.get("/train/{job_id}/metrics")
339
  def get_training_metrics(job_id: str):
340
+ """Lightweight endpoint for polling metrics"""
 
 
341
  job = training_jobs.get(job_id)
342
  if not job:
343
+ raise HTTPException(status_code=404, detail="Job not found")
344
+
 
 
345
  elapsed_time = 0
346
  if job.get("start_time"):
347
  elapsed_time = (datetime.now() - job["start_time"]).total_seconds()
348
+
349
  return {
350
  "status": job["status"],
351
  "metrics": job["metrics"],
 
354
 
355
  @app.post("/train/{job_id}/stop")
356
  def stop_training(job_id: str):
357
+ """Stop a training job"""
 
 
358
  job = training_jobs.get(job_id)
359
  if not job:
360
  raise HTTPException(status_code=404, detail="Job not found")
361
+
362
  if job["status"] == "training":
363
  job["status"] = "stopped"
364
  job["metrics"]["logs"].append(
 
367
  return {"message": "Training stopped successfully!"}
368
  else:
369
  raise HTTPException(status_code=400, detail="Job is not currently training")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ # WebSocket Endpoint
 
 
 
 
 
 
372
 
373
+ @app.websocket("/ws/render/{job_id}")
374
+ async def websocket_render_endpoint(websocket: WebSocket, job_id: str):
375
+ """
376
+ WebSocket endpoint for real-time environment rendering.
377
+ Connect from frontend with: ws://localhost:8000/ws/render/{job_id}
378
+ """
379
+ await manager.connect(job_id, websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ try:
382
+ while True:
383
+ # Keep connection alive and handle messages
384
+ data = await websocket.receive_text()
385
+ if data == "request_frame":
386
+ await manager.broadcast_frame(job_id)
387
+ elif data == "ping":
388
+ await websocket.send_json({"type": "pong"})
389
+ except WebSocketDisconnect:
390
+ manager.disconnect(job_id, websocket)
391
+ except Exception as e:
392
+ print(f"[ERROR] WebSocket error for job {job_id}: {e}")
393
+ manager.disconnect(job_id, websocket)
394
 
395
+ @app.get("/debug/jobs")
396
+ def debug_jobs():
397
+ """Debug endpoint to list all jobs"""
398
+ return {
399
+ "jobs": [
400
+ {
401
+ "job_id": job_id,
402
+ "status": job["status"],
403
+ "progress": job["metrics"]["progress"],
404
+ "episodes": job["metrics"]["episodes"],
405
+ }
406
+ for job_id, job in training_jobs.items()
407
+ ]
408
+ }