trafficops / server /sim /rl_controller.py
Kunalsinghh's picture
Upload folder using huggingface_hub
92107a5 verified
import os
from typing import Optional
import numpy as np
from .world import World, Intersection
STATE_DIM = 14
MIN_GREEN = 6
DECISION_INTERVAL = 5
class DQNController:
def __init__(self, weights_path: str):
data = np.load(weights_path)
self.feat_w1 = data["feature.0.weight"]
self.feat_b1 = data["feature.0.bias"]
self.feat_w2 = data["feature.2.weight"]
self.feat_b2 = data["feature.2.bias"]
self.val_w1 = data["value.0.weight"]
self.val_b1 = data["value.0.bias"]
self.val_w2 = data["value.2.weight"]
self.val_b2 = data["value.2.bias"]
self.adv_w1 = data["advantage.0.weight"]
self.adv_b1 = data["advantage.0.bias"]
self.adv_w2 = data["advantage.2.weight"]
self.adv_b2 = data["advantage.2.bias"]
def forward(self, states: np.ndarray) -> np.ndarray:
x = states @ self.feat_w1.T + self.feat_b1
x = np.maximum(x, 0)
x = x @ self.feat_w2.T + self.feat_b2
x = np.maximum(x, 0)
v = x @ self.val_w1.T + self.val_b1
v = np.maximum(v, 0)
v = v @ self.val_w2.T + self.val_b2
a = x @ self.adv_w1.T + self.adv_b1
a = np.maximum(a, 0)
a = a @ self.adv_w2.T + self.adv_b2
q = v + (a - a.mean(axis=-1, keepdims=True))
return q
def select_actions(self, states: np.ndarray) -> np.ndarray:
q = self.forward(states)
return q.argmax(axis=-1)
def _queue_norm(world: World, rid: Optional[str]) -> float:
if not rid or rid not in world.roads:
return 0.0
r = world.roads[rid]
return min(1.0, r.queue_at_tail() / max(1, r.length))
def _occ_norm(world: World, rid: Optional[str]) -> float:
if not rid or rid not in world.roads:
return 0.0
r = world.roads[rid]
return min(1.0, r.occupancy() / max(1, r.length))
def extract_state(world: World, I: Intersection) -> np.ndarray:
s = np.zeros(STATE_DIM, dtype=np.float32)
s_rid = I.incoming.get("S")
w_rid = I.incoming.get("W")
s[0] = _queue_norm(world, s_rid)
s[1] = _queue_norm(world, w_rid)
s[2] = _occ_norm(world, s_rid)
s[3] = _occ_norm(world, w_rid)
phase = I.current_phase()
s[4] = 1.0 if frozenset({"N", "S"}) == phase else 0.0
s[5] = min(1.0, I.phase_timer / 45.0)
s[6] = 1.0 if I.phase_timer >= MIN_GREEN else 0.0
s_neighbors = []
w_neighbors = []
for nid in I.neighbors:
NI = world.intersections.get(nid)
if NI is None:
continue
s_neighbors.append(_queue_norm(world, NI.incoming.get("S")))
w_neighbors.append(_queue_norm(world, NI.incoming.get("W")))
s[7] = float(np.mean(s_neighbors)) if s_neighbors else 0.0
s[8] = float(np.mean(w_neighbors)) if w_neighbors else 0.0
e_rid = I.outgoing.get("E")
n_rid = I.outgoing.get("N")
s[9] = _occ_norm(world, e_rid)
s[10] = _occ_norm(world, n_rid)
total_occ = 0.0
total_cap = 0.0
for d in ["S", "W"]:
rid = I.incoming.get(d)
if rid and rid in world.roads:
total_occ += world.roads[rid].occupancy()
total_cap += world.roads[rid].length
s[11] = total_occ / max(1, total_cap)
s_q = world.roads[s_rid].queue_at_tail() if s_rid and s_rid in world.roads else 0
w_q = world.roads[w_rid].queue_at_tail() if w_rid and w_rid in world.roads else 0
s[12] = (s_q - w_q) / max(1.0, float(s_q + w_q))
s[13] = world.tick / max(1, world.horizon)
return s
def rl_step(controller: DQNController, world: World) -> None:
inter_ids = sorted(world.intersections.keys())
states = np.stack([extract_state(world, world.intersections[iid]) for iid in inter_ids])
actions = controller.select_actions(states)
for i, iid in enumerate(inter_ids):
I = world.intersections[iid]
if I.preempt_direction is not None:
continue
if actions[i] == 1 and I.phase_timer >= MIN_GREEN:
I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases)
I.phase_timer = 0
_controller_cache: Optional[DQNController] = None
def get_controller() -> Optional[DQNController]:
global _controller_cache
if _controller_cache is not None:
return _controller_cache
for candidate in [
os.path.join(os.path.dirname(__file__), "..", "..", "dqn_weights.npz"),
os.path.join(os.path.dirname(__file__), "..", "..", "..", "dqn_weights.npz"),
"/app/env/dqn_weights.npz",
"dqn_weights.npz",
]:
if os.path.exists(candidate):
_controller_cache = DQNController(candidate)
return _controller_cache
return None