AxelHolst commited on
Commit
0fd960e
·
1 Parent(s): 049ee46

feat: add high-res heatmap v3 with Hopsworks storage

Browse files

- Update app.py with v3/v2 fallback fetch from Hopsworks
- Add batch prediction to predictor.py
- Add contours.py for GeoJSON grid cell generation
- Add trip_info.py for bus stop/route info
- Update requirements.txt with scipy, shapely, matplotlib

Files changed (6) hide show
  1. app.py +294 -72
  2. contours.py +350 -0
  3. predictor.py +96 -11
  4. requirements.txt +3 -0
  5. trip_info.py +65 -0
  6. weather.py +24 -2
app.py CHANGED
@@ -6,6 +6,7 @@ bus crowding in Östergötland.
6
  """
7
 
8
  import streamlit as st
 
9
 
10
  # Page config - MUST be first Streamlit command
11
  st.set_page_config(
@@ -16,32 +17,37 @@ st.set_page_config(
16
 
17
  import os
18
  import folium
19
- from folium.plugins import HeatMap
20
  from streamlit_folium import st_folium
21
  import numpy as np
22
  from datetime import datetime, timedelta
23
 
 
 
24
  # Import prediction and data fetching modules
25
  from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
26
  from weather import get_weather_for_prediction
27
  from holidays import get_holiday_features
 
 
28
 
29
  # Constants
30
  DEFAULT_LAT = 58.4108
31
  DEFAULT_LON = 15.6214
32
- DEFAULT_ZOOM = 10
33
 
 
 
34
  BOUNDS = {
35
- "min_lat": 57.8,
36
- "max_lat": 58.9,
37
- "min_lon": 14.5,
38
- "max_lon": 16.8
39
  }
40
 
41
- # Color scheme for occupancy levels
42
  OCCUPANCY_COLORS = {
43
  0: "#22c55e", # Empty - green
44
- 1: "#22c55e", # Many seats - green
45
  2: "#eab308", # Few seats - yellow
46
  3: "#f97316", # Standing - orange
47
  4: "#ef4444", # Crushed - red
@@ -49,6 +55,25 @@ OCCUPANCY_COLORS = {
49
  6: "#6b7280", # Not accepting - gray
50
  }
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  @st.cache_resource
54
  def get_model():
@@ -60,40 +85,167 @@ def get_model():
60
  return None
61
 
62
 
63
- def generate_heatmap_data(hour, day_of_week, weather, holidays):
64
- """Generate heat map data by predicting crowding across a grid."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  model = get_model()
66
  if model is None:
67
- return []
68
 
69
- # Create grid of points across Östergötland
70
  lat_steps = 15
71
  lon_steps = 20
72
  lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps)
73
  lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps)
74
 
75
- heatmap_data = []
 
76
 
 
77
  for lat in lats:
78
  for lon in lons:
79
  try:
80
- pred_class, confidence, _ = predict_occupancy(
81
  lat=lat, lon=lon, hour=hour, day_of_week=day_of_week,
82
  weather=weather, holidays=holidays
83
  )
84
- # Weight by occupancy level (higher = more crowded = more intense)
85
- intensity = pred_class / 5.0 # Normalize to 0-1
86
- if intensity > 0.1: # Only show if there's some crowding
87
- heatmap_data.append([lat, lon, intensity])
88
  except Exception:
89
- pass
90
-
91
- return heatmap_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
95
- heatmap_data=None):
96
- """Create a Folium map with optional marker and heatmap."""
97
  center_lat = selected_lat if selected_lat else DEFAULT_LAT
98
  center_lon = selected_lon if selected_lon else DEFAULT_LON
99
 
@@ -103,30 +255,37 @@ def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
103
  tiles="CartoDB positron"
104
  )
105
 
106
- # Add coverage area rectangle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  folium.Rectangle(
108
  bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
109
  [BOUNDS["max_lat"], BOUNDS["max_lon"]]],
110
- color="#3388ff",
111
  fill=False,
112
- weight=2,
113
- opacity=0.5,
 
114
  ).add_to(m)
115
 
116
- # Add heatmap if enabled
117
- if show_heatmap and heatmap_data and len(heatmap_data) > 0:
118
- HeatMap(
119
- data=heatmap_data,
120
- min_opacity=0.3,
121
- radius=25,
122
- blur=15,
123
- ).add_to(m)
124
-
125
  # Add marker if location selected
126
  if selected_lat and selected_lon:
127
  folium.Marker(
128
  [selected_lat, selected_lon],
129
  tooltip=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}",
 
130
  ).add_to(m)
131
 
132
  return m
@@ -146,7 +305,7 @@ def make_prediction(lat, lon, selected_datetime):
146
  weather = get_weather_for_prediction(lat, lon, selected_datetime)
147
  holidays = get_holiday_features(selected_datetime)
148
 
149
- pred_class, confidence, probs = predict_occupancy(
150
  lat=lat, lon=lon,
151
  hour=selected_datetime.hour,
152
  day_of_week=selected_datetime.weekday(),
@@ -154,10 +313,16 @@ def make_prediction(lat, lon, selected_datetime):
154
  holidays=holidays
155
  )
156
 
 
 
 
 
 
157
  return pred_class, confidence, {
158
  "weather": weather,
159
  "holidays": holidays,
160
- "datetime": selected_datetime
 
161
  }
162
  except Exception as e:
163
  return None, None, str(e)
@@ -170,13 +335,13 @@ if "selected_lon" not in st.session_state:
170
  st.session_state.selected_lon = DEFAULT_LON
171
 
172
  # Header
173
- st.title("🐟 HappySardines")
174
- st.markdown("*How packed are buses in Östergötland?*")
175
 
176
  # Check if model is available
177
  model = get_model()
178
  if model is None:
179
- st.error("⚠️ Could not load prediction model. Please check the configuration.")
180
  st.stop()
181
 
182
  # Sidebar controls
@@ -198,19 +363,20 @@ with st.sidebar:
198
 
199
  # View mode
200
  st.subheader("View Mode")
201
- show_heatmap = st.toggle("Show Crowding Forecast", value=False,
202
  help="Display predicted crowding across the region")
203
 
204
  if show_heatmap:
205
- st.info("🔥 Heat map shows predicted crowding levels. Red = busy, Green = quiet.")
206
-
207
- if st.button("Generate Heat Map", type="primary"):
208
- with st.spinner("Generating predictions across region..."):
209
- weather = get_weather_for_prediction(DEFAULT_LAT, DEFAULT_LON, selected_datetime)
210
- holidays = get_holiday_features(selected_datetime)
211
- st.session_state.heatmap_data = generate_heatmap_data(
212
- hour, selected_date.weekday(), weather, holidays
213
- )
 
214
 
215
  st.divider()
216
 
@@ -224,10 +390,10 @@ with st.sidebar:
224
  - 🕐 Time of day
225
  - 📅 Day of week
226
  - 🌡️ Weather conditions
227
- - 🎉 Holidays
228
 
229
  **Data sources:**
230
- - Bus occupancy data from Östgötatrafiken
231
  - Weather from Open-Meteo
