AxelHolst commited on
Commit
764eb6f
·
1 Parent(s): 37b8251

feat: add monitoring dashboard and improve prediction speed

Browse files

- Add Monitoring page (pages/Monitoring.py) showing model performance metrics
- Key metrics: accuracy, F1, precision, MAE
- Accuracy trend chart over time
- Per-class performance breakdown for all 7 occupancy classes
- Model version filter with v4 as default
- Alert banner when accuracy drops below threshold

- Improve real-time prediction speed in app.py
- Make trip info loading optional (was causing 2+ min delays)
- Add "Load nearby bus info" button for on-demand loading
- Trip forecast display when available from forecast_fg

- Update trip_info.py with haversine distance functions

Files changed (3) hide show
  1. app.py +122 -22
  2. pages/Monitoring.py +400 -0
  3. trip_info.py +133 -28
app.py CHANGED
@@ -27,7 +27,9 @@ import hopsworks
27
  from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
28
  from weather import get_weather_for_prediction
29
  from holidays import get_holiday_features
30
- from trip_info import load_static_trip_info, find_nearest_trip, load_static_stops_info, find_closest_stop
 
 
31
  from contours import load_contours_from_file, grid_to_cells_geojson
32
 
33
  # Constants
@@ -75,6 +77,12 @@ def get_static_stops_df():
75
  return None
76
 
77
 
 
 
 
 
 
 
78
  @st.cache_resource
79
  def get_model():
80
  """Load model once and cache it."""
@@ -90,6 +98,33 @@ def cached_predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
90
  return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @st.cache_resource
94
  def fetch_heatmaps_from_hopsworks():
95
  """
@@ -292,8 +327,12 @@ def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
292
  return m
293
 
294
 
295
- def make_prediction(lat, lon, selected_datetime):
296
- """Make prediction and return formatted result."""
 
 
 
 
297
  if lat is None or lon is None:
298
  return None, None, None
299
 
@@ -314,16 +353,44 @@ def make_prediction(lat, lon, selected_datetime):
314
  holidays=holidays
315
  )
316
 
 
317
  trip_info = None
318
- static_trip_df = get_static_trip_df()
319
- if static_trip_df is not None:
320
- trip_info = find_nearest_trip(lat, lon, selected_datetime, static_trip_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  return pred_class, confidence, {
323
  "weather": weather,
324
  "holidays": holidays,
325
  "datetime": selected_datetime,
326
- "trip_info": trip_info
 
327
  }
328
  except Exception as e:
329
  return None, None, str(e)
@@ -446,12 +513,16 @@ with col2:
446
  # Show selected coordinates
447
  st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
448
 
449
- # Make prediction
 
 
 
450
  with st.spinner("Fetching prediction..."):
451
  pred_class, confidence, result = make_prediction(
452
  st.session_state.selected_lat,
453
  st.session_state.selected_lon,
454
- selected_datetime
 
455
  )
456
 
457
  if pred_class is not None:
@@ -507,22 +578,51 @@ with col2:
507
  if route_desc:
508
  info_lines.append(f"Type: {route_desc}")
509
 
510
- # Trip ID
511
- trip_id = trip_info.get("trip_id")
512
- if trip_id is not None:
513
- # Closest stop
514
- static_stops_df = get_static_stops_df()
515
- closest_stop = find_closest_stop(
516
- st.session_state.selected_lat,
517
- st.session_state.selected_lon,
518
- trip_id,
519
- static_stops_df
520
- )
521
- if closest_stop:
522
- info_lines.append(f"Nearest stop: {closest_stop}")
523
 
524
  if info_lines:
525
  st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
  # Weather conditions
528
  conditions = []
 
27
  from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
28
  from weather import get_weather_for_prediction
29
  from holidays import get_holiday_features
30
+ from trip_info import (
31
+ load_static_trip_info, find_nearest_trip, load_static_stops_info
32
+ )
33
  from contours import load_contours_from_file, grid_to_cells_geojson
34
 
35
  # Constants
 
77
  return None
78
 
79
 
80
+ def is_stops_data_cached():
81
+ """Check if stops data is already in cache without triggering load."""
82
+ # Check if the cache has been populated by looking at session state
83
+ return "stops_data_loaded" in st.session_state and st.session_state.stops_data_loaded
84
+
85
+
86
  @st.cache_resource
87
  def get_model():
88
  """Load model once and cache it."""
 
98
  return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
99
 
100
 
101
+ @st.cache_data(ttl=3600)
102
+ def fetch_trip_forecasts_from_hopsworks():
103
+ """
104
+ Fetch trip forecasts from Hopsworks forecast_fg.
105
+
106
+ Returns DataFrame with columns: trip_id, hour, weekday, predicted_occupancy, confidence
107
+ Returns None if forecast_fg doesn't exist or is empty.
108
+ """
109
+ try:
110
+ project = hopsworks.login()
111
+ fs = project.get_feature_store()
112
+ # Try v2 (new schema with hour/weekday), fall back to v1
113
+ for version in [2, 1]:
114
+ try:
115
+ forecast_fg = fs.get_feature_group("forecast_fg", version=version)
116
+ df = forecast_fg.read()
117
+ if df is not None and not df.empty:
118
+ print(f"Loaded {len(df)} trip forecasts from Hopsworks v{version}")
119
+ return df
120
+ except Exception:
121
+ continue
122
+ return None
123
+ except Exception as e:
124
+ print(f"Could not load trip forecasts: {e}")
125
+ return None
126
+
127
+
128
  @st.cache_resource
129
  def fetch_heatmaps_from_hopsworks():
130
  """
 
327
  return m
328
 
329
 
330
+ def make_prediction(lat, lon, selected_datetime, skip_trip_info=False):
331
+ """Make prediction and return formatted result.
332
+
333
+ Args:
334
+ skip_trip_info: If True, skip the slow trip info lookup
335
+ """
336
  if lat is None or lon is None:
337
  return None, None, None
338
 
 
353
  holidays=holidays
354
  )
355
 
356
+ # Find nearest trip from static data (only if not skipping)
357
  trip_info = None
358
+ trip_forecast = None
359
+
360
+ if not skip_trip_info:
361
+ static_stops_df = get_static_stops_df()
362
+ # Mark that we've loaded the data (for future quick checks)
363
+ st.session_state.stops_data_loaded = True
364
+
365
+ if static_stops_df is not None:
366
+ trip_info = find_nearest_trip(lat, lon, selected_datetime, static_stops_df)
367
+
368
+ # Try to get trip forecast if available
369
+ if trip_info and trip_info.get("trip_id"):
370
+ forecasts_df = fetch_trip_forecasts_from_hopsworks()
371
+ if forecasts_df is not None:
372
+ trip_id = trip_info["trip_id"]
373
+ hour = selected_datetime.hour
374
+ weekday = selected_datetime.weekday()
375
+ # Find matching forecast
376
+ match = forecasts_df[
377
+ (forecasts_df["trip_id"] == trip_id) &
378
+ (forecasts_df["hour"] == hour) &
379
+ (forecasts_df["weekday"] == weekday)
380
+ ]
381
+ if not match.empty:
382
+ row = match.iloc[0]
383
+ trip_forecast = {
384
+ "predicted_occupancy": int(row.get("predicted_occupancy", 0)),
385
+ "confidence": float(row.get("confidence", 0)),
386
+ }
387
 
388
  return pred_class, confidence, {
389
  "weather": weather,
390
  "holidays": holidays,
391
  "datetime": selected_datetime,
392
+ "trip_info": trip_info,
393
+ "trip_forecast": trip_forecast
394
  }
395
  except Exception as e:
396
  return None, None, str(e)
 
513
  # Show selected coordinates
514
  st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
515
 
516
+ # Check if stops data is already cached (fast check)
517
+ stops_already_loaded = st.session_state.get("stops_data_loaded", False)
518
+
519
+ # Make prediction (skip trip info on first load to be fast)
520
  with st.spinner("Fetching prediction..."):
521
  pred_class, confidence, result = make_prediction(
522
  st.session_state.selected_lat,
523
  st.session_state.selected_lon,
524
+ selected_datetime,
525
+ skip_trip_info=not stops_already_loaded
526
  )
527
 
528
  if pred_class is not None:
 
578
  if route_desc:
579
  info_lines.append(f"Type: {route_desc}")
580
 
581
+ # Closest stop from trip info (already computed)
582
+ closest_stop = trip_info.get("closest_stop")
583
+ if closest_stop:
584
+ info_lines.append(f"Nearest stop: {closest_stop}")
585
+
586
+ # Distance to stop
587
+ distance = trip_info.get("distance_m")
588
+ if distance is not None:
589
+ info_lines.append(f"Distance: {distance}m")
 
 
 
 
590
 
591
  if info_lines:
592
  st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
593
+ elif not stops_already_loaded:
594
+ # Offer to load trip info (it's slow on first load)
595
+ if st.button("Load nearby bus info", help="First load takes ~1-2 minutes"):
596
+ with st.spinner("Loading trip data from Hopsworks (this may take a minute)..."):
597
+ # Trigger the load and rerun
598
+ get_static_stops_df()
599
+ st.session_state.stops_data_loaded = True
600
+ st.rerun()
601
+
602
+ # Show trip-specific forecast if available
603
+ trip_forecast = result.get("trip_forecast")
604
+ if trip_forecast:
605
+ forecast_class = trip_forecast["predicted_occupancy"]
606
+ forecast_conf = trip_forecast["confidence"]
607
+ forecast_label = OCCUPANCY_LABELS.get(forecast_class, OCCUPANCY_LABELS[0])
608
+ forecast_color = OCCUPANCY_COLORS.get(forecast_class, "#6b7280")
609
+
610
+ st.markdown(f"""
611
+ <div style="
612
+ background: {forecast_color}11;
613
+ border: 1px solid {forecast_color}44;
614
+ border-radius: 8px;
615
+ padding: 12px;
616
+ margin: 8px 0;
617
+ ">
618
+ <div style="font-size: 0.85em; color: #6b7280; margin-bottom: 4px;">
619
+ Trip-specific forecast:
620
+ </div>
621
+ <div style="font-weight: 600; color: {forecast_color};">
622
+ {forecast_label['icon']} {forecast_label['label']} ({forecast_conf:.0%})
623
+ </div>
624
+ </div>
625
+ """, unsafe_allow_html=True)
626
 
627
  # Weather conditions
628
  conditions = []
