ThEyAtH commited on
Commit
a53bea5
Β·
1 Parent(s): e370e67

fix: detect signals against route edge segments, not just nodes

Browse files
Files changed (1) hide show
  1. backend/signal/signal_model.py +53 -13
backend/signal/signal_model.py CHANGED
@@ -71,7 +71,6 @@ EXPECTED IMPROVEMENT
71
  import json
72
  import os
73
  import math
74
- import pickle
75
  import logging
76
  import osmnx as ox
77
 
@@ -97,7 +96,7 @@ class SignalModel:
97
  graph,
98
  registry_file="data/signals_registry.json",
99
  cluster_radius=90,
100
- detection_radius=80,
101
  avg_wait_per_signal=75,
102
  stop_probability=0.85,
103
  ):
@@ -109,10 +108,6 @@ class SignalModel:
109
  self.stop_prob = stop_probability
110
  self.junctions = []
111
 
112
- # Cache file lives next to the registry JSON
113
- base = os.path.splitext(registry_file)[0]
114
- self._cache_file = base + "_clustered.pkl"
115
-
116
  self._load_and_cluster_signals()
117
 
118
  # ── Load + Snap + Cluster ─────────────────────────────────────────────────
@@ -173,6 +168,26 @@ class SignalModel:
173
 
174
  # ── Route Analysis ────────────────────────────────────────────────────────
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def analyze_route(self, route):
177
  signal_count = 0
178
 
@@ -185,15 +200,24 @@ class SignalModel:
185
  for junction in self.junctions:
186
  j_lat = junction["lat"]
187
  j_lng = junction["lng"]
 
 
 
 
 
 
 
188
 
189
- for (node_lat, node_lng) in route_coords:
190
- # Early exit on latitude diff before computing full distance
191
- if abs(node_lat - j_lat) * _LAT_TO_M > self.detection_radius:
 
192
  continue
193
 
194
- dist = _fast_dist_m(node_lat, node_lng, j_lat, j_lng)
195
  if dist <= self.detection_radius:
196
  signal_count += 1
 
197
  break
198
 
199
  expected_stops = signal_count * self.stop_prob
@@ -208,11 +232,27 @@ class SignalModel:
208
  # ── Attach Signal Weights to Graph ────────────────────────────────────────
209
 
210
  def attach_signal_weights(self):
211
- # Snap each junction centroid to the nearest graph node
 
 
 
212
  node_to_junction = {}
 
 
 
 
 
213
  for jid, junction in enumerate(self.junctions):
214
- node = ox.distance.nearest_nodes(self.G, junction["lng"], junction["lat"])
215
- node_to_junction[node] = jid
 
 
 
 
 
 
 
 
216
 
217
  expected_delay_min = self.stop_prob * (self.avg_wait / 60.0)
218
 
 
71
  import json
72
  import os
73
  import math
 
74
  import logging
75
  import osmnx as ox
76
 
 
96
  graph,
97
  registry_file="data/signals_registry.json",
98
  cluster_radius=90,
99
+ detection_radius=150,
100
  avg_wait_per_signal=75,
101
  stop_probability=0.85,
102
  ):
 
108
  self.stop_prob = stop_probability
109
  self.junctions = []
110
 
 
 
 
 
111
  self._load_and_cluster_signals()
112
 
113
  # ── Load + Snap + Cluster ─────────────────────────────────────────────────
 
168
 
169
  # ── Route Analysis ────────────────────────────────────────────────────────
170
 
171
+ @staticmethod
172
+ def _point_to_segment_dist(plat, plng, alat, alng, blat, blng):
173
+ """Perpendicular distance from point P to segment A→B in metres."""
174
+ ax = (alng - plng) * _LAT_TO_M * math.cos(plat * math.pi / 180)
175
+ ay = (alat - plat) * _LAT_TO_M
176
+ bx = (blng - plng) * _LAT_TO_M * math.cos(plat * math.pi / 180)
177
+ by = (blat - plat) * _LAT_TO_M
178
+
179
+ ab_sq = ax*ax + ay*ay + bx*bx + by*by # rough check
180
+ dx, dy = bx - ax, by - ay
181
+ seg_len_sq = dx*dx + dy*dy
182
+
183
+ if seg_len_sq == 0:
184
+ return math.sqrt(ax*ax + ay*ay)
185
+
186
+ t = max(0.0, min(1.0, ((-ax)*dx + (-ay)*dy) / seg_len_sq))
187
+ cx = ax + t*dx
188
+ cy = ay + t*dy
189
+ return math.sqrt(cx*cx + cy*cy)
190
+
191
  def analyze_route(self, route):
192
  signal_count = 0
193
 
 
200
  for junction in self.junctions:
201
  j_lat = junction["lat"]
202
  j_lng = junction["lng"]
203
+ detected = False
204
+
205
+ # Check each edge segment of the route, not just nodes
206
+ # Fixes sparse-node highways where no single node is within radius
207
+ for i in range(len(route_coords) - 1):
208
+ alat, alng = route_coords[i]
209
+ blat, blng = route_coords[i + 1]
210
 
211
+ # Quick bbox reject
212
+ if (min(alat, blat) - j_lat) * _LAT_TO_M > self.detection_radius:
213
+ continue
214
+ if (j_lat - max(alat, blat)) * _LAT_TO_M > self.detection_radius:
215
  continue
216
 
217
+ dist = self._point_to_segment_dist(j_lat, j_lng, alat, alng, blat, blng)
218
  if dist <= self.detection_radius:
219
  signal_count += 1
220
+ detected = True
221
  break
222
 
223
  expected_stops = signal_count * self.stop_prob
 
232
  # ── Attach Signal Weights to Graph ────────────────────────────────────────
233
 
234
  def attach_signal_weights(self):
235
+ # Mark ALL nodes within detection_radius of each junction centroid.
236
+ # A single dot in the registry may represent a 4-way junction β€”
237
+ # all roads passing through that junction must carry the signal weight,
238
+ # not just the one road the dot happens to snap to.
239
  node_to_junction = {}
240
+ all_node_coords = [
241
+ (n, self.G.nodes[n]["y"], self.G.nodes[n]["x"])
242
+ for n in self.G.nodes
243
+ ]
244
+
245
  for jid, junction in enumerate(self.junctions):
246
+ jlat, jlng = junction["lat"], junction["lng"]
247
+ for node, nlat, nlng in all_node_coords:
248
+ if abs(nlat - jlat) * _LAT_TO_M > self.detection_radius:
249
+ continue
250
+ if _fast_dist_m(jlat, jlng, nlat, nlng) <= self.detection_radius:
251
+ # If node already tagged by a closer junction, keep that one
252
+ if node not in node_to_junction:
253
+ node_to_junction[node] = jid
254
+
255
+ logger.info(f"[SignalModel] Nodes tagged with signal: {len(node_to_junction)}")
256
 
257
  expected_delay_min = self.stop_prob * (self.avg_wait / 60.0)
258