232
  - Holidays from Svenska Dagar API
233
 
@@ -238,24 +404,34 @@ with st.sidebar:
238
  col1, col2 = st.columns([2, 1])
239
 
240
  with col1:
241
- st.subheader("📍 Click on the map to select a location")
242
 
243
- # Get heatmap data if available
244
- heatmap_data = st.session_state.get("heatmap_data", [])
 
 
 
 
 
 
 
 
 
245
 
246
  # Create and display map
247
  m = create_map(
248
  selected_lat=st.session_state.selected_lat,
249
  selected_lon=st.session_state.selected_lon,
250
  show_heatmap=show_heatmap,
251
- heatmap_data=heatmap_data
252
  )
253
 
 
254
  map_data = st_folium(
255
  m,
256
  height=500,
257
  use_container_width=True,
258
- key="map"
259
  )
260
 
261
  # Handle map clicks
@@ -266,17 +442,18 @@ with col1:
266
  st.rerun()
267
 
268
  with col2:
269
- st.subheader("🔮 Prediction")
270
 
271
  # Show selected coordinates
272
  st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
273
 
274
  # Make prediction
275
- pred_class, confidence, result = make_prediction(
276
- st.session_state.selected_lat,
277
- st.session_state.selected_lon,
278
- selected_datetime
279
- )
 
280
 
281
  if pred_class is not None:
282
  label_info = OCCUPANCY_LABELS[pred_class]
@@ -307,17 +484,62 @@ with col2:
307
  if isinstance(result, dict):
308
  weather = result["weather"]
309
  holidays = result["holidays"]
 
310
 
311
- day_type = "🎉 Holiday" if holidays.get("is_red_day") else (
312
- "🏖️ Work-free day" if holidays.get("is_work_free") else "📅 Regular day"
313
  )
314
 
315
- st.markdown(f"""
316
- **Conditions:**
317
- - 🌡️ {weather.get('temperature_2m', '?'):.0f}°C
318
- - {day_type}
319
- - {selected_datetime.strftime('%A')}
320
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  elif isinstance(result, str):
323
  st.error(result)
 
6
  """
7
 
8
  import streamlit as st
9
+ import json
10
 
11
  # Page config - MUST be first Streamlit command
12
  st.set_page_config(
 
17
 
18
  import os
19
  import folium
 
20
  from streamlit_folium import st_folium
21
  import numpy as np
22
  from datetime import datetime, timedelta
23
 
24
+ import hopsworks
25
+
26
  # Import prediction and data fetching modules
27
  from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
28
  from weather import get_weather_for_prediction
29
  from holidays import get_holiday_features
30
+ from trip_info import load_static_trip_info, find_nearest_trip, load_static_stops_info, find_closest_stop
31
+ from contours import load_contours_from_file, grid_to_cells_geojson
32
 
33
  # Constants
34
  DEFAULT_LAT = 58.4108
35
  DEFAULT_LON = 15.6214
36
+ DEFAULT_ZOOM = 9 # Slightly zoomed out to show more of the region
37
 
38
+ # Bounds derived from actual GTFS stop locations (3119 stops)
39
+ # Run ui/get_boundaries.py to recalculate if needed
40
  BOUNDS = {
41
+ "min_lat": 56.6414,
42
+ "max_lat": 58.8654,
43
+ "min_lon": 14.6144,
44
+ "max_lon": 16.9578,
45
  }
46
 
47
+ # Color scheme for occupancy levels (must match contours.py CLASS_COLORS)
48
  OCCUPANCY_COLORS = {
49
  0: "#22c55e", # Empty - green
50
+ 1: "#84cc16", # Many seats - lime
51
  2: "#eab308", # Few seats - yellow
52
  3: "#f97316", # Standing - orange
53
  4: "#ef4444", # Crushed - red
 
55
  6: "#6b7280", # Not accepting - gray
56
  }
57
 
58
+ # Lazy-load static data (deferred to avoid blocking app startup)
59
+ @st.cache_resource
60
+ def get_static_trip_df():
61
+ """Load static trip info from Hopsworks (cached)."""
62
+ try:
63
+ return load_static_trip_info()
64
+ except Exception as e:
65
+ print(f"Warning: Could not load static trip info: {e}")
66
+ return None
67
+
68
+ @st.cache_resource
69
+ def get_static_stops_df():
70
+ """Load static stops info from Hopsworks (cached)."""
71
+ try:
72
+ return load_static_stops_info()
73
+ except Exception as e:
74
+ print(f"Warning: Could not load static stops info: {e}")
75
+ return None
76
+
77
 
78
  @st.cache_resource
79
  def get_model():
 
85
  return None
86
 
87
 
88
+ @st.cache_data
89
+ def cached_predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
90
+ return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
91
+
92
+
93
+ @st.cache_data(ttl=3600) # Cache for 1 hour
94
+ def fetch_heatmaps_from_hopsworks():
95
+ """
96
+ Fetch all precomputed heatmaps from Hopsworks Feature Store.
97
+
98
+ Tries v3 (high-res 40x50) first, falls back to v2 (low-res 20x25).
99
+
100
+ Returns dict mapping (hour, weekday) -> GeoJSON FeatureCollection
101
+ """
102
+ try:
103
+ print("Fetching heatmaps from Hopsworks...")
104
+ project = hopsworks.login()
105
+ fs = project.get_feature_store()
106
+
107
+ # Try v3 first (high-res), fall back to v2 (low-res)
108
+ for version in [3, 2]:
109
+ try:
110
+ heatmap_fg = fs.get_feature_group("heatmap_geojson_fg", version=version)
111
+ df = heatmap_fg.read()
112
+
113
+ if df is not None and not df.empty:
114
+ # Convert to dict with tuple keys
115
+ heatmaps = {}
116
+ for _, row in df.iterrows():
117
+ key = (int(row["hour"]), int(row["weekday"]))
118
+ geojson = json.loads(row["geojson"])
119
+ heatmaps[key] = geojson
120
+
121
+ print(f"Loaded {len(heatmaps)} heatmaps from Hopsworks v{version}")
122
+ return heatmaps
123
+ else:
124
+ print(f"No data in Hopsworks v{version}, trying fallback...")
125
+ except Exception as e:
126
+ print(f"Could not fetch v{version}: {e}")
127
+ continue
128
+
129
+ print("No heatmap data found in any Hopsworks version")
130
+ return {}
131
+
132
+ except Exception as e:
133
+ print(f"Error fetching heatmaps from Hopsworks: {e}")
134
+ return {}
135
+
136
+
137
+ def load_precomputed_contours():
138
+ """Load precomputed contour GeoJSON from file (not cached to pick up new files)."""
139
+ script_dir = os.path.dirname(os.path.abspath(__file__))
140
+ contours_path = os.path.join(script_dir, "precomputed_contours.json")
141
+
142
+ if os.path.exists(contours_path):
143
+ try:
144
+ contours = load_contours_from_file(contours_path)
145
+ print(f"Loaded {len(contours)} precomputed time slots from {contours_path}")
146
+ return contours
147
+ except Exception as e:
148
+ print(f"Error loading contours: {e}")
149
+ return {}
150
+ print(f"Contours file not found: {contours_path}")
151
+ return {}
152
+
153
+
154
+ def generate_contours_on_demand(hour, day_of_week, weather, holidays):
155
+ """
156
+ Generate grid cell GeoJSON on-demand if precomputed data is not available.
157
+ This is slower but provides a fallback.
158
+ """
159
  model = get_model()
160
  if model is None:
161
+ return None
162
 
163
+ # Grid for on-demand generation (smaller for speed)
164
  lat_steps = 15
165
  lon_steps = 20
166
  lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps)
167
  lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps)