pages/Monitoring.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HappySardines - Model Monitoring Dashboard
3
+
4
+ Displays model performance metrics from hindcast analysis:
5
+ - Accuracy trends over time
6
+ - Actual vs predicted occupancy comparison
7
+ - Per-class performance breakdown
8
+ - Alerts for model drift
9
+ """
10
+
11
+ import streamlit as st
12
+ import pandas as pd
13
+ import numpy as np
14
+ import hopsworks
15
+ from datetime import datetime, timedelta
16
+
17
+ # Page config
18
+ st.set_page_config(
19
+ page_title="HappySardines - Monitoring",
20
+ page_icon="📊",
21
+ layout="wide"
22
+ )
23
+
24
+ # Constants
25
+ ALERT_THRESHOLD = 0.65 # Alert if accuracy drops below this
26
+ MODEL_NAME = "occupancy_xgboost_model_new"
27
+ CURRENT_MODEL_VERSION = 4 # Current production model version
28
+
29
+ # Occupancy class labels
30
+ OCCUPANCY_LABELS = {
31
+ 0: "Empty",
32
+ 1: "Many seats",
33
+ 2: "Few seats",
34
+ 3: "Standing",
35
+ 4: "Crowded",
36
+ 5: "Full",
37
+ 6: "Not accepting",
38
+ }
39
+
40
+ # Colors for occupancy levels
41
+ OCCUPANCY_COLORS = {
42
+ 0: "#22c55e", # Green
43
+ 1: "#84cc16", # Lime
44
+ 2: "#eab308", # Yellow
45
+ 3: "#f97316", # Orange
46
+ 4: "#ef4444", # Red
47
+ 5: "#ef4444", # Red
48
+ 6: "#6b7280", # Gray
49
+ }
50
+
51
+
52
+ @st.cache_data(ttl=3600)
53
+ def fetch_monitoring_data():
54
+ """
55
+ Fetch monitoring data from Hopsworks monitor_fg.
56
+
57
+ Returns DataFrame with columns:
58
+ - window_start, trip_id
59
+ - actual_occupancy_mode, predicted_occupancy_mode
60
+ - accuracy, precision, recall, f1_weighted, mae
61
+ - model_version, generated_at
62
+ """
63
+ try:
64
+ project = hopsworks.login()
65
+ fs = project.get_feature_store()
66
+
67
+ monitor_fg = fs.get_feature_group("monitor_fg", version=1)
68
+ df = monitor_fg.read()
69
+
70
+ if df is not None and not df.empty:
71
+ # Ensure datetime columns are properly typed
72
+ if "window_start" in df.columns:
73
+ df["window_start"] = pd.to_datetime(df["window_start"])
74
+ if "generated_at" in df.columns:
75
+ df["generated_at"] = pd.to_datetime(df["generated_at"])
76
+
77
+ print(f"Loaded {len(df)} monitoring records from Hopsworks")
78
+ return df
79
+
80
+ return pd.DataFrame()
81
+
82
+ except Exception as e:
83
+ print(f"Error loading monitoring data: {e}")
84
+ return pd.DataFrame()
85
+
86
+
87
+ def get_daily_metrics(df: pd.DataFrame) -> pd.DataFrame:
88
+ """Aggregate monitoring data by day."""
89
+ if df.empty:
90
+ return pd.DataFrame()
91
+
92
+ df = df.copy()
93
+ df["date"] = df["window_start"].dt.date
94
+
95
+ # Get first record per day (metrics are already daily aggregates)
96
+ daily = df.groupby("date").agg({
97
+ "accuracy": "first",
98
+ "precision": "first",
99
+ "recall": "first",
100
+ "f1_weighted": "first",
101
+ "mae": "first",
102
+ "model_version": "first",
103
+ }).reset_index()
104
+
105
+ daily["date"] = pd.to_datetime(daily["date"])
106
+ return daily.sort_values("date")
107
+
108
+
109
+ def get_hourly_comparison(df: pd.DataFrame) -> pd.DataFrame:
110
+ """Aggregate actual vs predicted by hour."""
111
+ if df.empty:
112
+ return pd.DataFrame()
113
+
114
+ df = df.copy()
115
+ df["hour"] = df["window_start"].dt.floor("H")
116
+
117
+ hourly = df.groupby("hour").agg({
118
+ "actual_occupancy_mode": "mean",
119
+ "predicted_occupancy_mode": "mean",
120
+ }).reset_index()
121
+
122
+ return hourly.sort_values("hour")
123
+
124
+
125
+ def get_per_class_metrics(df: pd.DataFrame) -> pd.DataFrame:
126
+ """Calculate per-class accuracy and counts for all 7 occupancy classes."""
127
+ if df.empty:
128
+ return pd.DataFrame()
129
+
130
+ results = []
131
+ # Always show all 7 classes (0-6), even if some have no data
132
+ for cls in range(7):
133
+ mask = df["actual_occupancy_mode"] == cls
134
+ subset = df[mask]
135
+
136
+ if len(subset) > 0:
137
+ correct = (subset["actual_occupancy_mode"] == subset["predicted_occupancy_mode"]).sum()
138
+ total = len(subset)
139
+ accuracy = correct / total
140
+ else:
141
+ correct = 0
142
+ total = 0
143
+ accuracy = None # No data for this class
144
+
145
+ results.append({
146
+ "class": cls,
147
+ "label": OCCUPANCY_LABELS.get(cls, f"Class {cls}"),
148
+ "count": total,
149
+ "correct": correct,
150
+ "accuracy": accuracy,
151
+ })
152
+
153
+ return pd.DataFrame(results)
154
+
155
+
156
+ def render_metric_card(label: str, value: float, format_str: str = "{:.1%}",
157
+ threshold_low: float = None, threshold_high: float = None):
158
+ """Render a metric with conditional coloring."""
159
+ formatted = format_str.format(value) if value is not None else "N/A"
160
+
161
+ # Determine color
162
+ if threshold_low is not None and value < threshold_low:
163
+ color = "#ef4444" # Red
164
+ elif threshold_high is not None and value >= threshold_high:
165
+ color = "#22c55e" # Green
166
+ else:
167
+ color = "#eab308" # Yellow
168
+
169
+ st.markdown(f"""
170
+ <div style="
171
+ background: {color}11;
172
+ border: 1px solid {color}44;
173
+ border-radius: 8px;
174
+ padding: 16px;
175
+ text-align: center;
176
+ ">
177
+ <div style="font-size: 0.9em; color: #6b7280; margin-bottom: 4px;">
178
+ {label}
179
+ </div>
180
+ <div style="font-size: 1.8em; font-weight: 600; color: {color};">
181
+ {formatted}
182
+ </div>
183
+ </div>
184
+ """, unsafe_allow_html=True)
185
+
186
+
187
+ # Main page content
188
+ st.title("📊 Model Monitoring")
189
+ st.markdown("Track model performance over time using hindcast analysis.")
190
+
191
+ # Load data
192
+ with st.spinner("Loading monitoring data..."):
193
+ monitor_df = fetch_monitoring_data()
194
+
195
+ if monitor_df.empty:
196
+ st.warning("""
197
+ **No monitoring data available yet.**
198
+
199
+ Monitoring data is generated daily by the inference pipeline, which compares
200
+ yesterday's predictions to actual observed occupancy.
201
+
202
+ The inference pipeline runs at 09:00 UTC. Check back after it has run at least once.
203
+ """)
204
+ st.stop()
205
+
206
+ # Model version filter
207
+ available_versions = sorted(monitor_df["model_version"].dropna().unique())
208
+ if len(available_versions) > 1:
209
+ st.sidebar.subheader("Filter")
210
+
211
+ # Default to current model version if available
212
+ default_idx = available_versions.index(CURRENT_MODEL_VERSION) if CURRENT_MODEL_VERSION in available_versions else len(available_versions) - 1
213
+
214
+ selected_version = st.sidebar.selectbox(
215
+ "Model Version",
216
+ options=available_versions,
217
+ index=default_idx,
218
+ format_func=lambda x: f"v{int(x)}" + (" (current)" if x == CURRENT_MODEL_VERSION else "")
219
+ )
220
+
221
+ # Filter data by selected version
222
+ monitor_df = monitor_df[monitor_df["model_version"] == selected_version]
223
+
224
+ if monitor_df.empty:
225
+ st.warning(f"No monitoring data available for model v{int(selected_version)}.")
226
+ st.stop()
227
+ else:
228
+ selected_version = available_versions[0] if available_versions else None
229
+
230
+ # Show warning if viewing old model data
231
+ if selected_version is not None and selected_version != CURRENT_MODEL_VERSION:
232
+ st.info(f"""
233
+ **Viewing historical data from model v{int(selected_version)}.**
234
+
235
+ The current production model is v{CURRENT_MODEL_VERSION}.
236
+ Data for v{CURRENT_MODEL_VERSION} will appear after the inference pipeline runs.
237
+ """)
238
+
239
+ # Calculate aggregates
240
+ daily_metrics = get_daily_metrics(monitor_df)
241
+ hourly_comparison = get_hourly_comparison(monitor_df)
242
+ per_class = get_per_class_metrics(monitor_df)
243
+
244
+ # Get latest metrics
245
+ latest = daily_metrics.iloc[-1] if not daily_metrics.empty else None
246
+
247
+ # Header with model info
248
+ if latest is not None:
249
+ col1, col2 = st.columns([3, 1])
250
+ with col1:
251
+ st.markdown(f"**Model Version:** v{int(latest['model_version'])}")
252
+ with col2:
253
+ last_date = latest["date"].strftime("%Y-%m-%d")
254
+ st.markdown(f"**Last Updated:** {last_date}")
255
+
256
+ # Alert banner
257
+ if latest is not None and latest["accuracy"] < ALERT_THRESHOLD:
258
+ st.error(f"""
259
+ ⚠️ **Model Performance Alert**
260
+
261
+ Accuracy ({latest['accuracy']:.1%}) is below the threshold ({ALERT_THRESHOLD:.0%}).
262
+ Consider investigating recent data quality or retraining the model.
263
+ """)
264
+
265
+ st.divider()
266
+
267
+ # Key metrics cards
268
+ st.subheader("Latest Performance")
269
+
270
+ if latest is not None:
271
+ col1, col2, col3, col4 = st.columns(4)
272
+
273
+ with col1:
274
+ render_metric_card(
275
+ "Accuracy",
276
+ latest["accuracy"],
277
+ threshold_low=0.60,
278
+ threshold_high=0.70
279
+ )
280
+
281
+ with col2:
282
+ render_metric_card(
283
+ "F1 Score (Weighted)",
284
+ latest["f1_weighted"],
285
+ threshold_low=0.55,
286
+ threshold_high=0.65
287
+ )
288
+
289
+ with col3:
290
+ render_metric_card(
291
+ "Precision",
292
+ latest["precision"],
293
+ threshold_low=0.55,
294
+ threshold_high=0.70
295
+ )
296
+
297
+ with col4:
298
+ render_metric_card(
299
+ "MAE",
300
+ latest["mae"],
301
+ format_str="{:.2f}",
302
+ threshold_low=0.3, # Lower is better for MAE
303
+ threshold_high=0.6
304
+ )
305
+
306
+ st.divider()
307
+
308
+ # Accuracy trend chart
309
+ st.subheader("Accuracy Over Time")
310
+
311
+ if not daily_metrics.empty and len(daily_metrics) > 1:
312
+ chart_data = daily_metrics[["date", "accuracy"]].set_index("date")
313
+ st.line_chart(chart_data, use_container_width=True)
314
+
315
+ # Show trend
316
+ if len(daily_metrics) >= 2:
317
+ recent = daily_metrics["accuracy"].iloc[-1]
318
+ previous = daily_metrics["accuracy"].iloc[-2]
319
+ delta = recent - previous
320
+ trend = "📈" if delta > 0 else "📉" if delta < 0 else "➡️"
321
+ st.caption(f"{trend} Change from previous: {delta:+.1%}")
322
+ else:
323
+ st.info("Need at least 2 days of data to show trend chart.")
324
+
325
+ st.divider()
326
+
327
+ # Actual vs Predicted comparison
328
+ st.subheader("Actual vs Predicted Occupancy")
329
+
330
+ if not hourly_comparison.empty:
331
+ # Rename columns for display
332
+ chart_df = hourly_comparison.rename(columns={
333
+ "actual_occupancy_mode": "Actual",
334
+ "predicted_occupancy_mode": "Predicted"
335
+ }).set_index("hour")
336
+
337
+ st.line_chart(chart_df, use_container_width=True)
338
+ st.caption("Hourly average occupancy levels (0=Empty, 3=Standing)")
339
+ else:
340
+ st.info("No hourly comparison data available.")
341
+
342
+ st.divider()
343
+
344
+ # Per-class performance
345
+ st.subheader("Per-Class Performance")
346
+
347
+ if not per_class.empty:
348
+ # Color-coded display for all 7 classes
349
+ for _, row in per_class.iterrows():
350
+ cls = int(row["class"])
351
+ label = row["label"]
352
+ accuracy = row["accuracy"]
353
+ count = int(row["count"])
354
+ color = OCCUPANCY_COLORS.get(cls, "#6b7280")
355
+
356
+ col1, col2, col3 = st.columns([2, 1, 3])
357
+
358
+ with col1:
359
+ st.markdown(f"**{cls}** - {label}")
360
+
361
+ with col2:
362
+ if count > 0:
363
+ st.markdown(f"{count:,} samples")
364
+ else:
365
+ st.markdown("No data", help="No samples of this class in the monitoring data")
366
+
367
+ with col3:
368
+ if accuracy is not None and count > 0:
369
+ # Progress bar with color
370
+ st.progress(accuracy, text=f"{accuracy:.1%}")
371
+ else:
372
+ st.markdown("—", help="No accuracy data for this class")
373
+
374
+ # Explanation
375
+ st.caption("""
376
+ Per-class accuracy (recall) shows how well the model predicts each occupancy level.
377
+ Classes 4-6 are rare in Swedish transit data. Lower accuracy for rare classes is expected.
378
+ """)
379
+ else:
380
+ st.info("No per-class metrics available.")
381
+
382
+ st.divider()
383
+
384
+ # Raw data expander
385
+ with st.expander("View Raw Monitoring Data"):
386
+ if not monitor_df.empty:
387
+ st.dataframe(
388
+ monitor_df.sort_values("window_start", ascending=False).head(100),
389
+ use_container_width=True
390
+ )
391
+ st.caption(f"Showing latest 100 of {len(monitor_df):,} total records")
392
+ else:
393
+ st.info("No raw data available.")
394
+
395
+ # Footer
396
+ st.divider()
397
+ st.markdown(
398
+ "<div style='text-align: center; opacity: 0.6;'>Model monitoring powered by Hopsworks Feature Store</div>",
399
+ unsafe_allow_html=True
400
+ )
trip_info.py CHANGED
@@ -1,65 +1,170 @@
1
  import hopsworks
2
  import os
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def load_static_trip_info():
6
  api_key = os.environ.get("HOPSWORKS_API_KEY")
7
  project_name = os.environ.get("HOPSWORKS_PROJECT")
8
  project = hopsworks.login(project=project_name, api_key_value=api_key)
9
-
10
  fs = project.get_feature_store()
11
- fg = fs.get_feature_group("static_trip_info_fg", version=1) # adjust version
12
-
13
  df = fg.read()
14
  return df
15
 
 
16
  def load_static_stops_info():
17
  api_key = os.environ.get("HOPSWORKS_API_KEY")
18
  project_name = os.environ.get("HOPSWORKS_PROJECT")
19
  project = hopsworks.login(project=project_name, api_key_value=api_key)
20
-
21
  fs = project.get_feature_store()
22
- fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1) # adjust version
23
-
24
  df = fg.read()
25
  return df
26
 
27
- def find_nearest_trip(lat, lon, datetime_obj, static_trip_df):
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
- Return the trip closest to the requested location/time.
30
- Currently just filters static trips; could be enhanced with stops & routing.
 
 
 
 
 
 
 
 
 
 
31
  """
