Spaces:
Sleeping
Sleeping
feat: add monitoring dashboard and improve prediction speed
Browse files- Add Monitoring page (pages/Monitoring.py) showing model performance metrics
- Key metrics: accuracy, F1, precision, MAE
- Accuracy trend chart over time
- Per-class performance breakdown for all 7 occupancy classes
- Model version filter with v4 as default
- Alert banner when accuracy drops below threshold
- Improve real-time prediction speed in app.py
- Make trip info loading optional (was causing 2+ min delays)
- Add "Load nearby bus info" button for on-demand loading
- Trip forecast display when available from forecast_fg
- Update trip_info.py with haversine distance functions
- app.py +122 -22
- pages/Monitoring.py +400 -0
- trip_info.py +133 -28
app.py
CHANGED
|
@@ -27,7 +27,9 @@ import hopsworks
|
|
| 27 |
from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
|
| 28 |
from weather import get_weather_for_prediction
|
| 29 |
from holidays import get_holiday_features
|
| 30 |
-
from trip_info import
|
|
|
|
|
|
|
| 31 |
from contours import load_contours_from_file, grid_to_cells_geojson
|
| 32 |
|
| 33 |
# Constants
|
|
@@ -75,6 +77,12 @@ def get_static_stops_df():
|
|
| 75 |
return None
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
@st.cache_resource
|
| 79 |
def get_model():
|
| 80 |
"""Load model once and cache it."""
|
|
@@ -90,6 +98,33 @@ def cached_predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
|
|
| 90 |
return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
@st.cache_resource
|
| 94 |
def fetch_heatmaps_from_hopsworks():
|
| 95 |
"""
|
|
@@ -292,8 +327,12 @@ def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
|
|
| 292 |
return m
|
| 293 |
|
| 294 |
|
| 295 |
-
def make_prediction(lat, lon, selected_datetime):
|
| 296 |
-
"""Make prediction and return formatted result.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
if lat is None or lon is None:
|
| 298 |
return None, None, None
|
| 299 |
|
|
@@ -314,16 +353,44 @@ def make_prediction(lat, lon, selected_datetime):
|
|
| 314 |
holidays=holidays
|
| 315 |
)
|
| 316 |
|
|
|
|
| 317 |
trip_info = None
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
return pred_class, confidence, {
|
| 323 |
"weather": weather,
|
| 324 |
"holidays": holidays,
|
| 325 |
"datetime": selected_datetime,
|
| 326 |
-
"trip_info": trip_info
|
|
|
|
| 327 |
}
|
| 328 |
except Exception as e:
|
| 329 |
return None, None, str(e)
|
|
@@ -446,12 +513,16 @@ with col2:
|
|
| 446 |
# Show selected coordinates
|
| 447 |
st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
|
| 448 |
|
| 449 |
-
#
|
|
|
|
|
|
|
|
|
|
| 450 |
with st.spinner("Fetching prediction..."):
|
| 451 |
pred_class, confidence, result = make_prediction(
|
| 452 |
st.session_state.selected_lat,
|
| 453 |
st.session_state.selected_lon,
|
| 454 |
-
selected_datetime
|
|
|
|
| 455 |
)
|
| 456 |
|
| 457 |
if pred_class is not None:
|
|
@@ -507,22 +578,51 @@ with col2:
|
|
| 507 |
if route_desc:
|
| 508 |
info_lines.append(f"Type: {route_desc}")
|
| 509 |
|
| 510 |
-
#
|
| 511 |
-
|
| 512 |
-
if
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
static_stops_df
|
| 520 |
-
)
|
| 521 |
-
if closest_stop:
|
| 522 |
-
info_lines.append(f"Nearest stop: {closest_stop}")
|
| 523 |
|
| 524 |
if info_lines:
|
| 525 |
st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
# Weather conditions
|
| 528 |
conditions = []
|
|
|
|
| 27 |
from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
|
| 28 |
from weather import get_weather_for_prediction
|
| 29 |
from holidays import get_holiday_features
|
| 30 |
+
from trip_info import (
|
| 31 |
+
load_static_trip_info, find_nearest_trip, load_static_stops_info
|
| 32 |
+
)
|
| 33 |
from contours import load_contours_from_file, grid_to_cells_geojson
|
| 34 |
|
| 35 |
# Constants
|
|
|
|
| 77 |
return None
|
| 78 |
|
| 79 |
|
| 80 |
+
def is_stops_data_cached():
|
| 81 |
+
"""Check if stops data is already in cache without triggering load."""
|
| 82 |
+
# Check if the cache has been populated by looking at session state
|
| 83 |
+
return "stops_data_loaded" in st.session_state and st.session_state.stops_data_loaded
|
| 84 |
+
|
| 85 |
+
|
| 86 |
@st.cache_resource
|
| 87 |
def get_model():
|
| 88 |
"""Load model once and cache it."""
|
|
|
|
| 98 |
return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
|
| 99 |
|
| 100 |
|
| 101 |
+
@st.cache_data(ttl=3600)
|
| 102 |
+
def fetch_trip_forecasts_from_hopsworks():
|
| 103 |
+
"""
|
| 104 |
+
Fetch trip forecasts from Hopsworks forecast_fg.
|
| 105 |
+
|
| 106 |
+
Returns DataFrame with columns: trip_id, hour, weekday, predicted_occupancy, confidence
|
| 107 |
+
Returns None if forecast_fg doesn't exist or is empty.
|
| 108 |
+
"""
|
| 109 |
+
try:
|
| 110 |
+
project = hopsworks.login()
|
| 111 |
+
fs = project.get_feature_store()
|
| 112 |
+
# Try v2 (new schema with hour/weekday), fall back to v1
|
| 113 |
+
for version in [2, 1]:
|
| 114 |
+
try:
|
| 115 |
+
forecast_fg = fs.get_feature_group("forecast_fg", version=version)
|
| 116 |
+
df = forecast_fg.read()
|
| 117 |
+
if df is not None and not df.empty:
|
| 118 |
+
print(f"Loaded {len(df)} trip forecasts from Hopsworks v{version}")
|
| 119 |
+
return df
|
| 120 |
+
except Exception:
|
| 121 |
+
continue
|
| 122 |
+
return None
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Could not load trip forecasts: {e}")
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
@st.cache_resource
|
| 129 |
def fetch_heatmaps_from_hopsworks():
|
| 130 |
"""
|
|
|
|
| 327 |
return m
|
| 328 |
|
| 329 |
|
| 330 |
+
def make_prediction(lat, lon, selected_datetime, skip_trip_info=False):
|
| 331 |
+
"""Make prediction and return formatted result.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
skip_trip_info: If True, skip the slow trip info lookup
|
| 335 |
+
"""
|
| 336 |
if lat is None or lon is None:
|
| 337 |
return None, None, None
|
| 338 |
|
|
|
|
| 353 |
holidays=holidays
|
| 354 |
)
|
| 355 |
|
| 356 |
+
# Find nearest trip from static data (only if not skipping)
|
| 357 |
trip_info = None
|
| 358 |
+
trip_forecast = None
|
| 359 |
+
|
| 360 |
+
if not skip_trip_info:
|
| 361 |
+
static_stops_df = get_static_stops_df()
|
| 362 |
+
# Mark that we've loaded the data (for future quick checks)
|
| 363 |
+
st.session_state.stops_data_loaded = True
|
| 364 |
+
|
| 365 |
+
if static_stops_df is not None:
|
| 366 |
+
trip_info = find_nearest_trip(lat, lon, selected_datetime, static_stops_df)
|
| 367 |
+
|
| 368 |
+
# Try to get trip forecast if available
|
| 369 |
+
if trip_info and trip_info.get("trip_id"):
|
| 370 |
+
forecasts_df = fetch_trip_forecasts_from_hopsworks()
|
| 371 |
+
if forecasts_df is not None:
|
| 372 |
+
trip_id = trip_info["trip_id"]
|
| 373 |
+
hour = selected_datetime.hour
|
| 374 |
+
weekday = selected_datetime.weekday()
|
| 375 |
+
# Find matching forecast
|
| 376 |
+
match = forecasts_df[
|
| 377 |
+
(forecasts_df["trip_id"] == trip_id) &
|
| 378 |
+
(forecasts_df["hour"] == hour) &
|
| 379 |
+
(forecasts_df["weekday"] == weekday)
|
| 380 |
+
]
|
| 381 |
+
if not match.empty:
|
| 382 |
+
row = match.iloc[0]
|
| 383 |
+
trip_forecast = {
|
| 384 |
+
"predicted_occupancy": int(row.get("predicted_occupancy", 0)),
|
| 385 |
+
"confidence": float(row.get("confidence", 0)),
|
| 386 |
+
}
|
| 387 |
|
| 388 |
return pred_class, confidence, {
|
| 389 |
"weather": weather,
|
| 390 |
"holidays": holidays,
|
| 391 |
"datetime": selected_datetime,
|
| 392 |
+
"trip_info": trip_info,
|
| 393 |
+
"trip_forecast": trip_forecast
|
| 394 |
}
|
| 395 |
except Exception as e:
|
| 396 |
return None, None, str(e)
|
|
|
|
| 513 |
# Show selected coordinates
|
| 514 |
st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
|
| 515 |
|
| 516 |
+
# Check if stops data is already cached (fast check)
|
| 517 |
+
stops_already_loaded = st.session_state.get("stops_data_loaded", False)
|
| 518 |
+
|
| 519 |
+
# Make prediction (skip trip info on first load to be fast)
|
| 520 |
with st.spinner("Fetching prediction..."):
|
| 521 |
pred_class, confidence, result = make_prediction(
|
| 522 |
st.session_state.selected_lat,
|
| 523 |
st.session_state.selected_lon,
|
| 524 |
+
selected_datetime,
|
| 525 |
+
skip_trip_info=not stops_already_loaded
|
| 526 |
)
|
| 527 |
|
| 528 |
if pred_class is not None:
|
|
|
|
| 578 |
if route_desc:
|
| 579 |
info_lines.append(f"Type: {route_desc}")
|
| 580 |
|
| 581 |
+
# Closest stop from trip info (already computed)
|
| 582 |
+
closest_stop = trip_info.get("closest_stop")
|
| 583 |
+
if closest_stop:
|
| 584 |
+
info_lines.append(f"Nearest stop: {closest_stop}")
|
| 585 |
+
|
| 586 |
+
# Distance to stop
|
| 587 |
+
distance = trip_info.get("distance_m")
|
| 588 |
+
if distance is not None:
|
| 589 |
+
info_lines.append(f"Distance: {distance}m")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
if info_lines:
|
| 592 |
st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
|
| 593 |
+
elif not stops_already_loaded:
|
| 594 |
+
# Offer to load trip info (it's slow on first load)
|
| 595 |
+
if st.button("Load nearby bus info", help="First load takes ~1-2 minutes"):
|
| 596 |
+
with st.spinner("Loading trip data from Hopsworks (this may take a minute)..."):
|
| 597 |
+
# Trigger the load and rerun
|
| 598 |
+
get_static_stops_df()
|
| 599 |
+
st.session_state.stops_data_loaded = True
|
| 600 |
+
st.rerun()
|
| 601 |
+
|
| 602 |
+
# Show trip-specific forecast if available
|
| 603 |
+
trip_forecast = result.get("trip_forecast")
|
| 604 |
+
if trip_forecast:
|
| 605 |
+
forecast_class = trip_forecast["predicted_occupancy"]
|
| 606 |
+
forecast_conf = trip_forecast["confidence"]
|
| 607 |
+
forecast_label = OCCUPANCY_LABELS.get(forecast_class, OCCUPANCY_LABELS[0])
|
| 608 |
+
forecast_color = OCCUPANCY_COLORS.get(forecast_class, "#6b7280")
|
| 609 |
+
|
| 610 |
+
st.markdown(f"""
|
| 611 |
+
<div style="
|
| 612 |
+
background: {forecast_color}11;
|
| 613 |
+
border: 1px solid {forecast_color}44;
|
| 614 |
+
border-radius: 8px;
|
| 615 |
+
padding: 12px;
|
| 616 |
+
margin: 8px 0;
|
| 617 |
+
">
|
| 618 |
+
<div style="font-size: 0.85em; color: #6b7280; margin-bottom: 4px;">
|
| 619 |
+
Trip-specific forecast:
|
| 620 |
+
</div>
|
| 621 |
+
<div style="font-weight: 600; color: {forecast_color};">
|
| 622 |
+
{forecast_label['icon']} {forecast_label['label']} ({forecast_conf:.0%})
|
| 623 |
+
</div>
|
| 624 |
+
</div>
|
| 625 |
+
""", unsafe_allow_html=True)
|
| 626 |
|
| 627 |
# Weather conditions
|
| 628 |
conditions = []
|
pages/Monitoring.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HappySardines - Model Monitoring Dashboard
|
| 3 |
+
|
| 4 |
+
Displays model performance metrics from hindcast analysis:
|
| 5 |
+
- Accuracy trends over time
|
| 6 |
+
- Actual vs predicted occupancy comparison
|
| 7 |
+
- Per-class performance breakdown
|
| 8 |
+
- Alerts for model drift
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
import hopsworks
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
|
| 17 |
+
# Page config
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="HappySardines - Monitoring",
|
| 20 |
+
page_icon="📊",
|
| 21 |
+
layout="wide"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Constants
|
| 25 |
+
ALERT_THRESHOLD = 0.65 # Alert if accuracy drops below this
|
| 26 |
+
MODEL_NAME = "occupancy_xgboost_model_new"
|
| 27 |
+
CURRENT_MODEL_VERSION = 4 # Current production model version
|
| 28 |
+
|
| 29 |
+
# Occupancy class labels
|
| 30 |
+
OCCUPANCY_LABELS = {
|
| 31 |
+
0: "Empty",
|
| 32 |
+
1: "Many seats",
|
| 33 |
+
2: "Few seats",
|
| 34 |
+
3: "Standing",
|
| 35 |
+
4: "Crowded",
|
| 36 |
+
5: "Full",
|
| 37 |
+
6: "Not accepting",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Colors for occupancy levels
|
| 41 |
+
OCCUPANCY_COLORS = {
|
| 42 |
+
0: "#22c55e", # Green
|
| 43 |
+
1: "#84cc16", # Lime
|
| 44 |
+
2: "#eab308", # Yellow
|
| 45 |
+
3: "#f97316", # Orange
|
| 46 |
+
4: "#ef4444", # Red
|
| 47 |
+
5: "#ef4444", # Red
|
| 48 |
+
6: "#6b7280", # Gray
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@st.cache_data(ttl=3600)
|
| 53 |
+
def fetch_monitoring_data():
|
| 54 |
+
"""
|
| 55 |
+
Fetch monitoring data from Hopsworks monitor_fg.
|
| 56 |
+
|
| 57 |
+
Returns DataFrame with columns:
|
| 58 |
+
- window_start, trip_id
|
| 59 |
+
- actual_occupancy_mode, predicted_occupancy_mode
|
| 60 |
+
- accuracy, precision, recall, f1_weighted, mae
|
| 61 |
+
- model_version, generated_at
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
project = hopsworks.login()
|
| 65 |
+
fs = project.get_feature_store()
|
| 66 |
+
|
| 67 |
+
monitor_fg = fs.get_feature_group("monitor_fg", version=1)
|
| 68 |
+
df = monitor_fg.read()
|
| 69 |
+
|
| 70 |
+
if df is not None and not df.empty:
|
| 71 |
+
# Ensure datetime columns are properly typed
|
| 72 |
+
if "window_start" in df.columns:
|
| 73 |
+
df["window_start"] = pd.to_datetime(df["window_start"])
|
| 74 |
+
if "generated_at" in df.columns:
|
| 75 |
+
df["generated_at"] = pd.to_datetime(df["generated_at"])
|
| 76 |
+
|
| 77 |
+
print(f"Loaded {len(df)} monitoring records from Hopsworks")
|
| 78 |
+
return df
|
| 79 |
+
|
| 80 |
+
return pd.DataFrame()
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Error loading monitoring data: {e}")
|
| 84 |
+
return pd.DataFrame()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_daily_metrics(df: pd.DataFrame) -> pd.DataFrame:
|
| 88 |
+
"""Aggregate monitoring data by day."""
|
| 89 |
+
if df.empty:
|
| 90 |
+
return pd.DataFrame()
|
| 91 |
+
|
| 92 |
+
df = df.copy()
|
| 93 |
+
df["date"] = df["window_start"].dt.date
|
| 94 |
+
|
| 95 |
+
# Get first record per day (metrics are already daily aggregates)
|
| 96 |
+
daily = df.groupby("date").agg({
|
| 97 |
+
"accuracy": "first",
|
| 98 |
+
"precision": "first",
|
| 99 |
+
"recall": "first",
|
| 100 |
+
"f1_weighted": "first",
|
| 101 |
+
"mae": "first",
|
| 102 |
+
"model_version": "first",
|
| 103 |
+
}).reset_index()
|
| 104 |
+
|
| 105 |
+
daily["date"] = pd.to_datetime(daily["date"])
|
| 106 |
+
return daily.sort_values("date")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_hourly_comparison(df: pd.DataFrame) -> pd.DataFrame:
|
| 110 |
+
"""Aggregate actual vs predicted by hour."""
|
| 111 |
+
if df.empty:
|
| 112 |
+
return pd.DataFrame()
|
| 113 |
+
|
| 114 |
+
df = df.copy()
|
| 115 |
+
df["hour"] = df["window_start"].dt.floor("H")
|
| 116 |
+
|
| 117 |
+
hourly = df.groupby("hour").agg({
|
| 118 |
+
"actual_occupancy_mode": "mean",
|
| 119 |
+
"predicted_occupancy_mode": "mean",
|
| 120 |
+
}).reset_index()
|
| 121 |
+
|
| 122 |
+
return hourly.sort_values("hour")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_per_class_metrics(df: pd.DataFrame) -> pd.DataFrame:
|
| 126 |
+
"""Calculate per-class accuracy and counts for all 7 occupancy classes."""
|
| 127 |
+
if df.empty:
|
| 128 |
+
return pd.DataFrame()
|
| 129 |
+
|
| 130 |
+
results = []
|
| 131 |
+
# Always show all 7 classes (0-6), even if some have no data
|
| 132 |
+
for cls in range(7):
|
| 133 |
+
mask = df["actual_occupancy_mode"] == cls
|
| 134 |
+
subset = df[mask]
|
| 135 |
+
|
| 136 |
+
if len(subset) > 0:
|
| 137 |
+
correct = (subset["actual_occupancy_mode"] == subset["predicted_occupancy_mode"]).sum()
|
| 138 |
+
total = len(subset)
|
| 139 |
+
accuracy = correct / total
|
| 140 |
+
else:
|
| 141 |
+
correct = 0
|
| 142 |
+
total = 0
|
| 143 |
+
accuracy = None # No data for this class
|
| 144 |
+
|
| 145 |
+
results.append({
|
| 146 |
+
"class": cls,
|
| 147 |
+
"label": OCCUPANCY_LABELS.get(cls, f"Class {cls}"),
|
| 148 |
+
"count": total,
|
| 149 |
+
"correct": correct,
|
| 150 |
+
"accuracy": accuracy,
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
return pd.DataFrame(results)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def render_metric_card(label: str, value: float, format_str: str = "{:.1%}",
|
| 157 |
+
threshold_low: float = None, threshold_high: float = None):
|
| 158 |
+
"""Render a metric with conditional coloring."""
|
| 159 |
+
formatted = format_str.format(value) if value is not None else "N/A"
|
| 160 |
+
|
| 161 |
+
# Determine color
|
| 162 |
+
if threshold_low is not None and value < threshold_low:
|
| 163 |
+
color = "#ef4444" # Red
|
| 164 |
+
elif threshold_high is not None and value >= threshold_high:
|
| 165 |
+
color = "#22c55e" # Green
|
| 166 |
+
else:
|
| 167 |
+
color = "#eab308" # Yellow
|
| 168 |
+
|
| 169 |
+
st.markdown(f"""
|
| 170 |
+
<div style="
|
| 171 |
+
background: {color}11;
|
| 172 |
+
border: 1px solid {color}44;
|
| 173 |
+
border-radius: 8px;
|
| 174 |
+
padding: 16px;
|
| 175 |
+
text-align: center;
|
| 176 |
+
">
|
| 177 |
+
<div style="font-size: 0.9em; color: #6b7280; margin-bottom: 4px;">
|
| 178 |
+
{label}
|
| 179 |
+
</div>
|
| 180 |
+
<div style="font-size: 1.8em; font-weight: 600; color: {color};">
|
| 181 |
+
{formatted}
|
| 182 |
+
</div>
|
| 183 |
+
</div>
|
| 184 |
+
""", unsafe_allow_html=True)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Main page content
|
| 188 |
+
st.title("📊 Model Monitoring")
|
| 189 |
+
st.markdown("Track model performance over time using hindcast analysis.")
|
| 190 |
+
|
| 191 |
+
# Load data
|
| 192 |
+
with st.spinner("Loading monitoring data..."):
|
| 193 |
+
monitor_df = fetch_monitoring_data()
|
| 194 |
+
|
| 195 |
+
if monitor_df.empty:
|
| 196 |
+
st.warning("""
|
| 197 |
+
**No monitoring data available yet.**
|
| 198 |
+
|
| 199 |
+
Monitoring data is generated daily by the inference pipeline, which compares
|
| 200 |
+
yesterday's predictions to actual observed occupancy.
|
| 201 |
+
|
| 202 |
+
The inference pipeline runs at 09:00 UTC. Check back after it has run at least once.
|
| 203 |
+
""")
|
| 204 |
+
st.stop()
|
| 205 |
+
|
| 206 |
+
# Model version filter
|
| 207 |
+
available_versions = sorted(monitor_df["model_version"].dropna().unique())
|
| 208 |
+
if len(available_versions) > 1:
|
| 209 |
+
st.sidebar.subheader("Filter")
|
| 210 |
+
|
| 211 |
+
# Default to current model version if available
|
| 212 |
+
default_idx = available_versions.index(CURRENT_MODEL_VERSION) if CURRENT_MODEL_VERSION in available_versions else len(available_versions) - 1
|
| 213 |
+
|
| 214 |
+
selected_version = st.sidebar.selectbox(
|
| 215 |
+
"Model Version",
|
| 216 |
+
options=available_versions,
|
| 217 |
+
index=default_idx,
|
| 218 |
+
format_func=lambda x: f"v{int(x)}" + (" (current)" if x == CURRENT_MODEL_VERSION else "")
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Filter data by selected version
|
| 222 |
+
monitor_df = monitor_df[monitor_df["model_version"] == selected_version]
|
| 223 |
+
|
| 224 |
+
if monitor_df.empty:
|
| 225 |
+
st.warning(f"No monitoring data available for model v{int(selected_version)}.")
|
| 226 |
+
st.stop()
|
| 227 |
+
else:
|
| 228 |
+
selected_version = available_versions[0] if available_versions else None
|
| 229 |
+
|
| 230 |
+
# Show warning if viewing old model data
|
| 231 |
+
if selected_version is not None and selected_version != CURRENT_MODEL_VERSION:
|
| 232 |
+
st.info(f"""
|
| 233 |
+
**Viewing historical data from model v{int(selected_version)}.**
|
| 234 |
+
|
| 235 |
+
The current production model is v{CURRENT_MODEL_VERSION}.
|
| 236 |
+
Data for v{CURRENT_MODEL_VERSION} will appear after the inference pipeline runs.
|
| 237 |
+
""")
|
| 238 |
+
|
| 239 |
+
# Calculate aggregates
|
| 240 |
+
daily_metrics = get_daily_metrics(monitor_df)
|
| 241 |
+
hourly_comparison = get_hourly_comparison(monitor_df)
|
| 242 |
+
per_class = get_per_class_metrics(monitor_df)
|
| 243 |
+
|
| 244 |
+
# Get latest metrics
|
| 245 |
+
latest = daily_metrics.iloc[-1] if not daily_metrics.empty else None
|
| 246 |
+
|
| 247 |
+
# Header with model info
|
| 248 |
+
if latest is not None:
|
| 249 |
+
col1, col2 = st.columns([3, 1])
|
| 250 |
+
with col1:
|
| 251 |
+
st.markdown(f"**Model Version:** v{int(latest['model_version'])}")
|
| 252 |
+
with col2:
|
| 253 |
+
last_date = latest["date"].strftime("%Y-%m-%d")
|
| 254 |
+
st.markdown(f"**Last Updated:** {last_date}")
|
| 255 |
+
|
| 256 |
+
# Alert banner
|
| 257 |
+
if latest is not None and latest["accuracy"] < ALERT_THRESHOLD:
|
| 258 |
+
st.error(f"""
|
| 259 |
+
⚠️ **Model Performance Alert**
|
| 260 |
+
|
| 261 |
+
Accuracy ({latest['accuracy']:.1%}) is below the threshold ({ALERT_THRESHOLD:.0%}).
|
| 262 |
+
Consider investigating recent data quality or retraining the model.
|
| 263 |
+
""")
|
| 264 |
+
|
| 265 |
+
st.divider()
|
| 266 |
+
|
| 267 |
+
# Key metrics cards
|
| 268 |
+
st.subheader("Latest Performance")
|
| 269 |
+
|
| 270 |
+
if latest is not None:
|
| 271 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 272 |
+
|
| 273 |
+
with col1:
|
| 274 |
+
render_metric_card(
|
| 275 |
+
"Accuracy",
|
| 276 |
+
latest["accuracy"],
|
| 277 |
+
threshold_low=0.60,
|
| 278 |
+
threshold_high=0.70
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
with col2:
|
| 282 |
+
render_metric_card(
|
| 283 |
+
"F1 Score (Weighted)",
|
| 284 |
+
latest["f1_weighted"],
|
| 285 |
+
threshold_low=0.55,
|
| 286 |
+
threshold_high=0.65
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
with col3:
|
| 290 |
+
render_metric_card(
|
| 291 |
+
"Precision",
|
| 292 |
+
latest["precision"],
|
| 293 |
+
threshold_low=0.55,
|
| 294 |
+
threshold_high=0.70
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
with col4:
|
| 298 |
+
render_metric_card(
|
| 299 |
+
"MAE",
|
| 300 |
+
latest["mae"],
|
| 301 |
+
format_str="{:.2f}",
|
| 302 |
+
threshold_low=0.3, # Lower is better for MAE
|
| 303 |
+
threshold_high=0.6
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
st.divider()
|
| 307 |
+
|
| 308 |
+
# Accuracy trend chart
|
| 309 |
+
st.subheader("Accuracy Over Time")
|
| 310 |
+
|
| 311 |
+
if not daily_metrics.empty and len(daily_metrics) > 1:
|
| 312 |
+
chart_data = daily_metrics[["date", "accuracy"]].set_index("date")
|
| 313 |
+
st.line_chart(chart_data, use_container_width=True)
|
| 314 |
+
|
| 315 |
+
# Show trend
|
| 316 |
+
if len(daily_metrics) >= 2:
|
| 317 |
+
recent = daily_metrics["accuracy"].iloc[-1]
|
| 318 |
+
previous = daily_metrics["accuracy"].iloc[-2]
|
| 319 |
+
delta = recent - previous
|
| 320 |
+
trend = "📈" if delta > 0 else "📉" if delta < 0 else "➡️"
|
| 321 |
+
st.caption(f"{trend} Change from previous: {delta:+.1%}")
|
| 322 |
+
else:
|
| 323 |
+
st.info("Need at least 2 days of data to show trend chart.")
|
| 324 |
+
|
| 325 |
+
st.divider()
|
| 326 |
+
|
| 327 |
+
# Actual vs Predicted comparison
|
| 328 |
+
st.subheader("Actual vs Predicted Occupancy")
|
| 329 |
+
|
| 330 |
+
if not hourly_comparison.empty:
|
| 331 |
+
# Rename columns for display
|
| 332 |
+
chart_df = hourly_comparison.rename(columns={
|
| 333 |
+
"actual_occupancy_mode": "Actual",
|
| 334 |
+
"predicted_occupancy_mode": "Predicted"
|
| 335 |
+
}).set_index("hour")
|
| 336 |
+
|
| 337 |
+
st.line_chart(chart_df, use_container_width=True)
|
| 338 |
+
st.caption("Hourly average occupancy levels (0=Empty, 3=Standing)")
|
| 339 |
+
else:
|
| 340 |
+
st.info("No hourly comparison data available.")
|
| 341 |
+
|
| 342 |
+
st.divider()
|
| 343 |
+
|
| 344 |
+
# Per-class performance
|
| 345 |
+
st.subheader("Per-Class Performance")
|
| 346 |
+
|
| 347 |
+
if not per_class.empty:
|
| 348 |
+
# Color-coded display for all 7 classes
|
| 349 |
+
for _, row in per_class.iterrows():
|
| 350 |
+
cls = int(row["class"])
|
| 351 |
+
label = row["label"]
|
| 352 |
+
accuracy = row["accuracy"]
|
| 353 |
+
count = int(row["count"])
|
| 354 |
+
color = OCCUPANCY_COLORS.get(cls, "#6b7280")
|
| 355 |
+
|
| 356 |
+
col1, col2, col3 = st.columns([2, 1, 3])
|
| 357 |
+
|
| 358 |
+
with col1:
|
| 359 |
+
st.markdown(f"**{cls}** - {label}")
|
| 360 |
+
|
| 361 |
+
with col2:
|
| 362 |
+
if count > 0:
|
| 363 |
+
st.markdown(f"{count:,} samples")
|
| 364 |
+
else:
|
| 365 |
+
st.markdown("No data", help="No samples of this class in the monitoring data")
|
| 366 |
+
|
| 367 |
+
with col3:
|
| 368 |
+
if accuracy is not None and count > 0:
|
| 369 |
+
# Progress bar with color
|
| 370 |
+
st.progress(accuracy, text=f"{accuracy:.1%}")
|
| 371 |
+
else:
|
| 372 |
+
st.markdown("—", help="No accuracy data for this class")
|
| 373 |
+
|
| 374 |
+
# Explanation
|
| 375 |
+
st.caption("""
|
| 376 |
+
Per-class accuracy (recall) shows how well the model predicts each occupancy level.
|
| 377 |
+
Classes 4-6 are rare in Swedish transit data. Lower accuracy for rare classes is expected.
|
| 378 |
+
""")
|
| 379 |
+
else:
|
| 380 |
+
st.info("No per-class metrics available.")
|
| 381 |
+
|
| 382 |
+
st.divider()
|
| 383 |
+
|
| 384 |
+
# Raw data expander
|
| 385 |
+
with st.expander("View Raw Monitoring Data"):
|
| 386 |
+
if not monitor_df.empty:
|
| 387 |
+
st.dataframe(
|
| 388 |
+
monitor_df.sort_values("window_start", ascending=False).head(100),
|
| 389 |
+
use_container_width=True
|
| 390 |
+
)
|
| 391 |
+
st.caption(f"Showing latest 100 of {len(monitor_df):,} total records")
|
| 392 |
+
else:
|
| 393 |
+
st.info("No raw data available.")
|
| 394 |
+
|
| 395 |
+
# Footer
|
| 396 |
+
st.divider()
|
| 397 |
+
st.markdown(
|
| 398 |
+
"<div style='text-align: center; opacity: 0.6;'>Model monitoring powered by Hopsworks Feature Store</div>",
|
| 399 |
+
unsafe_allow_html=True
|
| 400 |
+
)
|
trip_info.py
CHANGED
|
@@ -1,65 +1,170 @@
|
|
| 1 |
import hopsworks
|
| 2 |
import os
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
def load_static_trip_info():
|
| 6 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 7 |
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 8 |
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 9 |
-
|
| 10 |
fs = project.get_feature_store()
|
| 11 |
-
fg = fs.get_feature_group("static_trip_info_fg", version=1)
|
| 12 |
-
|
| 13 |
df = fg.read()
|
| 14 |
return df
|
| 15 |
|
|
|
|
| 16 |
def load_static_stops_info():
|
| 17 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 18 |
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 19 |
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 20 |
-
|
| 21 |
fs = project.get_feature_store()
|
| 22 |
-
fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1)
|
| 23 |
-
|
| 24 |
df = fg.read()
|
| 25 |
return df
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
return None
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return {
|
| 39 |
-
"trip_id":
|
| 40 |
-
"route_short_name":
|
| 41 |
-
"route_long_name":
|
| 42 |
-
"trip_headsign":
|
|
|
|
|
|
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
|
| 46 |
def find_closest_stop(lat, lon, trip_id, stops_df):
|
| 47 |
"""
|
| 48 |
Returns the closest stop to a given lat/lon for the specified trip_id.
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
-
if stops_df is None:
|
| 51 |
-
return None
|
|
|
|
| 52 |
# Filter stops for this trip
|
| 53 |
trip_stops = stops_df[stops_df["trip_id"] == trip_id]
|
| 54 |
if trip_stops.empty:
|
| 55 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
closest_stop = trip_stops.iloc[idx_min]
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import hopsworks
|
| 2 |
import os
|
| 3 |
import numpy as np
|
| 4 |
+
from math import radians, sin, cos, sqrt, atan2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def haversine_distance(lat1, lon1, lat2, lon2):
|
| 8 |
+
"""Calculate distance in meters between two points using haversine formula."""
|
| 9 |
+
R = 6371000 # Earth's radius in meters
|
| 10 |
+
dlat = radians(lat2 - lat1)
|
| 11 |
+
dlon = radians(lon2 - lon1)
|
| 12 |
+
|
| 13 |
+
a = sin(dlat/2)**2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(dlon/2)**2
|
| 14 |
+
c = 2 * atan2(sqrt(a), sqrt(1-a))
|
| 15 |
+
return R * c
|
| 16 |
+
|
| 17 |
|
| 18 |
def load_static_trip_info():
|
| 19 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 20 |
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 21 |
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 22 |
+
|
| 23 |
fs = project.get_feature_store()
|
| 24 |
+
fg = fs.get_feature_group("static_trip_info_fg", version=1)
|
| 25 |
+
|
| 26 |
df = fg.read()
|
| 27 |
return df
|
| 28 |
|
| 29 |
+
|
| 30 |
def load_static_stops_info():
|
| 31 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 32 |
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 33 |
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 34 |
+
|
| 35 |
fs = project.get_feature_store()
|
| 36 |
+
fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1)
|
| 37 |
+
|
| 38 |
df = fg.read()
|
| 39 |
return df
|
| 40 |
|
| 41 |
+
|
| 42 |
+
def time_to_seconds(t):
|
| 43 |
+
"""Convert time string (HH:MM:SS) to seconds since midnight."""
|
| 44 |
+
if t is None:
|
| 45 |
+
return None
|
| 46 |
+
try:
|
| 47 |
+
h, m, s = map(int, str(t).split(":"))
|
| 48 |
+
return h * 3600 + m * 60 + s
|
| 49 |
+
except (ValueError, AttributeError):
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def find_nearest_trip(lat, lon, datetime_obj, static_trip_and_stops_df, max_radius_m=500):
|
| 54 |
"""
|
| 55 |
+
Find the nearest trip to a given location and time.
|
| 56 |
+
|
| 57 |
+
Uses haversine distance and filters by time window.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
lat, lon: Location to search near
|
| 61 |
+
datetime_obj: Target datetime
|
| 62 |
+
static_trip_and_stops_df: DataFrame with trip and stop info
|
| 63 |
+
max_radius_m: Maximum search radius in meters (default 500m)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Dict with trip info or None if no nearby trip found
|
| 67 |
"""
|
| 68 |
+
if static_trip_and_stops_df is None or static_trip_and_stops_df.empty:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
target_s = datetime_obj.hour * 3600 + datetime_obj.minute * 60
|
| 72 |
+
|
| 73 |
+
# Compute distance to each stop
|
| 74 |
+
df = static_trip_and_stops_df.copy()
|
| 75 |
+
|
| 76 |
+
# Check if required columns exist
|
| 77 |
+
if "stop_lat" not in df.columns or "stop_lon" not in df.columns:
|
| 78 |
return None
|
| 79 |
+
|
| 80 |
+
df["distance_m"] = df.apply(
|
| 81 |
+
lambda r: haversine_distance(lat, lon, r["stop_lat"], r["stop_lon"]),
|
| 82 |
+
axis=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Geographic filter
|
| 86 |
+
nearby = df[df["distance_m"] <= max_radius_m]
|
| 87 |
+
if nearby.empty:
|
| 88 |
+
# Try with larger radius
|
| 89 |
+
nearby = df[df["distance_m"] <= max_radius_m * 2]
|
| 90 |
+
if nearby.empty:
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
# Build time window check if arrival/departure times are available
|
| 94 |
+
if "arrival_time" in nearby.columns and "departure_time" in nearby.columns:
|
| 95 |
+
nearby = nearby.copy()
|
| 96 |
+
nearby["arr_s"] = nearby["arrival_time"].apply(time_to_seconds)
|
| 97 |
+
nearby["dep_s"] = nearby["departure_time"].apply(time_to_seconds)
|
| 98 |
+
|
| 99 |
+
# Keep trips where we're near a scheduled stop time
|
| 100 |
+
time_filtered = nearby[
|
| 101 |
+
(nearby["arr_s"].notna()) &
|
| 102 |
+
((nearby["arr_s"] <= target_s + 3600) & (nearby["arr_s"] >= target_s - 3600))
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
if not time_filtered.empty:
|
| 106 |
+
nearby = time_filtered
|
| 107 |
+
|
| 108 |
+
# Choose the one whose stop is closest to the click
|
| 109 |
+
best = nearby.sort_values("distance_m").iloc[0]
|
| 110 |
+
|
| 111 |
return {
|
| 112 |
+
"trip_id": best.get("trip_id"),
|
| 113 |
+
"route_short_name": best.get("route_short_name"),
|
| 114 |
+
"route_long_name": best.get("route_long_name"),
|
| 115 |
+
"trip_headsign": best.get("trip_headsign"),
|
| 116 |
+
"closest_stop": best.get("stop_name"),
|
| 117 |
+
"closest_stop_headsign": best.get("stop_headsign"),
|
| 118 |
+
"distance_m": round(best["distance_m"]),
|
| 119 |
}
|
| 120 |
|
| 121 |
|
| 122 |
def find_closest_stop(lat, lon, trip_id, stops_df):
|
| 123 |
"""
|
| 124 |
Returns the closest stop to a given lat/lon for the specified trip_id.
|
| 125 |
+
|
| 126 |
+
Returns tuple of (stop_name, stop_headsign) or (None, None) if not found.
|
| 127 |
"""
|
| 128 |
+
if stops_df is None or stops_df.empty:
|
| 129 |
+
return None, None
|
| 130 |
+
|
| 131 |
# Filter stops for this trip
|
| 132 |
trip_stops = stops_df[stops_df["trip_id"] == trip_id]
|
| 133 |
if trip_stops.empty:
|
| 134 |
+
return None, None
|
| 135 |
+
|
| 136 |
+
# Compute distances using haversine
|
| 137 |
+
distances = trip_stops.apply(
|
| 138 |
+
lambda r: haversine_distance(lat, lon, r["stop_lat"], r["stop_lon"]),
|
| 139 |
+
axis=1
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
closest_stop = trip_stops.loc[distances.idxmin()]
|
| 143 |
+
|
| 144 |
+
return closest_stop.get("stop_name"), closest_stop.get("stop_headsign")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_trip_forecasts(fs, hour, weekday):
|
| 148 |
+
"""
|
| 149 |
+
Load trip forecasts from forecast_fg for a specific hour and weekday.
|
| 150 |
+
|
| 151 |
+
Returns DataFrame with trip predictions or empty DataFrame if not available.
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
forecast_fg = fs.get_feature_group("forecast_fg", version=1)
|
| 155 |
+
df = forecast_fg.read()
|
| 156 |
+
|
| 157 |
+
if df.empty:
|
| 158 |
+
return df
|
| 159 |
|
| 160 |
+
# Filter to matching hour and weekday
|
| 161 |
+
df["window_start"] = pd.to_datetime(df["window_start"])
|
| 162 |
+
df["hour"] = df["window_start"].dt.hour
|
| 163 |
+
df["weekday"] = df["window_start"].dt.weekday
|
| 164 |
|
| 165 |
+
filtered = df[(df["hour"] == hour) & (df["weekday"] == weekday)]
|
| 166 |
+
return filtered
|
|
|
|
| 167 |
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Could not load trip forecasts: {e}")
|
| 170 |
+
return pd.DataFrame()
|