168
 
169
+ lat_step = (BOUNDS["max_lat"] - BOUNDS["min_lat"]) / (lat_steps - 1)
170
+ lon_step = (BOUNDS["max_lon"] - BOUNDS["min_lon"]) / (lon_steps - 1)
171
 
172
+ prediction_data = []
173
  for lat in lats:
174
  for lon in lons:
175
  try:
176
+ pred_class, confidence, _ = cached_predict_occupancy(
177
  lat=lat, lon=lon, hour=hour, day_of_week=day_of_week,
178
  weather=weather, holidays=holidays
179
  )
180
+ prediction_data.append([lat, lon, pred_class])
 
 
 
181
  except Exception:
182
+ prediction_data.append([lat, lon, 0])
183
+
184
+ # Convert to GeoJSON grid cells
185
+ return grid_to_cells_geojson(prediction_data, lat_step, lon_step)
186
+
187
+
188
+ def get_test_contour_geojson():
189
+ """
190
+ Return a simple hardcoded test GeoJSON to verify rendering works.
191
+ Creates a small grid of cells with different colors.
192
+ """
193
+ # Create a 3x3 grid of test cells around Linköping
194
+ center_lat = 58.41
195
+ center_lon = 15.62
196
+ cell_size = 0.15
197
+
198
+ # Test predictions: mix of classes
199
+ test_data = [
200
+ (center_lat - cell_size, center_lon - cell_size, 0), # green
201
+ (center_lat - cell_size, center_lon, 0), # green
202
+ (center_lat - cell_size, center_lon + cell_size, 1), # green
203
+ (center_lat, center_lon - cell_size, 0), # green
204
+ (center_lat, center_lon, 2), # yellow
205
+ (center_lat, center_lon + cell_size, 2), # yellow
206
+ (center_lat + cell_size, center_lon - cell_size, 0), # green
207
+ (center_lat + cell_size, center_lon, 3), # orange
208
+ (center_lat + cell_size, center_lon + cell_size, 0), # green
209
+ ]
210
+
211
+ return grid_to_cells_geojson(test_data, cell_size, cell_size)
212
+
213
+
214
+ def get_contour_geojson(hour, day_of_week, weather=None, holidays=None):
215
+ """
216
+ Get contour GeoJSON for the given hour and day of week.
217
+
218
+ Tries sources in order:
219
+ 1. Hopsworks Feature Store (primary)
220
+ 2. Local JSON file (fallback)
221
+ 3. Test contours (last resort)
222
+ """
223
+ key = (hour, day_of_week)
224
+
225
+ # Try Hopsworks first
226
+ hopsworks_heatmaps = fetch_heatmaps_from_hopsworks()
227
+ if key in hopsworks_heatmaps:
228
+ geojson = hopsworks_heatmaps[key]
229
+ n_features = len(geojson.get("features", []))
230
+ print(f"Found heatmap in Hopsworks for {key}: {n_features} features")
231
+ return geojson
232
+
233
+ # Fall back to local JSON file
234
+ precomputed = load_precomputed_contours()
235
+ if key in precomputed:
236
+ geojson = precomputed[key]
237
+ n_features = len(geojson.get("features", []))
238
+ print(f"Found heatmap in local file for {key}: {n_features} features")
239
+ return geojson
240
+
241
+ # Last resort: test contours
242
+ print(f"No heatmap for {key}, using test contours")
243
+ return get_test_contour_geojson()
244
 
245
 
246
  def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
247
+ contour_geojson=None):
248
+ """Create a Folium map with optional marker and contour overlay."""
249
  center_lat = selected_lat if selected_lat else DEFAULT_LAT
250
  center_lon = selected_lon if selected_lon else DEFAULT_LON
251
 
 
255
  tiles="CartoDB positron"
256
  )
257
 
258
+ # Add contour overlay if enabled
259
+ if show_heatmap and contour_geojson and contour_geojson.get("features"):
260
+ # Add each contour level as a separate GeoJSON layer
261
+ folium.GeoJson(
262
+ contour_geojson,
263
+ style_function=lambda feature: {
264
+ 'fillColor': feature['properties']['color'],
265
+ 'fillOpacity': feature['properties'].get('fillOpacity', 0.35),
266
+ 'color': 'none', # No border
267
+ 'weight': 0
268
+ },
269
+ name="Crowding Forecast"
270
+ ).add_to(m)
271
+
272
+ # Add coverage area rectangle (subtle border)
273
  folium.Rectangle(
274
  bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
275
  [BOUNDS["max_lat"], BOUNDS["max_lon"]]],
276
+ color="#6b7280",
277
  fill=False,
278
+ weight=1,
279
+ opacity=0.3,
280
+ dash_array="5, 5",
281
  ).add_to(m)
282
 
 
 
 
 
 
 
 
 
 
283
  # Add marker if location selected
284
  if selected_lat and selected_lon:
285
  folium.Marker(
286
  [selected_lat, selected_lon],
287
  tooltip=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}",
288
+ icon=folium.Icon(color="blue", icon="info-sign")
289
  ).add_to(m)
290
 
291
  return m
 
305
  weather = get_weather_for_prediction(lat, lon, selected_datetime)
306
  holidays = get_holiday_features(selected_datetime)
307
 
308
+ pred_class, confidence, probs = cached_predict_occupancy(
309
  lat=lat, lon=lon,
310
  hour=selected_datetime.hour,
311
  day_of_week=selected_datetime.weekday(),
 
313
  holidays=holidays
314
  )
315
 
316
+ trip_info = None
317
+ static_trip_df = get_static_trip_df()
318
+ if static_trip_df is not None:
319
+ trip_info = find_nearest_trip(lat, lon, selected_datetime, static_trip_df)
320
+
321
  return pred_class, confidence, {
322
  "weather": weather,
323
  "holidays": holidays,
324
+ "datetime": selected_datetime,
325
+ "trip_info": trip_info
326
  }
327
  except Exception as e:
328
  return None, None, str(e)
 
335
  st.session_state.selected_lon = DEFAULT_LON
336
 
337
  # Header
338
+ st.title("HappySardines")
339
+ st.markdown("*Predicted bus crowding in Östergötland*")
340
 
341
  # Check if model is available
342
  model = get_model()
343
  if model is None:
344
+ st.error("Could not load prediction model. Please check the configuration.")
345
  st.stop()
346
 
347
  # Sidebar controls
 
363
 
364
  # View mode
365
  st.subheader("View Mode")
366
+ show_heatmap = st.toggle("Show Crowding Forecast", value=True,
367
  help="Display predicted crowding across the region")
368
 
369
  if show_heatmap:
370
+ st.markdown("""
371
+ **Legend:**
372
+ <div style="display: flex; flex-direction: column; gap: 4px; font-size: 14px;">
373
+ <div><span style="display: inline-block; width: 16px; height: 16px; background: #22c55e; border-radius: 3px; vertical-align: middle;"></span> Empty</div>
374
+ <div><span style="display: inline-block; width: 16px; height: 16px; background: #84cc16; border-radius: 3px; vertical-align: middle;"></span> Many seats</div>
375
+ <div><span style="display: inline-block; width: 16px; height: 16px; background: #eab308; border-radius: 3px; vertical-align: middle;"></span> Few seats</div>
376
+ <div><span style="display: inline-block; width: 16px; height: 16px; background: #f97316; border-radius: 3px; vertical-align: middle;"></span> Standing room</div>
377
+ <div><span style="display: inline-block; width: 16px; height: 16px; background: #ef4444; border-radius: 3px; vertical-align: middle;"></span> Crowded</div>
378
+ </div>
379
+ """, unsafe_allow_html=True)
380
 
381
  st.divider()
382
 
 
390
  - 🕐 Time of day
391
  - 📅 Day of week
392
  - 🌡️ Weather conditions
393
+ - 🇸🇪 Holidays
394
 
395
  **Data sources:**
396
+ - Bus occupancy data from Östgötatrafiken (KODA API)
397
  - Weather from Open-Meteo
398
  - Holidays from Svenska Dagar API
399
 
 
404
  col1, col2 = st.columns([2, 1])
405
 
406
  with col1:
407
+ st.subheader("Click on the map to select a location")
408
 
409
+ # Get weather/holidays for on-demand generation fallback
410
+ weather = get_weather_for_prediction(DEFAULT_LAT, DEFAULT_LON, selected_datetime)
411
+ holidays = get_holiday_features(selected_datetime)
412
+
413
+ # Get contour GeoJSON for current hour/day
414
+ contour_geojson = None
415
+ if show_heatmap:
416
+ contour_geojson = get_contour_geojson(
417
+ hour, selected_date.weekday(),
418
+ weather=weather, holidays=holidays
419
+ )
420
 
421
  # Create and display map
422
  m = create_map(
423
  selected_lat=st.session_state.selected_lat,
424
  selected_lon=st.session_state.selected_lon,
425
  show_heatmap=show_heatmap,
426
+ contour_geojson=contour_geojson
427
  )
428
 
429
+ # Render the map - key includes hour and day to force re-render on time change
430
  map_data = st_folium(
431
  m,
432
  height=500,
433
  use_container_width=True,
434
+ key=f"map_{hour}_{selected_date.weekday()}"
435
  )
436
 
437
  # Handle map clicks
 
442
  st.rerun()
443
 
444
  with col2:
445
+ st.subheader("Prediction")
446
 
447
  # Show selected coordinates
448
  st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
449
 
450
  # Make prediction
451
+ with st.spinner("Fetching prediction..."):
452
+ pred_class, confidence, result = make_prediction(
453
+ st.session_state.selected_lat,
454
+ st.session_state.selected_lon,
455
+ selected_datetime
456
+ )
457
 
458
  if pred_class is not None:
459
  label_info = OCCUPANCY_LABELS[pred_class]
 
484
  if isinstance(result, dict):
485
  weather = result["weather"]
486
  holidays = result["holidays"]
487
+ trip_info = result.get("trip_info")
488
 
489
+ day_type = "Holiday" if holidays.get("is_red_day") else (
490
+ "Work-free day" if holidays.get("is_work_free") else "Regular day"
491
  )
492
 
493
+ if trip_info:
494
+ info_lines = []
495
+
496
+ # Route number or name
497
+ route_number = trip_info.get("route_short_name")
498
+ route_long_name = trip_info.get("route_long_name")
499
+ if route_number and route_long_name:
500
+ info_lines.append(f"{route_number} - {route_long_name}")
501
+ elif route_number:
502
+ info_lines.append(f"Route: {route_number}")
503
+ elif route_long_name:
504
+ info_lines.append(f"Route: {route_long_name}")
505
+
506
+ # Bus type / description
507
+ route_desc = trip_info.get("route_desc")
508
+ if route_desc:
509
+ info_lines.append(f"Type: {route_desc}")
510
+
511
+ # Trip ID
512
+ trip_id = trip_info.get("trip_id")
513
+ if trip_id is not None:
514
+ # Closest stop
515
+ static_stops_df = get_static_stops_df()
516
+ closest_stop = find_closest_stop(
517
+ st.session_state.selected_lat,
518
+ st.session_state.selected_lon,
519
+ trip_id,
520
+ static_stops_df
521
+ )
522
+ if closest_stop:
523
+ info_lines.append(f"Nearest stop: {closest_stop}")
524
+
525
+ if info_lines:
526
+ st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
527
+
528
+ # Weather conditions
529
+ conditions = []
530
+ temp = weather.get('temperature_2m')
531
+ if temp is not None:
532
+ conditions.append(f"{temp:.0f}°C")
533
+
534
+ if weather.get('snowfall', 0) > 0:
535
+ conditions.append("Snow")
536
+ if weather.get('rain', 0) > 0:
537
+ conditions.append("Rain")
538
+
539
+ conditions.append(day_type)
540
+ conditions.append(selected_datetime.strftime('%A'))
541
+
542
+ st.markdown("**Conditions:** " + " | ".join(conditions))
543
 
544
  elif isinstance(result, str):
545
  st.error(result)