32
- # For static data, we can only match by service_id/date/time if available
33
- # Here we just pick a random trip as placeholder
34
- if static_trip_df is None or len(static_trip_df) == 0:
 
 
 
 
 
 
 
35
  return None
36
-
37
- trip = static_trip_df.sample(1).iloc[0] # pick 1 random trip for demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return {
39
- "trip_id": trip["trip_id"],
40
- "route_short_name": trip["route_short_name"],
41
- "route_long_name": trip["route_long_name"],
42
- "trip_headsign": trip.get("trip_headsign", None)
 
 
 
43
  }
44
 
45
 
46
  def find_closest_stop(lat, lon, trip_id, stops_df):
47
  """
48
  Returns the closest stop to a given lat/lon for the specified trip_id.
 
 
49
  """
50
- if stops_df is None:
51
- return None
 
52
  # Filter stops for this trip
53
  trip_stops = stops_df[stops_df["trip_id"] == trip_id]
54
  if trip_stops.empty:
55
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Compute distances
58
- lat_array = trip_stops["stop_lat"].to_numpy()
59
- lon_array = trip_stops["stop_lon"].to_numpy()
 
60
 
61
- distances = np.sqrt((lat_array - lat)**2 + (lon_array - lon)**2)
62
- idx_min = distances.argmin()
63
- closest_stop = trip_stops.iloc[idx_min]
64
 
65
- return closest_stop["stop_name"]
 
 
 
1
  import hopsworks
2
  import os
3
  import numpy as np
4
+ from math import radians, sin, cos, sqrt, atan2
5
+
6
+
7
+ def haversine_distance(lat1, lon1, lat2, lon2):
8
+ """Calculate distance in meters between two points using haversine formula."""
9
+ R = 6371000 # Earth's radius in meters
10
+ dlat = radians(lat2 - lat1)
11
+ dlon = radians(lon2 - lon1)
12
+
13
+ a = sin(dlat/2)**2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(dlon/2)**2
14
+ c = 2 * atan2(sqrt(a), sqrt(1-a))
15
+ return R * c
16
+
17
 
18
  def load_static_trip_info():
19
  api_key = os.environ.get("HOPSWORKS_API_KEY")
20
  project_name = os.environ.get("HOPSWORKS_PROJECT")
21
  project = hopsworks.login(project=project_name, api_key_value=api_key)
22
+
23
  fs = project.get_feature_store()
24
+ fg = fs.get_feature_group("static_trip_info_fg", version=1)
25
+
26
  df = fg.read()
27
  return df
28
 
29
+
30
  def load_static_stops_info():
31
  api_key = os.environ.get("HOPSWORKS_API_KEY")
32
  project_name = os.environ.get("HOPSWORKS_PROJECT")
