Spaces:
Sleeping
Sleeping
| """ | |
| src/rl_agent.py | |
| --------------- | |
| PPO reinforcement-learning agent for Jaam Ctrl. | |
| Exports required by app.py: | |
| train_ppo(total_timesteps, learning_rate, progress_callback) -> str (model path) | |
| load_ppo_model() -> model | None | |
| load_training_log() -> dict | |
| MODEL_PATH str | |
| SB3_AVAILABLE bool | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import time | |
| import math | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # stable-baselines3 availability | |
| # --------------------------------------------------------------------------- | |
| try: | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.env_util import make_vec_env | |
| import gymnasium as gym | |
| SB3_AVAILABLE = True | |
| except ImportError: | |
| SB3_AVAILABLE = False | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| _SRC_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _ROOT_DIR = os.path.dirname(_SRC_DIR) | |
| _MODELS_DIR = os.path.join(_ROOT_DIR, "models") | |
| MODEL_PATH = os.path.join(_MODELS_DIR, "ppo_jaam_ctrl") | |
| _LOG_PATH = os.path.join(_MODELS_DIR, "training_log.json") | |
| def _ensure_models_dir() -> None: | |
| os.makedirs(_MODELS_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Gym environment (used when SB3 + SUMO both available) | |
| # --------------------------------------------------------------------------- | |
| if SB3_AVAILABLE: | |
| class JaamCtrlEnv(gym.Env): | |
| """ | |
| OpenAI Gym environment wrapping the 3-junction SUMO corridor. | |
| Observation: 18-dim float32 (6 features × 3 junctions) | |
| Action: Discrete(8) — 3-bit binary switch request per junction | |
| """ | |
| metadata = {"render_modes": []} | |
| def __init__(self): | |
| super().__init__() | |
| self.observation_space = gym.spaces.Box( | |
| low=0.0, high=1.0, shape=(18,), dtype=np.float32 | |
| ) | |
| self.action_space = gym.spaces.Discrete(8) | |
| self._step = 0 | |
| self._max_steps = 180 # 180 × 10s = 1800s episode | |
| self._queues = np.zeros((3, 2), dtype=np.float32) # [junc, ew/ns] | |
| self._phases = np.zeros(3, dtype=np.int32) | |
| self._phase_ages = np.zeros(3, dtype=np.float32) | |
| def reset(self, *, seed=None, options=None): | |
| super().reset(seed=seed) | |
| self._step = 0 | |
| self._queues = self.np_random.uniform(0, 5, (3, 2)).astype(np.float32) | |
| self._phases = np.zeros(3, dtype=np.int32) | |
| self._phase_ages = np.zeros(3, dtype=np.float32) | |
| return self._obs(), {} | |
| def step(self, action: int): | |
| # Decode 3-bit action | |
| bits = [(action >> i) & 1 for i in range(3)] | |
| # Update phase durations and possibly switch | |
| phase_lengths = [40, 4, 30, 4] | |
| for j in range(3): | |
| self._phase_ages[j] += 10 # 10s control step | |
| if bits[j] == 1: | |
| cur_ph = self._phases[j] | |
| min_age = 15 if cur_ph in (0, 2) else 4 | |
| if self._phase_ages[j] >= min_age: | |
| self._phases[j] = (cur_ph + 1) % 4 | |
| self._phase_ages[j] = 0 | |
| # Simulate queue evolution (Poisson arrivals, phase-dependent service) | |
| rng = self.np_random | |
| arr = rng.poisson(3, (3, 2)).astype(np.float32) # arrivals | |
| for j in range(3): | |
| ph = self._phases[j] | |
| # Green phase for EW = 0, NS = 2 | |
| service = np.array([ | |
| 4.0 if ph == 0 else 0.5, | |
| 4.0 if ph == 2 else 0.5, | |
| ], dtype=np.float32) | |
| self._queues[j] = np.clip( | |
| self._queues[j] + arr[j] - service, 0, 25 | |
| ) | |
| # Reward | |
| total_queue = float(self._queues.sum()) | |
| prev_delay = total_queue * 2.0 | |
| new_delay = total_queue * 1.8 | |
| delay_r = math.tanh((prev_delay - new_delay) / max(prev_delay, 1)) | |
| throughput_r= 0.5 * min(float(arr.sum()) / 10.0, 1.0) | |
| q_std = float(self._queues.mean(axis=1).std()) | |
| q_mean = float(self._queues.mean()) + 1e-6 | |
| balance_r = -0.3 * q_std / q_mean | |
| gridlock = float((self._queues.mean(axis=1) > 20).sum()) / 3.0 | |
| gridlock_r = -0.4 * gridlock | |
| reward = delay_r + throughput_r + balance_r + gridlock_r | |
| self._step += 1 | |
| done = self._step >= self._max_steps | |
| return self._obs(), reward, done, False, {} | |
| def _obs(self) -> np.ndarray: | |
| obs = [] | |
| for j in range(3): | |
| qew = self._queues[j, 0] / 25.0 | |
| qns = self._queues[j, 1] / 25.0 | |
| ph = self._phases[j] | |
| ph_ew = 1.0 if ph == 0 else 0.0 | |
| ph_ns = 1.0 if ph == 2 else 0.0 | |
| age = min(self._phase_ages[j] / 60.0, 1.0) | |
| tput = 0.5 | |
| obs.extend([qew, qns, ph_ew, ph_ns, age, tput]) | |
| return np.array(obs, dtype=np.float32) | |
| # --------------------------------------------------------------------------- | |
| # Training | |
| # --------------------------------------------------------------------------- | |
| def train_ppo( | |
| total_timesteps: int = 3000, | |
| learning_rate: float = 3e-4, | |
| progress_callback=None, | |
| ) -> str: | |
| """ | |
| Train a PPO agent on JaamCtrlEnv and save to MODEL_PATH. | |
| Returns the model path (without .zip). | |
| Raises RuntimeError if stable-baselines3 is not installed. | |
| """ | |
| if not SB3_AVAILABLE: | |
| raise RuntimeError( | |
| "stable-baselines3 is not installed. " | |
| "Run: pip install stable-baselines3 gymnasium" | |
| ) | |
| _ensure_models_dir() | |
| env = make_vec_env(JaamCtrlEnv, n_envs=1) | |
| model = PPO( | |
| "MlpPolicy", | |
| env, | |
| learning_rate = learning_rate, | |
| n_steps = 128, | |
| batch_size = 32, | |
| n_epochs = 5, | |
| gamma = 0.95, | |
| policy_kwargs = dict(net_arch=[128, 128]), | |
| verbose = 0, | |
| ) | |
| episode_rewards: list[float] = [] | |
| episode_delays: list[float] = [] | |
| ep_reward = 0.0 | |
| ep_steps = 0 | |
| best_r = -float("inf") | |
| start_ts = 0 | |
| class _LogCB: | |
| def __init__(self): | |
| self.n_calls = 0 | |
| def __call__(self, locals_, globals_): | |
| nonlocal ep_reward, ep_steps, best_r, start_ts | |
| self.n_calls += 1 | |
| ep_reward += float(locals_["rewards"][0]) | |
| ep_steps += 1 | |
| if locals_["dones"][0]: | |
| episode_rewards.append(ep_reward) | |
| episode_delays.append(max(0, 62 - ep_reward * 10)) | |
| best_r = max(best_r, ep_reward) | |
| ep_reward = 0.0 | |
| ep_steps = 0 | |
| if progress_callback: | |
| progress_callback(self.n_calls, total_timesteps) | |
| return True | |
| model.learn( | |
| total_timesteps = total_timesteps, | |
| callback = _LogCB(), | |
| reset_num_timesteps = True, | |
| ) | |
| model.save(MODEL_PATH) | |
| # Save training log | |
| log = { | |
| "total_episodes": len(episode_rewards), | |
| "episode_rewards": episode_rewards, | |
| "episode_delays": episode_delays, | |
| "mean_reward": float(np.mean(episode_rewards)) if episode_rewards else 0.0, | |
| "best_reward": float(best_r) if best_r > -float("inf") else 0.0, | |
| } | |
| with open(_LOG_PATH, "w") as f: | |
| json.dump(log, f) | |
| return MODEL_PATH | |
| # --------------------------------------------------------------------------- | |
| # Loading | |
| # --------------------------------------------------------------------------- | |
| def load_ppo_model(): | |
| """ | |
| Load the saved PPO model from disk. | |
| Attempts to download from GitHub if not found locally. | |
| Returns the model object or None if not found / SB3 not available. | |
| """ | |
| if not SB3_AVAILABLE: | |
| return None | |
| path = MODEL_PATH + ".zip" | |
| # Try to load locally first | |
| if os.path.exists(path): | |
| try: | |
| return PPO.load(MODEL_PATH) | |
| except Exception: | |
| pass | |
| # If not local, try downloading from GitHub release | |
| try: | |
| import urllib.request | |
| _ensure_models_dir() | |
| url = "https://github.com/Akshara2424/JaamCTRL-OpenEnv/releases/download/v1.0.0/ppo_jaam_ctrl.zip" | |
| print(f"[INFO] Downloading PPO model from {url}...") | |
| urllib.request.urlretrieve(url, path) | |
| print(f"[INFO] Downloaded to {path}") | |
| return PPO.load(MODEL_PATH) | |
| except Exception as e: | |
| print(f"[WARN] Could not download model: {e}") | |
| return None | |
| def load_training_log() -> dict: | |
| """ | |
| Load the training log JSON written by train_ppo(). | |
| Returns an empty dict if not found. | |
| """ | |
| if not os.path.exists(_LOG_PATH): | |
| return {} | |
| try: | |
| with open(_LOG_PATH) as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |