File size: 4,661 Bytes
92107a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from typing import Optional

import numpy as np

from .world import World, Intersection

STATE_DIM = 14
MIN_GREEN = 6
DECISION_INTERVAL = 5


class DQNController:
    def __init__(self, weights_path: str):
        data = np.load(weights_path)
        self.feat_w1 = data["feature.0.weight"]
        self.feat_b1 = data["feature.0.bias"]
        self.feat_w2 = data["feature.2.weight"]
        self.feat_b2 = data["feature.2.bias"]
        self.val_w1 = data["value.0.weight"]
        self.val_b1 = data["value.0.bias"]
        self.val_w2 = data["value.2.weight"]
        self.val_b2 = data["value.2.bias"]
        self.adv_w1 = data["advantage.0.weight"]
        self.adv_b1 = data["advantage.0.bias"]
        self.adv_w2 = data["advantage.2.weight"]
        self.adv_b2 = data["advantage.2.bias"]

    def forward(self, states: np.ndarray) -> np.ndarray:
        x = states @ self.feat_w1.T + self.feat_b1
        x = np.maximum(x, 0)
        x = x @ self.feat_w2.T + self.feat_b2
        x = np.maximum(x, 0)
        v = x @ self.val_w1.T + self.val_b1
        v = np.maximum(v, 0)
        v = v @ self.val_w2.T + self.val_b2
        a = x @ self.adv_w1.T + self.adv_b1
        a = np.maximum(a, 0)
        a = a @ self.adv_w2.T + self.adv_b2
        q = v + (a - a.mean(axis=-1, keepdims=True))
        return q

    def select_actions(self, states: np.ndarray) -> np.ndarray:
        q = self.forward(states)
        return q.argmax(axis=-1)


def _queue_norm(world: World, rid: Optional[str]) -> float:
    if not rid or rid not in world.roads:
        return 0.0
    r = world.roads[rid]
    return min(1.0, r.queue_at_tail() / max(1, r.length))


def _occ_norm(world: World, rid: Optional[str]) -> float:
    if not rid or rid not in world.roads:
        return 0.0
    r = world.roads[rid]
    return min(1.0, r.occupancy() / max(1, r.length))


def extract_state(world: World, I: Intersection) -> np.ndarray:
    s = np.zeros(STATE_DIM, dtype=np.float32)
    s_rid = I.incoming.get("S")
    w_rid = I.incoming.get("W")
    s[0] = _queue_norm(world, s_rid)
    s[1] = _queue_norm(world, w_rid)
    s[2] = _occ_norm(world, s_rid)
    s[3] = _occ_norm(world, w_rid)
    phase = I.current_phase()
    s[4] = 1.0 if frozenset({"N", "S"}) == phase else 0.0
    s[5] = min(1.0, I.phase_timer / 45.0)
    s[6] = 1.0 if I.phase_timer >= MIN_GREEN else 0.0
    s_neighbors = []
    w_neighbors = []
    for nid in I.neighbors:
        NI = world.intersections.get(nid)
        if NI is None:
            continue
        s_neighbors.append(_queue_norm(world, NI.incoming.get("S")))
        w_neighbors.append(_queue_norm(world, NI.incoming.get("W")))
    s[7] = float(np.mean(s_neighbors)) if s_neighbors else 0.0
    s[8] = float(np.mean(w_neighbors)) if w_neighbors else 0.0
    e_rid = I.outgoing.get("E")
    n_rid = I.outgoing.get("N")
    s[9] = _occ_norm(world, e_rid)
    s[10] = _occ_norm(world, n_rid)
    total_occ = 0.0
    total_cap = 0.0
    for d in ["S", "W"]:
        rid = I.incoming.get(d)
        if rid and rid in world.roads:
            total_occ += world.roads[rid].occupancy()
            total_cap += world.roads[rid].length
    s[11] = total_occ / max(1, total_cap)
    s_q = world.roads[s_rid].queue_at_tail() if s_rid and s_rid in world.roads else 0
    w_q = world.roads[w_rid].queue_at_tail() if w_rid and w_rid in world.roads else 0
    s[12] = (s_q - w_q) / max(1.0, float(s_q + w_q))
    s[13] = world.tick / max(1, world.horizon)
    return s


def rl_step(controller: DQNController, world: World) -> None:
    inter_ids = sorted(world.intersections.keys())
    states = np.stack([extract_state(world, world.intersections[iid]) for iid in inter_ids])
    actions = controller.select_actions(states)
    for i, iid in enumerate(inter_ids):
        I = world.intersections[iid]
        if I.preempt_direction is not None:
            continue
        if actions[i] == 1 and I.phase_timer >= MIN_GREEN:
            I.current_phase_idx = (I.current_phase_idx + 1) % len(I.phases)
            I.phase_timer = 0


_controller_cache: Optional[DQNController] = None


def get_controller() -> Optional[DQNController]:
    global _controller_cache
    if _controller_cache is not None:
        return _controller_cache
    for candidate in [
        os.path.join(os.path.dirname(__file__), "..", "..", "dqn_weights.npz"),
        os.path.join(os.path.dirname(__file__), "..", "..", "..", "dqn_weights.npz"),
        "/app/env/dqn_weights.npz",
        "dqn_weights.npz",
    ]:
        if os.path.exists(candidate):
            _controller_cache = DQNController(candidate)
            return _controller_cache
    return None