33
  project = hopsworks.login(project=project_name, api_key_value=api_key)
34
+
35
  fs = project.get_feature_store()
36
+ fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1)
37
+
38
  df = fg.read()
39
  return df
40
 
41
+
42
+ def time_to_seconds(t):
43
+ """Convert time string (HH:MM:SS) to seconds since midnight."""
44
+ if t is None:
45
+ return None
46
+ try:
47
+ h, m, s = map(int, str(t).split(":"))
48
+ return h * 3600 + m * 60 + s
49
+ except (ValueError, AttributeError):
50
+ return None
51
+
52
+
53
+ def find_nearest_trip(lat, lon, datetime_obj, static_trip_and_stops_df, max_radius_m=500):
54
  """
55
+ Find the nearest trip to a given location and time.
56
+
57
+ Uses haversine distance and filters by time window.
58
+
59
+ Args:
60
+ lat, lon: Location to search near
61
+ datetime_obj: Target datetime
62
+ static_trip_and_stops_df: DataFrame with trip and stop info
63
+ max_radius_m: Maximum search radius in meters (default 500m)
64
+
65
+ Returns:
66
+ Dict with trip info or None if no nearby trip found
67
  """
68
+ if static_trip_and_stops_df is None or static_trip_and_stops_df.empty:
69
+ return None
70
+
71
+ target_s = datetime_obj.hour * 3600 + datetime_obj.minute * 60
72
+
73
+ # Compute distance to each stop
74
+ df = static_trip_and_stops_df.copy()
75
+
76
+ # Check if required columns exist
77
+ if "stop_lat" not in df.columns or "stop_lon" not in df.columns:
78
  return None
