JaamCTRL-OpenEnv / src /rl_agent.py
Akshara
Add automatic model download fallback from GitHub releases
d555b05
"""
src/rl_agent.py
---------------
PPO reinforcement-learning agent for Jaam Ctrl.
Exports required by app.py:
train_ppo(total_timesteps, learning_rate, progress_callback) -> str (model path)
load_ppo_model() -> model | None
load_training_log() -> dict
MODEL_PATH str
SB3_AVAILABLE bool
"""
from __future__ import annotations
import os
import json
import time
import math
import numpy as np
# ---------------------------------------------------------------------------
# stable-baselines3 availability
# ---------------------------------------------------------------------------
try:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
SB3_AVAILABLE = True
except ImportError:
SB3_AVAILABLE = False
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
_SRC_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_SRC_DIR)
_MODELS_DIR = os.path.join(_ROOT_DIR, "models")
MODEL_PATH = os.path.join(_MODELS_DIR, "ppo_jaam_ctrl")
_LOG_PATH = os.path.join(_MODELS_DIR, "training_log.json")
def _ensure_models_dir() -> None:
os.makedirs(_MODELS_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Gym environment (used when SB3 + SUMO both available)
# ---------------------------------------------------------------------------
if SB3_AVAILABLE:
class JaamCtrlEnv(gym.Env):
"""
OpenAI Gym environment wrapping the 3-junction SUMO corridor.
Observation: 18-dim float32 (6 features × 3 junctions)
Action: Discrete(8) — 3-bit binary switch request per junction
"""
metadata = {"render_modes": []}
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(
low=0.0, high=1.0, shape=(18,), dtype=np.float32
)
self.action_space = gym.spaces.Discrete(8)
self._step = 0
self._max_steps = 180 # 180 × 10s = 1800s episode
self._queues = np.zeros((3, 2), dtype=np.float32) # [junc, ew/ns]
self._phases = np.zeros(3, dtype=np.int32)
self._phase_ages = np.zeros(3, dtype=np.float32)
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self._step = 0
self._queues = self.np_random.uniform(0, 5, (3, 2)).astype(np.float32)
self._phases = np.zeros(3, dtype=np.int32)
self._phase_ages = np.zeros(3, dtype=np.float32)
return self._obs(), {}
def step(self, action: int):
# Decode 3-bit action
bits = [(action >> i) & 1 for i in range(3)]
# Update phase durations and possibly switch
phase_lengths = [40, 4, 30, 4]
for j in range(3):
self._phase_ages[j] += 10 # 10s control step
if bits[j] == 1:
cur_ph = self._phases[j]
min_age = 15 if cur_ph in (0, 2) else 4
if self._phase_ages[j] >= min_age:
self._phases[j] = (cur_ph + 1) % 4
self._phase_ages[j] = 0
# Simulate queue evolution (Poisson arrivals, phase-dependent service)
rng = self.np_random
arr = rng.poisson(3, (3, 2)).astype(np.float32) # arrivals
for j in range(3):
ph = self._phases[j]
# Green phase for EW = 0, NS = 2
service = np.array([
4.0 if ph == 0 else 0.5,
4.0 if ph == 2 else 0.5,
], dtype=np.float32)
self._queues[j] = np.clip(
self._queues[j] + arr[j] - service, 0, 25
)
# Reward
total_queue = float(self._queues.sum())
prev_delay = total_queue * 2.0
new_delay = total_queue * 1.8
delay_r = math.tanh((prev_delay - new_delay) / max(prev_delay, 1))
throughput_r= 0.5 * min(float(arr.sum()) / 10.0, 1.0)
q_std = float(self._queues.mean(axis=1).std())
q_mean = float(self._queues.mean()) + 1e-6
balance_r = -0.3 * q_std / q_mean
gridlock = float((self._queues.mean(axis=1) > 20).sum()) / 3.0
gridlock_r = -0.4 * gridlock
reward = delay_r + throughput_r + balance_r + gridlock_r
self._step += 1
done = self._step >= self._max_steps
return self._obs(), reward, done, False, {}
def _obs(self) -> np.ndarray:
obs = []
for j in range(3):
qew = self._queues[j, 0] / 25.0
qns = self._queues[j, 1] / 25.0
ph = self._phases[j]
ph_ew = 1.0 if ph == 0 else 0.0
ph_ns = 1.0 if ph == 2 else 0.0
age = min(self._phase_ages[j] / 60.0, 1.0)
tput = 0.5
obs.extend([qew, qns, ph_ew, ph_ns, age, tput])
return np.array(obs, dtype=np.float32)
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train_ppo(
total_timesteps: int = 3000,
learning_rate: float = 3e-4,
progress_callback=None,
) -> str:
"""
Train a PPO agent on JaamCtrlEnv and save to MODEL_PATH.
Returns the model path (without .zip).
Raises RuntimeError if stable-baselines3 is not installed.
"""
if not SB3_AVAILABLE:
raise RuntimeError(
"stable-baselines3 is not installed. "
"Run: pip install stable-baselines3 gymnasium"
)
_ensure_models_dir()
env = make_vec_env(JaamCtrlEnv, n_envs=1)
model = PPO(
"MlpPolicy",
env,
learning_rate = learning_rate,
n_steps = 128,
batch_size = 32,
n_epochs = 5,
gamma = 0.95,
policy_kwargs = dict(net_arch=[128, 128]),
verbose = 0,
)
episode_rewards: list[float] = []
episode_delays: list[float] = []
ep_reward = 0.0
ep_steps = 0
best_r = -float("inf")
start_ts = 0
class _LogCB:
def __init__(self):
self.n_calls = 0
def __call__(self, locals_, globals_):
nonlocal ep_reward, ep_steps, best_r, start_ts
self.n_calls += 1
ep_reward += float(locals_["rewards"][0])
ep_steps += 1
if locals_["dones"][0]:
episode_rewards.append(ep_reward)
episode_delays.append(max(0, 62 - ep_reward * 10))
best_r = max(best_r, ep_reward)
ep_reward = 0.0
ep_steps = 0
if progress_callback:
progress_callback(self.n_calls, total_timesteps)
return True
model.learn(
total_timesteps = total_timesteps,
callback = _LogCB(),
reset_num_timesteps = True,
)
model.save(MODEL_PATH)
# Save training log
log = {
"total_episodes": len(episode_rewards),
"episode_rewards": episode_rewards,
"episode_delays": episode_delays,
"mean_reward": float(np.mean(episode_rewards)) if episode_rewards else 0.0,
"best_reward": float(best_r) if best_r > -float("inf") else 0.0,
}
with open(_LOG_PATH, "w") as f:
json.dump(log, f)
return MODEL_PATH
# ---------------------------------------------------------------------------
# Loading
# ---------------------------------------------------------------------------
def load_ppo_model():
"""
Load the saved PPO model from disk.
Attempts to download from GitHub if not found locally.
Returns the model object or None if not found / SB3 not available.
"""
if not SB3_AVAILABLE:
return None
path = MODEL_PATH + ".zip"
# Try to load locally first
if os.path.exists(path):
try:
return PPO.load(MODEL_PATH)
except Exception:
pass
# If not local, try downloading from GitHub release
try:
import urllib.request
_ensure_models_dir()
url = "https://github.com/Akshara2424/JaamCTRL-OpenEnv/releases/download/v1.0.0/ppo_jaam_ctrl.zip"
print(f"[INFO] Downloading PPO model from {url}...")
urllib.request.urlretrieve(url, path)
print(f"[INFO] Downloaded to {path}")
return PPO.load(MODEL_PATH)
except Exception as e:
print(f"[WARN] Could not download model: {e}")
return None
def load_training_log() -> dict:
"""
Load the training log JSON written by train_ppo().
Returns an empty dict if not found.
"""
if not os.path.exists(_LOG_PATH):
return {}
try:
with open(_LOG_PATH) as f:
return json.load(f)
except Exception:
return {}