contours.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contour generation module for heatmap visualization.
3
+
4
+ Converts prediction grid data into GeoJSON polygons that can be rendered
5
+ as vector overlays on Folium maps. This provides zoom-independent visualization
6
+ similar to weather radar overlays.
7
+ """
8
+
9
+ import numpy as np
10
+ from scipy.interpolate import griddata
11
+ from scipy.ndimage import binary_dilation
12
+ import matplotlib
13
+ matplotlib.use('Agg') # Non-interactive backend for server use
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib.path import Path
16
+ from shapely.geometry import Polygon, MultiPolygon, mapping
17
+ from shapely.ops import unary_union
18
+ from shapely.validation import make_valid
19
+ import json
20
+
21
+
22
+ # Color scheme: green -> lime -> yellow -> orange -> red
23
+ # Each class gets a distinct color for clear differentiation
24
+ CLASS_COLORS = {
25
+ 0: "#22c55e", # Empty - green
26
+ 1: "#84cc16", # Many seats - lime (green-yellow mix)
27
+ 2: "#eab308", # Few seats - yellow
28
+ 3: "#f97316", # Standing room - orange
29
+ 4: "#ef4444", # Crushed standing - red
30
+ 5: "#ef4444", # Full - red
31
+ 6: "#6b7280", # Not accepting - gray
32
+ }
33
+
34
+ # Legacy contour colors (for backwards compatibility)
35
+ CONTOUR_COLORS = [
36
+ "#22c55e", # 0.0-0.2: Green (class 0 - empty)
37
+ "#eab308", # 0.2-0.4: Yellow (class 1 - many seats)
38
+ "#f97316", # 0.4-0.6: Orange (class 2 - few seats)
39
+ "#ef4444", # 0.6-0.8: Red (class 3 - standing)
40
+ "#7f1d1d", # 0.8-1.0: Dark red (class 4+ - crowded)
41
+ ]
42
+
43
+ # Contour levels (intensity thresholds)
44
+ CONTOUR_LEVELS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
45
+
46
+
47
+ def _extract_polygons_from_contour(contour_set, level_idx):
48
+ """
49
+ Extract polygons from a contourf result for a specific level.
50
+ Compatible with matplotlib 3.8+ which removed .collections attribute.
51
+ """
52
+ polygons = []
53
+
54
+ # Try new API first (matplotlib 3.8+)
55
+ if hasattr(contour_set, 'get_paths'):
56
+ # New API: iterate through all paths
57
+ all_paths = contour_set.get_paths()
58
+ # In new API, paths are organized differently
59
+ # We need to use allsegs instead
60
+ pass
61
+
62
+ # Use allsegs which works in both old and new matplotlib
63
+ if hasattr(contour_set, 'allsegs'):
64
+ if level_idx < len(contour_set.allsegs):
65
+ segments = contour_set.allsegs[level_idx]
66
+ for seg in segments:
67
+ if len(seg) >= 4:
68
+ try:
69
+ poly = Polygon(seg)
70
+ if not poly.is_valid:
71
+ poly = make_valid(poly)
72
+ if poly.is_valid and not poly.is_empty and poly.area > 0:
73
+ if isinstance(poly, MultiPolygon):
74
+ polygons.extend(poly.geoms)
75
+ else:
76
+ polygons.append(poly)
77
+ except Exception:
78
+ continue
79
+ return polygons
80
+
81
+
82
+ def grid_to_contour_geojson(
83
+ prediction_data: list,
84
+ bounds: dict,
85
+ interpolation_resolution: int = 100,
86
+ fill_opacity: float = 0.35,
87
+ ) -> dict:
88
+ """
89
+ Convert prediction grid data to GeoJSON FeatureCollection with filled contour polygons.
90
+
91
+ Args:
92
+ prediction_data: List of [lat, lon, intensity] where intensity is 0-1
93
+ bounds: Dict with min_lat, max_lat, min_lon, max_lon
94
+ interpolation_resolution: Number of points per axis for interpolation grid
95
+ fill_opacity: Opacity for the fill color (0-1)
96
+
97
+ Returns:
98
+ GeoJSON FeatureCollection with colored polygon features
99
+ """
100
+ if not prediction_data or len(prediction_data) < 4:
101
+ return _empty_feature_collection()
102
+
103
+ # Extract coordinates and values
104
+ points = np.array([[p[0], p[1]] for p in prediction_data]) # lat, lon
105
+ values = np.array([p[2] for p in prediction_data]) # intensity
106
+
107
+ # Create fine interpolation grid
108
+ lat_fine = np.linspace(bounds["min_lat"], bounds["max_lat"], interpolation_resolution)
109
+ lon_fine = np.linspace(bounds["min_lon"], bounds["max_lon"], interpolation_resolution)
110
+ lon_grid, lat_grid = np.meshgrid(lon_fine, lat_fine)
111
+
112
+ # Interpolate to fine grid (using lat,lon order for points)
113
+ try:
114
+ values_fine = griddata(
115
+ points, # (lat, lon) pairs
116
+ values,
117
+ (lat_grid, lon_grid), # grid in (lat, lon) format
118
+ method='cubic',
119
+ fill_value=0.0
120
+ )
121
+ except Exception:
122
+ # Fall back to linear if cubic fails
123
+ values_fine = griddata(
124
+ points,
125
+ values,
126
+ (lat_grid, lon_grid),
127
+ method='linear',
128
+ fill_value=0.0
129
+ )
130
+
131
+ # Clip values to valid range
132
+ values_fine = np.clip(values_fine, 0.0, 1.0)
133
+
134
+ # Generate contours using matplotlib (but don't display)
135
+ fig, ax = plt.subplots(figsize=(10, 10))
136
+
137
+ # contourf returns a QuadContourSet
138
+ contour_set = ax.contourf(
139
+ lon_grid, lat_grid, values_fine,
140
+ levels=CONTOUR_LEVELS,
141
+ extend='neither'
142
+ )
143
+
144
+ plt.close(fig) # Don't display, just extract the polygons
145
+
146
+ # Convert matplotlib contours to GeoJSON features
147
+ features = []
148
+
149
+ for level_idx in range(len(CONTOUR_LEVELS) - 1):
150
+ if level_idx >= len(CONTOUR_COLORS):
151
+ break
152
+
153
+ color = CONTOUR_COLORS[level_idx]
154
+ level_min = CONTOUR_LEVELS[level_idx]
155
+ level_max = CONTOUR_LEVELS[level_idx + 1]
156
+
157
+ # Extract polygons for this level
158
+ polygons = _extract_polygons_from_contour(contour_set, level_idx)
159
+
160
+ if not polygons:
161
+ continue
162
+
163
+ # Merge overlapping polygons at this level
164
+ try:
165
+ merged = unary_union(polygons)
166
+ if merged.is_empty:
167
+ continue
168
+ except Exception:
169
+ continue
170
+
171
+ # Create GeoJSON feature
172
+ feature = {
173
+ "type": "Feature",
174
+ "properties": {
175
+ "color": color,
176
+ "fillOpacity": fill_opacity,
177
+ "level_min": level_min,
178
+ "level_max": level_max,
179
+ "level_idx": level_idx,
180
+ },
181
+ "geometry": mapping(merged)
182
+ }
183
+ features.append(feature)
184
+
185
+ return {
186
+ "type": "FeatureCollection",
187
+ "features": features
188
+ }
189
+
190
+
191
+ def _empty_feature_collection() -> dict:
192
+ """Return an empty GeoJSON FeatureCollection."""
193
+ return {
194
+ "type": "FeatureCollection",
195
+ "features": []
196
+ }
197
+
198
+
199
+ def grid_to_cells_geojson(
200
+ prediction_data: list,
201
+ lat_step: float,
202
+ lon_step: float,
203
+ fill_opacity: float = 0.35,
204
+ ) -> dict:
205
+ """
206
+ Convert prediction grid to GeoJSON rectangles - one cell per prediction point.
207
+
208
+ This is simpler and more accurate than contours:
209
+ - No interpolation artifacts
210
+ - Each cell shows the exact prediction for that area
211
+ - No fake background fill
212
+
213
+ Args:
214
+ prediction_data: List of [lat, lon, pred_class] where pred_class is 0-6
215
+ lat_step: Height of each cell in degrees
216
+ lon_step: Width of each cell in degrees
217
+ fill_opacity: Opacity for the fill color (0-1)
218
+
219
+ Returns:
220
+ GeoJSON FeatureCollection with colored rectangle features
221
+ """
222
+ if not prediction_data:
223
+ return _empty_feature_collection()
224
+
225
+ features = []
226
+ half_lat = lat_step / 2
227
+ half_lon = lon_step / 2
228
+
229
+ for lat, lon, pred_class in prediction_data:
230
+ pred_class = int(pred_class)
231
+ color = CLASS_COLORS.get(pred_class, CLASS_COLORS[0])
232
+
233
+ # Create rectangle centered on the prediction point
234
+ coords = [[
235
+ [lon - half_lon, lat - half_lat], # SW
236
+ [lon + half_lon, lat - half_lat], # SE
237
+ [lon + half_lon, lat + half_lat], # NE
238
+ [lon - half_lon, lat + half_lat], # NW
239
+ [lon - half_lon, lat - half_lat], # SW (close polygon)
240
+ ]]
241
+
242
+ feature = {
243
+ "type": "Feature",
244
+ "properties": {
245
+ "color": color,
246
+ "fillOpacity": fill_opacity,
247
+ "pred_class": pred_class,
248
+ },
249
+ "geometry": {
250
+ "type": "Polygon",
251
+ "coordinates": coords
252
+ }
253
+ }
254
+ features.append(feature)
255
+
256
+ return {
257
+ "type": "FeatureCollection",
258
+ "features": features
259
+ }
260
+
261
+
262
+ def precompute_contours_for_all_times(
263
+ prediction_func,
264
+ bounds: dict,
265
+ hours: list = None,
266
+ weekdays: list = None,
267
+ lat_steps: int = 20,
268
+ lon_steps: int = 25,
269
+ ) -> dict:
270
+ """
271
+ Precompute contour GeoJSON for all hour/weekday combinations.
272
+
273
+ Args:
274
+ prediction_func: Function(lat, lon, hour, weekday) -> intensity (0-1)
275
+ bounds: Geographic bounds dict
276
+ hours: List of hours to compute (default: 5-23)
277
+ weekdays: List of weekdays to compute (default: 0-6)
278
+ lat_steps: Number of latitude grid points
279
+ lon_steps: Number of longitude grid points
280
+
281
+ Returns:
282
+ Dict mapping (hour, weekday) -> GeoJSON FeatureCollection
283
+ """
284
+ if hours is None:
285
+ hours = list(range(5, 24)) # 5:00 to 23:00
286
+ if weekdays is None:
287
+ weekdays = list(range(7)) # Monday to Sunday
288
+
289
+ # Generate grid points
290
+ lats = np.linspace(bounds["min_lat"], bounds["max_lat"], lat_steps)
291
+ lons = np.linspace(bounds["min_lon"], bounds["max_lon"], lon_steps)
292
+
293
+ results = {}
294
+ total = len(hours) * len(weekdays)
295
+ count = 0
296
+
297
+ for hour in hours:
298
+ for weekday in weekdays:
299
+ count += 1
300
+ print(f"Generating contours {count}/{total}: hour={hour}, weekday={weekday}")
301
+
302
+ # Generate predictions for this time slot
303
+ prediction_data = []
304
+ for lat in lats:
305
+ for lon in lons:
306
+ try:
307
+ intensity = prediction_func(lat, lon, hour, weekday)
308
+ prediction_data.append([lat, lon, intensity])
309
+ except Exception:
310
+ prediction_data.append([lat, lon, 0.0])
311
+
312
+ # Convert to contour GeoJSON
313
+ geojson = grid_to_contour_geojson(prediction_data, bounds)
314
+ results[(hour, weekday)] = geojson
315
+
316
+ return results
317
+
318
+
319
+ def save_contours_to_file(contours: dict, filepath: str):
320
+ """
321
+ Save precomputed contours to a JSON file.
322
+
323
+ The dict keys (hour, weekday) are converted to strings for JSON serialization.
324
+ """
325
+ # Convert tuple keys to string keys for JSON
326
+ json_compatible = {
327
+ f"{hour},{weekday}": geojson
328
+ for (hour, weekday), geojson in contours.items()
329
+ }
330
+
331
+ with open(filepath, 'w') as f:
332
+ json.dump(json_compatible, f)
333
+
334
+ print(f"Saved contours to {filepath}")
335
+
336
+
337
+ def load_contours_from_file(filepath: str) -> dict:
338
+ """
339
+ Load precomputed contours from a JSON file.
340
+
341
+ Returns dict mapping (hour, weekday) tuple -> GeoJSON FeatureCollection.
342
+ """
343
+ with open(filepath, 'r') as f:
344
+ data = json.load(f)
345
+
346
+ # Convert string keys back to tuple keys
347
+ return {
348
+ tuple(map(int, key.split(','))): geojson
349
+ for key, geojson in data.items()
350
+ }
predictor.py CHANGED
@@ -7,6 +7,9 @@ Loads the XGBoost model from Hopsworks Model Registry and makes predictions.
7
  import os
8
  import numpy as np
9
  import pandas as pd
 
 
 
10
 
11
  # Global model cache
12
  _model = None
@@ -58,21 +61,29 @@ OCCUPANCY_LABELS = {
58
  }
59
  }
60
 
61
- # Feature order expected by the model
62
- # Must match training pipeline exactly
63
  FEATURE_ORDER = [
64
- "avg_speed",
 
65
  "max_speed",
66
- "speed_std",
67
  "n_positions",
 
 
68
  "lat_mean",
 
 
69
  "lon_mean",
 
 
70
  "hour",
71
  "day_of_week",
72
  "temperature_2m",
73
  "precipitation",
74
  "cloud_cover",
75
  "wind_speed_10m",
 
 
76
  "is_work_free",
77
  "is_red_day",
78
  "is_day_before_holiday",
@@ -81,10 +92,10 @@ FEATURE_ORDER = [
81
  # Default values for vehicle features (we don't have real-time vehicle data)
82
  # These are approximate averages from the training data
83
  DEFAULT_VEHICLE_FEATURES = {
84
- "avg_speed": 20.0, # typical urban bus speed (km/h)
85
  "max_speed": 45.0, # typical max speed
86
- "speed_std": 12.0, # typical speed variation
87
  "n_positions": 30, # typical GPS points per trip window
 
 
88
  }
89
 
90
 
@@ -101,6 +112,7 @@ def load_model():
101
 
102
  # Check for API key before attempting connection
103
  api_key = os.environ.get("HOPSWORKS_API_KEY")
 
104
  if not api_key:
105
  raise ValueError("HOPSWORKS_API_KEY environment variable not set. Please add it in Space settings.")
106
 
@@ -109,11 +121,12 @@ def load_model():
109
  from xgboost import XGBClassifier
110
 
111
  print("Connecting to Hopsworks...")
112
- project = hopsworks.login(api_key_value=api_key)
113
  mr = project.get_model_registry()
114
 
115
  print("Fetching model from registry...")
116
- model_entry = mr.get_model("occupancy_xgboost_model", version=None) # Latest version
 
117
 
118
  print(f"Downloading model version {model_entry.version}...")
119
  model_dir = model_entry.download()
@@ -153,15 +166,23 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
153
  # Assemble feature vector
154
  features = {
155
  # Vehicle features - use defaults
156
- "avg_speed": DEFAULT_VEHICLE_FEATURES["avg_speed"],
 
157
  "max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
158
- "speed_std": DEFAULT_VEHICLE_FEATURES["speed_std"],
159
  "n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
160
 
161
- # Location
 
 
162
  "lat_mean": lat,
 
 
163
  "lon_mean": lon,
164
 
 
 
 
 
165
  # Time
166
  "hour": hour,
167
  "day_of_week": day_of_week,
@@ -171,6 +192,8 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
171
  "precipitation": weather.get("precipitation", 0.0),
172
  "cloud_cover": weather.get("cloud_cover", 50.0),
173
  "wind_speed_10m": weather.get("wind_speed_10m", 5.0),
 
 
174
 
175
  # Holidays (convert bool to int)
176
  "is_work_free": int(holidays.get("is_work_free", False)),
@@ -189,3 +212,65 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
189
  confidence = float(probabilities[predicted_class])
190
 
191
  return predicted_class, confidence, probabilities.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import numpy as np
9
  import pandas as pd
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
 
14
  # Global model cache
15
  _model = None
 
61
  }
62
  }
63
 
64
+ # Feature order expected by the model (occupancy_xgboost_model_new v4)
65
+ # Must match training pipeline exactly - includes lat/lon bounds and bearing
66
  FEATURE_ORDER = [
67
+ "trip_id",
68
+ "vehicle_id",
69
  "max_speed",
 
70
  "n_positions",
71
+ "lat_min",
72
+ "lat_max",
73
  "lat_mean",
74
+ "lon_min",
75
+ "lon_max",
76
  "lon_mean",
77
+ "bearing_min",
78
+ "bearing_max",
79
  "hour",
80
  "day_of_week",
81
  "temperature_2m",
82
  "precipitation",
83
  "cloud_cover",
84
  "wind_speed_10m",
85
+ "rain",
86
+ "snowfall",
87
  "is_work_free",
88
  "is_red_day",
89
  "is_day_before_holiday",
 
92
  # Default values for vehicle features (we don't have real-time vehicle data)
93
  # These are approximate averages from the training data
94
  DEFAULT_VEHICLE_FEATURES = {
 
95
  "max_speed": 45.0, # typical max speed
 
96
  "n_positions": 30, # typical GPS points per trip window
97
+ "bearing_min": 0.0, # neutral bearing
98
+ "bearing_max": 360.0, # full range (stationary/unknown direction)
99
  }
100
 
101
 
 
112
 
113
  # Check for API key before attempting connection
114
  api_key = os.environ.get("HOPSWORKS_API_KEY")
115
+ project = os.environ.get("HOPSWORKS_PROJECT")
116
  if not api_key:
117
  raise ValueError("HOPSWORKS_API_KEY environment variable not set. Please add it in Space settings.")
118
 
 
121
  from xgboost import XGBClassifier
122
 
123
  print("Connecting to Hopsworks...")
124
+ project = hopsworks.login(project=project, api_key_value=api_key)
125
  mr = project.get_model_registry()
126
 
127
  print("Fetching model from registry...")
128
+ # Get version 4 explicitly (the model trained with 23 features)
129
+ model_entry = mr.get_model("occupancy_xgboost_model_new", version=4)
130
 
131
  print(f"Downloading model version {model_entry.version}...")
132
  model_dir = model_entry.download()
 
166
  # Assemble feature vector
167
  features = {
168
  # Vehicle features - use defaults
169
+ "trip_id": 0, # placeholder
170
+ "vehicle_id": 0, # placeholder
171
  "max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
 
172
  "n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
173
 
174
+ # Location bounds (set equal to point for single-location prediction)
175
+ "lat_min": lat,
176
+ "lat_max": lat,
177
  "lat_mean": lat,
178
+ "lon_min": lon,
179
+ "lon_max": lon,
180
  "lon_mean": lon,
181
 
182
+ # Bearing (neutral values for point prediction)
183
+ "bearing_min": DEFAULT_VEHICLE_FEATURES["bearing_min"],
184
+ "bearing_max": DEFAULT_VEHICLE_FEATURES["bearing_max"],
185
+
186
  # Time
187
  "hour": hour,
188
  "day_of_week": day_of_week,
 
192
  "precipitation": weather.get("precipitation", 0.0),
193
  "cloud_cover": weather.get("cloud_cover", 50.0),
194
  "wind_speed_10m": weather.get("wind_speed_10m", 5.0),
195
+ "rain": weather.get("rain", 0.0),
196
+ "snowfall": weather.get("snowfall", 0.0),
197
 
198
  # Holidays (convert bool to int)
199
  "is_work_free": int(holidays.get("is_work_free", False)),
 
212
  confidence = float(probabilities[predicted_class])
213
 
214
  return predicted_class, confidence, probabilities.tolist()
215
+
216
+
217
+ def predict_occupancy_batch(locations, hour, day_of_week, weather, holidays):
218
+ """
219
+ Predict occupancy for multiple locations in a single batch.
220
+
221
+ Much faster than calling predict_occupancy() in a loop.
222
+
223
+ Args:
224
+ locations: List of (lat, lon) tuples
225
+ hour: Hour of day (0-23)
226
+ day_of_week: Day of week (0=Monday, 6=Sunday)
227
+ weather: Dict with temperature_2m, precipitation, cloud_cover, wind_speed_10m
228
+ holidays: Dict with is_work_free, is_red_day, is_day_before_holiday
229
+
230
+ Returns:
231
+ List of (predicted_class, confidence) tuples
232
+ """
233
+ model = load_model()
234
+
235
+ # Build all feature rows at once
236
+ rows = []
237
+ for lat, lon in locations:
238
+ rows.append({
239
+ "trip_id": 0,
240
+ "vehicle_id": 0,
241
+ "max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
242
+ "n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
243
+ "lat_min": lat,
244
+ "lat_max": lat,
245
+ "lat_mean": lat,
246
+ "lon_min": lon,
247
+ "lon_max": lon,
248
+ "lon_mean": lon,
249
+ "bearing_min": DEFAULT_VEHICLE_FEATURES["bearing_min"],
250
+ "bearing_max": DEFAULT_VEHICLE_FEATURES["bearing_max"],
251
+ "hour": hour,
252
+ "day_of_week": day_of_week,
253
+ "temperature_2m": weather.get("temperature_2m", 10.0),
254
+ "precipitation": weather.get("precipitation", 0.0),
255
+ "cloud_cover": weather.get("cloud_cover", 50.0),
256
+ "wind_speed_10m": weather.get("wind_speed_10m", 5.0),
257
+ "rain": weather.get("rain", 0.0),
258
+ "snowfall": weather.get("snowfall", 0.0),
259
+ "is_work_free": int(holidays.get("is_work_free", False)),
260
+ "is_red_day": int(holidays.get("is_red_day", False)),
261
+ "is_day_before_holiday": int(holidays.get("is_day_before_holiday", False)),
262
+ })
263
+
264
+ # Single DataFrame, single predict call
265
+ X = pd.DataFrame(rows)[FEATURE_ORDER]
266
+ probabilities = model.predict_proba(X)
267
+
268
+ # Extract results
269
+ results = []
270
+ for i, (lat, lon) in enumerate(locations):
271
+ probs = probabilities[i]
272
+ predicted_class = int(np.argmax(probs))
273
+ confidence = float(probs[predicted_class])
274
+ results.append((predicted_class, confidence))
275
+
276
+ return results
requirements.txt CHANGED
@@ -8,3 +8,6 @@ pandas
8
  numpy
9
  requests
10
  python-dotenv
 
 
 
 
8
  numpy
9
  requests
10
  python-dotenv
11
+ scipy>=1.10.0
12
+ shapely>=2.0.0
13
+ matplotlib>=3.7.0
trip_info.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hopsworks
2
+ import os
3
+ import numpy as np
4
+
5
+ def load_static_trip_info():
6
+ api_key = os.environ.get("HOPSWORKS_API_KEY")
7
+ project_name = os.environ.get("HOPSWORKS_PROJECT")
8
+ project = hopsworks.login(project=project_name, api_key_value=api_key)
9
+
10
+ fs = project.get_feature_store()
11
+ fg = fs.get_feature_group("static_trip_info_fg", version=1) # adjust version
12
+
13
+ df = fg.read()
14
+ return df
15
+
16
+ def load_static_stops_info():
17
+ api_key = os.environ.get("HOPSWORKS_API_KEY")
18
+ project_name = os.environ.get("HOPSWORKS_PROJECT")
19
+ project = hopsworks.login(project=project_name, api_key_value=api_key)
20
+
21
+ fs = project.get_feature_store()
22
+ fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1) # adjust version
23
+
24
+ df = fg.read()
25
+ return df
26
+
27
+ def find_nearest_trip(lat, lon, datetime_obj, static_trip_df):
28
+ """
29
+ Return the trip closest to the requested location/time.
30
+ Currently just filters static trips; could be enhanced with stops & routing.
31
+ """
32
+ # For static data, we can only match by service_id/date/time if available
33
+ # Here we just pick a random trip as placeholder
34
+ if static_trip_df is None or len(static_trip_df) == 0:
35
+ return None
36
+
37
+ trip = static_trip_df.sample(1).iloc[0] # pick 1 random trip for demo
38
+ return {
39
+ "trip_id": trip["trip_id"],
40
+ "route_short_name": trip["route_short_name"],
41
+ "route_long_name": trip["route_long_name"],
42
+ "trip_headsign": trip.get("trip_headsign", None)
43
+ }
44
+
45
+
46
+ def find_closest_stop(lat, lon, trip_id, stops_df):
47
+ """
48
+ Returns the closest stop to a given lat/lon for the specified trip_id.
49
+ """
50
+ if stops_df is None:
51
+ return None
52
+ # Filter stops for this trip
53
+ trip_stops = stops_df[stops_df["trip_id"] == trip_id]
54
+ if trip_stops.empty:
55
+ return None
56
+
57
+ # Compute distances
58
+ lat_array = trip_stops["stop_lat"].to_numpy()
59
+ lon_array = trip_stops["stop_lon"].to_numpy()
60
+
61
+ distances = np.sqrt((lat_array - lat)**2 + (lon_array - lon)**2)
62
+ idx_min = distances.argmin()
63
+ closest_stop = trip_stops.iloc[idx_min]
64
+
65
+ return closest_stop["stop_name"]
weather.py CHANGED
@@ -6,6 +6,7 @@ Uses Open-Meteo API to get weather forecasts.
6
 
7
  import requests
8
  from datetime import datetime
 
9
 
10
  # Open-Meteo API
11
  OPENMETEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
@@ -16,6 +17,8 @@ WEATHER_VARIABLES = [
16
  "precipitation",
17
  "cloud_cover",
18
  "wind_speed_10m",
 
 
19
  ]
20
 
21
 
@@ -53,10 +56,25 @@ def get_weather_for_prediction(lat: float, lon: float, target_datetime: datetime
53
  if days_ahead <= 0:
54
  params["past_days"] = 1
55
 
56
- response = requests.get(OPENMETEO_FORECAST_URL, params=params, timeout=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  if response.status_code != 200:
59
- print(f"Weather API error: {response.status_code}")
60
  return _default_weather()
61
 
62
  data = response.json()
@@ -89,6 +107,8 @@ def get_weather_for_prediction(lat: float, lon: float, target_datetime: datetime
89
  "precipitation": hourly.get("precipitation", [None])[idx] or 0.0,
90
  "cloud_cover": hourly.get("cloud_cover", [None])[idx] or 50.0,
91
  "wind_speed_10m": hourly.get("wind_speed_10m", [None])[idx] or 5.0,
 
 
92
  }
93
 
94
  except Exception as e:
@@ -103,4 +123,6 @@ def _default_weather() -> dict:
103
  "precipitation": 0.0,
104
  "cloud_cover": 50.0,
105
  "wind_speed_10m": 5.0,
 
 
106
  }
 
6
 
7
  import requests
8
  from datetime import datetime
9
+ import time
10
 
11
  # Open-Meteo API
12
  OPENMETEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
 
17
  "precipitation",
18
  "cloud_cover",
19
  "wind_speed_10m",
20
+ "rain",
21
+ "snowfall"
22
  ]
23
 
24
 
 
56
  if days_ahead <= 0:
57
  params["past_days"] = 1
58
 
59
+ # Retry logic for transient failures
60
+ max_retries = 3
61
+ for attempt in range(max_retries):
62
+ try:
63
+ response = requests.get(OPENMETEO_FORECAST_URL, params=params, timeout=30)
64
+ if response.status_code == 200:
65
+ break
66
+ if attempt < max_retries - 1:
67
+ time.sleep(2 ** attempt) # Exponential backoff: 1s, 2s, 4s
68
+ except requests.exceptions.Timeout:
69
+ if attempt < max_retries - 1:
70
+ time.sleep(2 ** attempt)
71
+ continue
72
+ return _default_weather()
73
+ else:
74
+ print(f"Weather API error after {max_retries} retries: {response.status_code}")
75
+ return _default_weather()
76
 
77
  if response.status_code != 200:
 
78
  return _default_weather()
79
 
80
  data = response.json()
 
107
  "precipitation": hourly.get("precipitation", [None])[idx] or 0.0,
108
  "cloud_cover": hourly.get("cloud_cover", [None])[idx] or 50.0,
109
  "wind_speed_10m": hourly.get("wind_speed_10m", [None])[idx] or 5.0,
110
+ "rain": hourly.get("rain", [None])[idx] or 0.0,
111
+ "snowfall": hourly.get("snowfall", [None])[idx] or 0.0,
112
  }
113
 
114
  except Exception as e:
 
123
  "precipitation": 0.0,
124
  "cloud_cover": 50.0,
125
  "wind_speed_10m": 5.0,
126
+ "rain": 0.0,
127
+ "snowfall": 0.0
128
  }