79
+
80
+ df["distance_m"] = df.apply(
81
+ lambda r: haversine_distance(lat, lon, r["stop_lat"], r["stop_lon"]),
82
+ axis=1
83
+ )
84
+
85
+ # Geographic filter
86
+ nearby = df[df["distance_m"] <= max_radius_m]
87
+ if nearby.empty:
88
+ # Try with larger radius
89
+ nearby = df[df["distance_m"] <= max_radius_m * 2]
90
+ if nearby.empty:
91
+ return None
92
+
93
+ # Build time window check if arrival/departure times are available
94
+ if "arrival_time" in nearby.columns and "departure_time" in nearby.columns:
95
+ nearby = nearby.copy()
96
+ nearby["arr_s"] = nearby["arrival_time"].apply(time_to_seconds)
97
+ nearby["dep_s"] = nearby["departure_time"].apply(time_to_seconds)
98
+
99
+ # Keep trips where we're near a scheduled stop time
100
+ time_filtered = nearby[
101
+ (nearby["arr_s"].notna()) &
102
+ ((nearby["arr_s"] <= target_s + 3600) & (nearby["arr_s"] >= target_s - 3600))
103
+ ]
104
+
105
+ if not time_filtered.empty:
106
+ nearby = time_filtered
107
+
108
+ # Choose the one whose stop is closest to the click
109
+ best = nearby.sort_values("distance_m").iloc[0]
110
+
111
  return {
112
+ "trip_id": best.get("trip_id"),
113
+ "route_short_name": best.get("route_short_name"),
114
+ "route_long_name": best.get("route_long_name"),
115
+ "trip_headsign": best.get("trip_headsign"),
116
+ "closest_stop": best.get("stop_name"),
117
+ "closest_stop_headsign": best.get("stop_headsign"),
118
+ "distance_m": round(best["distance_m"]),
119
  }
120
 
121
 
122
  def find_closest_stop(lat, lon, trip_id, stops_df):
123
  """
124
  Returns the closest stop to a given lat/lon for the specified trip_id.
125
+
126
+ Returns tuple of (stop_name, stop_headsign) or (None, None) if not found.
127
  """
128
+ if stops_df is None or stops_df.empty:
129
+ return None, None
130
+
131
  # Filter stops for this trip
132
  trip_stops = stops_df[stops_df["trip_id"] == trip_id]
133
  if trip_stops.empty:
134
+ return None, None
135
+
136
+ # Compute distances using haversine
137
+ distances = trip_stops.apply(
138
+ lambda r: haversine_distance(lat, lon, r["stop_lat"], r["stop_lon"]),
139
+ axis=1
140
+ )
141
+
142
+ closest_stop = trip_stops.loc[distances.idxmin()]
143
+
144
+ return closest_stop.get("stop_name"), closest_stop.get("stop_headsign")
145
+
146
+
147
+ def load_trip_forecasts(fs, hour, weekday):
148
+ """
149
+ Load trip forecasts from forecast_fg for a specific hour and weekday.
150
+
151
+ Returns DataFrame with trip predictions or empty DataFrame if not available.
152
+ """
153
+ try:
154
+ forecast_fg = fs.get_feature_group("forecast_fg", version=1)
155
+ df = forecast_fg.read()
156
+
157
+ if df.empty:
158
+ return df
159
 
160
+ # Filter to matching hour and weekday
161
+ df["window_start"] = pd.to_datetime(df["window_start"])
162
+ df["hour"] = df["window_start"].dt.hour
163
+ df["weekday"] = df["window_start"].dt.weekday
164
 
165
+ filtered = df[(df["hour"] == hour) & (df["weekday"] == weekday)]
166
+ return filtered
 
167
 
168
+ except Exception as e:
169
+ print(f"Could not load trip forecasts: {e}")
170
+ return pd.DataFrame()