Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| import numpy as np | |
| from .world import World, Intersection | |
| STATE_DIM = 14 | |
| MIN_GREEN = 6 | |
| DECISION_INTERVAL = 5 | |
| class DQNController: | |
| def __init__(self, weights_path: str): | |
| data = np.load(weights_path) | |
| self.feat_w1 = data["feature.0.weight"] | |
| self.feat_b1 = data["feature.0.bias"] | |
| self.feat_w2 = data["feature.2.weight"] | |
| self.feat_b2 = data["feature.2.bias"] | |
| self.val_w1 = data["value.0.weight"] | |
| self.val_b1 = data["value.0.bias"] | |
| self.val_w2 = data["value.2.weight"] | |
| self.val_b2 = data["value.2.bias"] | |
| self.adv_w1 = data["advantage.0.weight"] | |
| self.adv_b1 = data["advantage.0.bias"] | |
| self.adv_w2 = data["advantage.2.weight"] | |
| self.adv_b2 = data["advantage.2.bias"] | |
| def forward(self, states: np.ndarray) -> np.ndarray: | |
| x = states @ self.feat_w1.T + self.feat_b1 | |
| x = np.maximum(x, 0) | |
| x = x @ self.feat_w2.T + self.feat_b2 | |
| x = np.maximum(x, 0) | |
| v = x @ self.val_w1.T + self.val_b1 | |
| v = np.maximum(v, 0) | |
| v = v @ self.val_w2.T + self.val_b2 | |
| a = x @ self.adv_w1.T + self.adv_b1 | |
| a = np.maximum(a, 0) | |
| a = a @ self.adv_w2.T + self.adv_b2 | |
| q = v + (a - a.mean(axis=-1, keepdims=True)) | |
| return q | |
| def select_actions(self, states: np.ndarray) -> np.ndarray: | |
| q = self.forward(states) | |
| return q.argmax(axis=-1) | |
| def _queue_norm(world: World, rid: Optional[str]) -> float: | |
| if not rid or rid not in world.roads: | |
| return 0.0 | |
| r = world.roads[rid] | |
| return min(1.0, r.queue_at_tail() / max(1, r.length)) | |
| def _occ_norm(world: World, rid: Optional[str]) -> float: | |
| if not rid or rid not in world.roads: | |
| return 0.0 | |
| r = world.roads[rid] | |
| return min(1.0, r.occupancy() / max(1, r.length)) | |
| def extract_state(world: World, I: Intersection) -> np.ndarray: | |
| s = np.zeros(STATE_DIM, dtype=np.float32) | |
| s_rid = I.incoming.get("S") | |
| w_rid = I.incoming.get("W") | |
| s[0] = _queue_norm(world, s_rid) | |
| s[1] = _queue_norm(world, w_rid) | |
| s[2] = _occ_norm(world, s_rid) | |
| s[3] = _occ_norm(world, w_rid) | |
| phase = I.current_phase() | |
| s[4] = 1.0 if frozenset({"N", "S"}) == phase else 0.0 | |
| s[5] = min(1.0, I.phase_timer / 45.0) | |
| s[6] = 1.0 if I.phase_timer >= MIN_GREEN else 0.0 | |
| s_neighbors = [] | |
| w_neighbors = [] | |
| for nid in I.neighbors: | |
| NI = world.intersections.get(nid) | |
| if NI is None: | |
| continue | |
| s_neighbors.append(_queue_norm(world, NI.incoming.get("S"))) | |
| w_neighbors.append(_queue_norm(world, NI.incoming.get("W"))) | |
| s[7] = float(np.mean(s_neighbors)) if s_neighbors else 0.0 | |
| s[8] = float(np.mean(w_neighbors)) if w_neighbors else 0.0 | |
| e_rid = I.outgoing.get("E") | |
| n_rid = I.outgoing.get("N") | |
| s[9] = _occ_norm(world, e_rid) | |
| s[10] = _occ_norm(world, n_rid) | |
| total_occ = 0.0 | |
| total_cap = 0.0 | |
| for d in ["S", "W"]: | |
| rid = I.incoming.get(d) | |
| if rid and rid in world.roads: | |
| total_occ += world.roads[rid].occupancy() | |
| total_cap += world.roads[rid].length | |
| s[11] = total_occ / max(1, total_cap) | |
| s_q = world.roads[s_rid].queue_at_tail() if s_rid and s_rid in world.roads else 0 | |
| w_q = world.roads[w_rid].queue_at_tail() if w_rid and w_rid in world.roads else 0 | |
| s[12] = (s_q - w_q) / max(1.0, float(s_q + w_q)) | |
| s[13] = world.tick / max(1, world.horizon) | |
| return s | |
| def rl_step(controller: DQNController, world: World) -> None: | |
| inter_ids = sorted(world.intersections.keys()) | |
| states = np.stack([extract_state(world, world.intersections[iid]) for iid in inter_ids]) | |
| actions = controller.select_actions(states) | |
| for i, iid in enumerate(inter_ids): | |
| I = world.intersections[iid] | |
| if I.preempt_direction is not None: | |
| continue | |
| if actions[i] == 1 and I.phase_timer >= MIN_GREEN: | |
| I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases) | |
| I.phase_timer = 0 | |
| _controller_cache: Optional[DQNController] = None | |
| def get_controller() -> Optional[DQNController]: | |
| global _controller_cache | |
| if _controller_cache is not None: | |
| return _controller_cache | |
| for candidate in [ | |
| os.path.join(os.path.dirname(__file__), "..", "..", "dqn_weights.npz"), | |
| os.path.join(os.path.dirname(__file__), "..", "..", "..", "dqn_weights.npz"), | |
| "/app/env/dqn_weights.npz", | |
| "dqn_weights.npz", | |
| ]: | |
| if os.path.exists(candidate): | |
| _controller_cache = DQNController(candidate) | |
| return _controller_cache | |
| return None | |