Spaces:
Sleeping
Sleeping
feat: add high-res heatmap v3 with Hopsworks storage
Browse files- Update app.py with v3/v2 fallback fetch from Hopsworks
- Add batch prediction to predictor.py
- Add contours.py for GeoJSON grid cell generation
- Add trip_info.py for bus stop/route info
- Update requirements.txt with scipy, shapely, matplotlib
- app.py +294 -72
- contours.py +350 -0
- predictor.py +96 -11
- requirements.txt +3 -0
- trip_info.py +65 -0
- weather.py +24 -2
app.py
CHANGED
|
@@ -6,6 +6,7 @@ bus crowding in Östergötland.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import streamlit as st
|
|
|
|
| 9 |
|
| 10 |
# Page config - MUST be first Streamlit command
|
| 11 |
st.set_page_config(
|
|
@@ -16,32 +17,37 @@ st.set_page_config(
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import folium
|
| 19 |
-
from folium.plugins import HeatMap
|
| 20 |
from streamlit_folium import st_folium
|
| 21 |
import numpy as np
|
| 22 |
from datetime import datetime, timedelta
|
| 23 |
|
|
|
|
|
|
|
| 24 |
# Import prediction and data fetching modules
|
| 25 |
from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
|
| 26 |
from weather import get_weather_for_prediction
|
| 27 |
from holidays import get_holiday_features
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Constants
|
| 30 |
DEFAULT_LAT = 58.4108
|
| 31 |
DEFAULT_LON = 15.6214
|
| 32 |
-
DEFAULT_ZOOM =
|
| 33 |
|
|
|
|
|
|
|
| 34 |
BOUNDS = {
|
| 35 |
-
"min_lat":
|
| 36 |
-
"max_lat": 58.
|
| 37 |
-
"min_lon": 14.
|
| 38 |
-
"max_lon": 16.
|
| 39 |
}
|
| 40 |
|
| 41 |
-
# Color scheme for occupancy levels
|
| 42 |
OCCUPANCY_COLORS = {
|
| 43 |
0: "#22c55e", # Empty - green
|
| 44 |
-
1: "#
|
| 45 |
2: "#eab308", # Few seats - yellow
|
| 46 |
3: "#f97316", # Standing - orange
|
| 47 |
4: "#ef4444", # Crushed - red
|
|
@@ -49,6 +55,25 @@ OCCUPANCY_COLORS = {
|
|
| 49 |
6: "#6b7280", # Not accepting - gray
|
| 50 |
}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
@st.cache_resource
|
| 54 |
def get_model():
|
|
@@ -60,40 +85,167 @@ def get_model():
|
|
| 60 |
return None
|
| 61 |
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
model = get_model()
|
| 66 |
if model is None:
|
| 67 |
-
return
|
| 68 |
|
| 69 |
-
#
|
| 70 |
lat_steps = 15
|
| 71 |
lon_steps = 20
|
| 72 |
lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps)
|
| 73 |
lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps)
|
| 74 |
|
| 75 |
-
|
|
|
|
| 76 |
|
|
|
|
| 77 |
for lat in lats:
|
| 78 |
for lon in lons:
|
| 79 |
try:
|
| 80 |
-
pred_class, confidence, _ =
|
| 81 |
lat=lat, lon=lon, hour=hour, day_of_week=day_of_week,
|
| 82 |
weather=weather, holidays=holidays
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
intensity = pred_class / 5.0 # Normalize to 0-1
|
| 86 |
-
if intensity > 0.1: # Only show if there's some crowding
|
| 87 |
-
heatmap_data.append([lat, lon, intensity])
|
| 88 |
except Exception:
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
|
| 95 |
-
|
| 96 |
-
"""Create a Folium map with optional marker and
|
| 97 |
center_lat = selected_lat if selected_lat else DEFAULT_LAT
|
| 98 |
center_lon = selected_lon if selected_lon else DEFAULT_LON
|
| 99 |
|
|
@@ -103,30 +255,37 @@ def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
|
|
| 103 |
tiles="CartoDB positron"
|
| 104 |
)
|
| 105 |
|
| 106 |
-
# Add
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
folium.Rectangle(
|
| 108 |
bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
|
| 109 |
[BOUNDS["max_lat"], BOUNDS["max_lon"]]],
|
| 110 |
-
color="#
|
| 111 |
fill=False,
|
| 112 |
-
weight=
|
| 113 |
-
opacity=0.
|
|
|
|
| 114 |
).add_to(m)
|
| 115 |
|
| 116 |
-
# Add heatmap if enabled
|
| 117 |
-
if show_heatmap and heatmap_data and len(heatmap_data) > 0:
|
| 118 |
-
HeatMap(
|
| 119 |
-
data=heatmap_data,
|
| 120 |
-
min_opacity=0.3,
|
| 121 |
-
radius=25,
|
| 122 |
-
blur=15,
|
| 123 |
-
).add_to(m)
|
| 124 |
-
|
| 125 |
# Add marker if location selected
|
| 126 |
if selected_lat and selected_lon:
|
| 127 |
folium.Marker(
|
| 128 |
[selected_lat, selected_lon],
|
| 129 |
tooltip=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}",
|
|
|
|
| 130 |
).add_to(m)
|
| 131 |
|
| 132 |
return m
|
|
@@ -146,7 +305,7 @@ def make_prediction(lat, lon, selected_datetime):
|
|
| 146 |
weather = get_weather_for_prediction(lat, lon, selected_datetime)
|
| 147 |
holidays = get_holiday_features(selected_datetime)
|
| 148 |
|
| 149 |
-
pred_class, confidence, probs =
|
| 150 |
lat=lat, lon=lon,
|
| 151 |
hour=selected_datetime.hour,
|
| 152 |
day_of_week=selected_datetime.weekday(),
|
|
@@ -154,10 +313,16 @@ def make_prediction(lat, lon, selected_datetime):
|
|
| 154 |
holidays=holidays
|
| 155 |
)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
return pred_class, confidence, {
|
| 158 |
"weather": weather,
|
| 159 |
"holidays": holidays,
|
| 160 |
-
"datetime": selected_datetime
|
|
|
|
| 161 |
}
|
| 162 |
except Exception as e:
|
| 163 |
return None, None, str(e)
|
|
@@ -170,13 +335,13 @@ if "selected_lon" not in st.session_state:
|
|
| 170 |
st.session_state.selected_lon = DEFAULT_LON
|
| 171 |
|
| 172 |
# Header
|
| 173 |
-
st.title("
|
| 174 |
-
st.markdown("*
|
| 175 |
|
| 176 |
# Check if model is available
|
| 177 |
model = get_model()
|
| 178 |
if model is None:
|
| 179 |
-
st.error("
|
| 180 |
st.stop()
|
| 181 |
|
| 182 |
# Sidebar controls
|
|
@@ -198,19 +363,20 @@ with st.sidebar:
|
|
| 198 |
|
| 199 |
# View mode
|
| 200 |
st.subheader("View Mode")
|
| 201 |
-
show_heatmap = st.toggle("Show Crowding Forecast", value=
|
| 202 |
help="Display predicted crowding across the region")
|
| 203 |
|
| 204 |
if show_heatmap:
|
| 205 |
-
st.
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
|
| 215 |
st.divider()
|
| 216 |
|
|
@@ -224,10 +390,10 @@ with st.sidebar:
|
|
| 224 |
- 🕐 Time of day
|
| 225 |
- 📅 Day of week
|
| 226 |
- 🌡️ Weather conditions
|
| 227 |
-
-
|
| 228 |
|
| 229 |
**Data sources:**
|
| 230 |
-
- Bus occupancy data from Östgötatrafiken
|
| 231 |
- Weather from Open-Meteo
|
| 232 |
- Holidays from Svenska Dagar API
|
| 233 |
|
|
@@ -238,24 +404,34 @@ with st.sidebar:
|
|
| 238 |
col1, col2 = st.columns([2, 1])
|
| 239 |
|
| 240 |
with col1:
|
| 241 |
-
st.subheader("
|
| 242 |
|
| 243 |
-
# Get
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# Create and display map
|
| 247 |
m = create_map(
|
| 248 |
selected_lat=st.session_state.selected_lat,
|
| 249 |
selected_lon=st.session_state.selected_lon,
|
| 250 |
show_heatmap=show_heatmap,
|
| 251 |
-
|
| 252 |
)
|
| 253 |
|
|
|
|
| 254 |
map_data = st_folium(
|
| 255 |
m,
|
| 256 |
height=500,
|
| 257 |
use_container_width=True,
|
| 258 |
-
key="
|
| 259 |
)
|
| 260 |
|
| 261 |
# Handle map clicks
|
|
@@ -266,17 +442,18 @@ with col1:
|
|
| 266 |
st.rerun()
|
| 267 |
|
| 268 |
with col2:
|
| 269 |
-
st.subheader("
|
| 270 |
|
| 271 |
# Show selected coordinates
|
| 272 |
st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
|
| 273 |
|
| 274 |
# Make prediction
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 280 |
|
| 281 |
if pred_class is not None:
|
| 282 |
label_info = OCCUPANCY_LABELS[pred_class]
|
|
@@ -307,17 +484,62 @@ with col2:
|
|
| 307 |
if isinstance(result, dict):
|
| 308 |
weather = result["weather"]
|
| 309 |
holidays = result["holidays"]
|
|
|
|
| 310 |
|
| 311 |
-
day_type = "
|
| 312 |
-
"
|
| 313 |
)
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
elif isinstance(result, str):
|
| 323 |
st.error(result)
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import streamlit as st
|
| 9 |
+
import json
|
| 10 |
|
| 11 |
# Page config - MUST be first Streamlit command
|
| 12 |
st.set_page_config(
|
|
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
import folium
|
|
|
|
| 20 |
from streamlit_folium import st_folium
|
| 21 |
import numpy as np
|
| 22 |
from datetime import datetime, timedelta
|
| 23 |
|
| 24 |
+
import hopsworks
|
| 25 |
+
|
| 26 |
# Import prediction and data fetching modules
|
| 27 |
from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
|
| 28 |
from weather import get_weather_for_prediction
|
| 29 |
from holidays import get_holiday_features
|
| 30 |
+
from trip_info import load_static_trip_info, find_nearest_trip, load_static_stops_info, find_closest_stop
|
| 31 |
+
from contours import load_contours_from_file, grid_to_cells_geojson
|
| 32 |
|
| 33 |
# Constants
|
| 34 |
DEFAULT_LAT = 58.4108
|
| 35 |
DEFAULT_LON = 15.6214
|
| 36 |
+
DEFAULT_ZOOM = 9 # Slightly zoomed out to show more of the region
|
| 37 |
|
| 38 |
+
# Bounds derived from actual GTFS stop locations (3119 stops)
|
| 39 |
+
# Run ui/get_boundaries.py to recalculate if needed
|
| 40 |
BOUNDS = {
|
| 41 |
+
"min_lat": 56.6414,
|
| 42 |
+
"max_lat": 58.8654,
|
| 43 |
+
"min_lon": 14.6144,
|
| 44 |
+
"max_lon": 16.9578,
|
| 45 |
}
|
| 46 |
|
| 47 |
+
# Color scheme for occupancy levels (must match contours.py CLASS_COLORS)
|
| 48 |
OCCUPANCY_COLORS = {
|
| 49 |
0: "#22c55e", # Empty - green
|
| 50 |
+
1: "#84cc16", # Many seats - lime
|
| 51 |
2: "#eab308", # Few seats - yellow
|
| 52 |
3: "#f97316", # Standing - orange
|
| 53 |
4: "#ef4444", # Crushed - red
|
|
|
|
| 55 |
6: "#6b7280", # Not accepting - gray
|
| 56 |
}
|
| 57 |
|
| 58 |
+
# Lazy-load static data (deferred to avoid blocking app startup)
|
| 59 |
+
@st.cache_resource
|
| 60 |
+
def get_static_trip_df():
|
| 61 |
+
"""Load static trip info from Hopsworks (cached)."""
|
| 62 |
+
try:
|
| 63 |
+
return load_static_trip_info()
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Warning: Could not load static trip info: {e}")
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
@st.cache_resource
|
| 69 |
+
def get_static_stops_df():
|
| 70 |
+
"""Load static stops info from Hopsworks (cached)."""
|
| 71 |
+
try:
|
| 72 |
+
return load_static_stops_info()
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Warning: Could not load static stops info: {e}")
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
|
| 78 |
@st.cache_resource
|
| 79 |
def get_model():
|
|
|
|
| 85 |
return None
|
| 86 |
|
| 87 |
|
| 88 |
+
@st.cache_data
|
| 89 |
+
def cached_predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
|
| 90 |
+
return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@st.cache_data(ttl=3600) # Cache for 1 hour
|
| 94 |
+
def fetch_heatmaps_from_hopsworks():
|
| 95 |
+
"""
|
| 96 |
+
Fetch all precomputed heatmaps from Hopsworks Feature Store.
|
| 97 |
+
|
| 98 |
+
Tries v3 (high-res 40x50) first, falls back to v2 (low-res 20x25).
|
| 99 |
+
|
| 100 |
+
Returns dict mapping (hour, weekday) -> GeoJSON FeatureCollection
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
print("Fetching heatmaps from Hopsworks...")
|
| 104 |
+
project = hopsworks.login()
|
| 105 |
+
fs = project.get_feature_store()
|
| 106 |
+
|
| 107 |
+
# Try v3 first (high-res), fall back to v2 (low-res)
|
| 108 |
+
for version in [3, 2]:
|
| 109 |
+
try:
|
| 110 |
+
heatmap_fg = fs.get_feature_group("heatmap_geojson_fg", version=version)
|
| 111 |
+
df = heatmap_fg.read()
|
| 112 |
+
|
| 113 |
+
if df is not None and not df.empty:
|
| 114 |
+
# Convert to dict with tuple keys
|
| 115 |
+
heatmaps = {}
|
| 116 |
+
for _, row in df.iterrows():
|
| 117 |
+
key = (int(row["hour"]), int(row["weekday"]))
|
| 118 |
+
geojson = json.loads(row["geojson"])
|
| 119 |
+
heatmaps[key] = geojson
|
| 120 |
+
|
| 121 |
+
print(f"Loaded {len(heatmaps)} heatmaps from Hopsworks v{version}")
|
| 122 |
+
return heatmaps
|
| 123 |
+
else:
|
| 124 |
+
print(f"No data in Hopsworks v{version}, trying fallback...")
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Could not fetch v{version}: {e}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
print("No heatmap data found in any Hopsworks version")
|
| 130 |
+
return {}
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"Error fetching heatmaps from Hopsworks: {e}")
|
| 134 |
+
return {}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_precomputed_contours():
|
| 138 |
+
"""Load precomputed contour GeoJSON from file (not cached to pick up new files)."""
|
| 139 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 140 |
+
contours_path = os.path.join(script_dir, "precomputed_contours.json")
|
| 141 |
+
|
| 142 |
+
if os.path.exists(contours_path):
|
| 143 |
+
try:
|
| 144 |
+
contours = load_contours_from_file(contours_path)
|
| 145 |
+
print(f"Loaded {len(contours)} precomputed time slots from {contours_path}")
|
| 146 |
+
return contours
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"Error loading contours: {e}")
|
| 149 |
+
return {}
|
| 150 |
+
print(f"Contours file not found: {contours_path}")
|
| 151 |
+
return {}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def generate_contours_on_demand(hour, day_of_week, weather, holidays):
|
| 155 |
+
"""
|
| 156 |
+
Generate grid cell GeoJSON on-demand if precomputed data is not available.
|
| 157 |
+
This is slower but provides a fallback.
|
| 158 |
+
"""
|
| 159 |
model = get_model()
|
| 160 |
if model is None:
|
| 161 |
+
return None
|
| 162 |
|
| 163 |
+
# Grid for on-demand generation (smaller for speed)
|
| 164 |
lat_steps = 15
|
| 165 |
lon_steps = 20
|
| 166 |
lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps)
|
| 167 |
lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps)
|
| 168 |
|
| 169 |
+
lat_step = (BOUNDS["max_lat"] - BOUNDS["min_lat"]) / (lat_steps - 1)
|
| 170 |
+
lon_step = (BOUNDS["max_lon"] - BOUNDS["min_lon"]) / (lon_steps - 1)
|
| 171 |
|
| 172 |
+
prediction_data = []
|
| 173 |
for lat in lats:
|
| 174 |
for lon in lons:
|
| 175 |
try:
|
| 176 |
+
pred_class, confidence, _ = cached_predict_occupancy(
|
| 177 |
lat=lat, lon=lon, hour=hour, day_of_week=day_of_week,
|
| 178 |
weather=weather, holidays=holidays
|
| 179 |
)
|
| 180 |
+
prediction_data.append([lat, lon, pred_class])
|
|
|
|
|
|
|
|
|
|
| 181 |
except Exception:
|
| 182 |
+
prediction_data.append([lat, lon, 0])
|
| 183 |
+
|
| 184 |
+
# Convert to GeoJSON grid cells
|
| 185 |
+
return grid_to_cells_geojson(prediction_data, lat_step, lon_step)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_test_contour_geojson():
|
| 189 |
+
"""
|
| 190 |
+
Return a simple hardcoded test GeoJSON to verify rendering works.
|
| 191 |
+
Creates a small grid of cells with different colors.
|
| 192 |
+
"""
|
| 193 |
+
# Create a 3x3 grid of test cells around Linköping
|
| 194 |
+
center_lat = 58.41
|
| 195 |
+
center_lon = 15.62
|
| 196 |
+
cell_size = 0.15
|
| 197 |
+
|
| 198 |
+
# Test predictions: mix of classes
|
| 199 |
+
test_data = [
|
| 200 |
+
(center_lat - cell_size, center_lon - cell_size, 0), # green
|
| 201 |
+
(center_lat - cell_size, center_lon, 0), # green
|
| 202 |
+
(center_lat - cell_size, center_lon + cell_size, 1), # green
|
| 203 |
+
(center_lat, center_lon - cell_size, 0), # green
|
| 204 |
+
(center_lat, center_lon, 2), # yellow
|
| 205 |
+
(center_lat, center_lon + cell_size, 2), # yellow
|
| 206 |
+
(center_lat + cell_size, center_lon - cell_size, 0), # green
|
| 207 |
+
(center_lat + cell_size, center_lon, 3), # orange
|
| 208 |
+
(center_lat + cell_size, center_lon + cell_size, 0), # green
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
return grid_to_cells_geojson(test_data, cell_size, cell_size)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def get_contour_geojson(hour, day_of_week, weather=None, holidays=None):
|
| 215 |
+
"""
|
| 216 |
+
Get contour GeoJSON for the given hour and day of week.
|
| 217 |
+
|
| 218 |
+
Tries sources in order:
|
| 219 |
+
1. Hopsworks Feature Store (primary)
|
| 220 |
+
2. Local JSON file (fallback)
|
| 221 |
+
3. Test contours (last resort)
|
| 222 |
+
"""
|
| 223 |
+
key = (hour, day_of_week)
|
| 224 |
+
|
| 225 |
+
# Try Hopsworks first
|
| 226 |
+
hopsworks_heatmaps = fetch_heatmaps_from_hopsworks()
|
| 227 |
+
if key in hopsworks_heatmaps:
|
| 228 |
+
geojson = hopsworks_heatmaps[key]
|
| 229 |
+
n_features = len(geojson.get("features", []))
|
| 230 |
+
print(f"Found heatmap in Hopsworks for {key}: {n_features} features")
|
| 231 |
+
return geojson
|
| 232 |
+
|
| 233 |
+
# Fall back to local JSON file
|
| 234 |
+
precomputed = load_precomputed_contours()
|
| 235 |
+
if key in precomputed:
|
| 236 |
+
geojson = precomputed[key]
|
| 237 |
+
n_features = len(geojson.get("features", []))
|
| 238 |
+
print(f"Found heatmap in local file for {key}: {n_features} features")
|
| 239 |
+
return geojson
|
| 240 |
+
|
| 241 |
+
# Last resort: test contours
|
| 242 |
+
print(f"No heatmap for {key}, using test contours")
|
| 243 |
+
return get_test_contour_geojson()
|
| 244 |
|
| 245 |
|
| 246 |
def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
|
| 247 |
+
contour_geojson=None):
|
| 248 |
+
"""Create a Folium map with optional marker and contour overlay."""
|
| 249 |
center_lat = selected_lat if selected_lat else DEFAULT_LAT
|
| 250 |
center_lon = selected_lon if selected_lon else DEFAULT_LON
|
| 251 |
|
|
|
|
| 255 |
tiles="CartoDB positron"
|
| 256 |
)
|
| 257 |
|
| 258 |
+
# Add contour overlay if enabled
|
| 259 |
+
if show_heatmap and contour_geojson and contour_geojson.get("features"):
|
| 260 |
+
# Add each contour level as a separate GeoJSON layer
|
| 261 |
+
folium.GeoJson(
|
| 262 |
+
contour_geojson,
|
| 263 |
+
style_function=lambda feature: {
|
| 264 |
+
'fillColor': feature['properties']['color'],
|
| 265 |
+
'fillOpacity': feature['properties'].get('fillOpacity', 0.35),
|
| 266 |
+
'color': 'none', # No border
|
| 267 |
+
'weight': 0
|
| 268 |
+
},
|
| 269 |
+
name="Crowding Forecast"
|
| 270 |
+
).add_to(m)
|
| 271 |
+
|
| 272 |
+
# Add coverage area rectangle (subtle border)
|
| 273 |
folium.Rectangle(
|
| 274 |
bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
|
| 275 |
[BOUNDS["max_lat"], BOUNDS["max_lon"]]],
|
| 276 |
+
color="#6b7280",
|
| 277 |
fill=False,
|
| 278 |
+
weight=1,
|
| 279 |
+
opacity=0.3,
|
| 280 |
+
dash_array="5, 5",
|
| 281 |
).add_to(m)
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
# Add marker if location selected
|
| 284 |
if selected_lat and selected_lon:
|
| 285 |
folium.Marker(
|
| 286 |
[selected_lat, selected_lon],
|
| 287 |
tooltip=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}",
|
| 288 |
+
icon=folium.Icon(color="blue", icon="info-sign")
|
| 289 |
).add_to(m)
|
| 290 |
|
| 291 |
return m
|
|
|
|
| 305 |
weather = get_weather_for_prediction(lat, lon, selected_datetime)
|
| 306 |
holidays = get_holiday_features(selected_datetime)
|
| 307 |
|
| 308 |
+
pred_class, confidence, probs = cached_predict_occupancy(
|
| 309 |
lat=lat, lon=lon,
|
| 310 |
hour=selected_datetime.hour,
|
| 311 |
day_of_week=selected_datetime.weekday(),
|
|
|
|
| 313 |
holidays=holidays
|
| 314 |
)
|
| 315 |
|
| 316 |
+
trip_info = None
|
| 317 |
+
static_trip_df = get_static_trip_df()
|
| 318 |
+
if static_trip_df is not None:
|
| 319 |
+
trip_info = find_nearest_trip(lat, lon, selected_datetime, static_trip_df)
|
| 320 |
+
|
| 321 |
return pred_class, confidence, {
|
| 322 |
"weather": weather,
|
| 323 |
"holidays": holidays,
|
| 324 |
+
"datetime": selected_datetime,
|
| 325 |
+
"trip_info": trip_info
|
| 326 |
}
|
| 327 |
except Exception as e:
|
| 328 |
return None, None, str(e)
|
|
|
|
| 335 |
st.session_state.selected_lon = DEFAULT_LON
|
| 336 |
|
| 337 |
# Header
|
| 338 |
+
st.title("HappySardines")
|
| 339 |
+
st.markdown("*Predicted bus crowding in Östergötland*")
|
| 340 |
|
| 341 |
# Check if model is available
|
| 342 |
model = get_model()
|
| 343 |
if model is None:
|
| 344 |
+
st.error("Could not load prediction model. Please check the configuration.")
|
| 345 |
st.stop()
|
| 346 |
|
| 347 |
# Sidebar controls
|
|
|
|
| 363 |
|
| 364 |
# View mode
|
| 365 |
st.subheader("View Mode")
|
| 366 |
+
show_heatmap = st.toggle("Show Crowding Forecast", value=True,
|
| 367 |
help="Display predicted crowding across the region")
|
| 368 |
|
| 369 |
if show_heatmap:
|
| 370 |
+
st.markdown("""
|
| 371 |
+
**Legend:**
|
| 372 |
+
<div style="display: flex; flex-direction: column; gap: 4px; font-size: 14px;">
|
| 373 |
+
<div><span style="display: inline-block; width: 16px; height: 16px; background: #22c55e; border-radius: 3px; vertical-align: middle;"></span> Empty</div>
|
| 374 |
+
<div><span style="display: inline-block; width: 16px; height: 16px; background: #84cc16; border-radius: 3px; vertical-align: middle;"></span> Many seats</div>
|
| 375 |
+
<div><span style="display: inline-block; width: 16px; height: 16px; background: #eab308; border-radius: 3px; vertical-align: middle;"></span> Few seats</div>
|
| 376 |
+
<div><span style="display: inline-block; width: 16px; height: 16px; background: #f97316; border-radius: 3px; vertical-align: middle;"></span> Standing room</div>
|
| 377 |
+
<div><span style="display: inline-block; width: 16px; height: 16px; background: #ef4444; border-radius: 3px; vertical-align: middle;"></span> Crowded</div>
|
| 378 |
+
</div>
|
| 379 |
+
""", unsafe_allow_html=True)
|
| 380 |
|
| 381 |
st.divider()
|
| 382 |
|
|
|
|
| 390 |
- 🕐 Time of day
|
| 391 |
- 📅 Day of week
|
| 392 |
- 🌡️ Weather conditions
|
| 393 |
+
- 🇸🇪 Holidays
|
| 394 |
|
| 395 |
**Data sources:**
|
| 396 |
+
- Bus occupancy data from Östgötatrafiken (KODA API)
|
| 397 |
- Weather from Open-Meteo
|
| 398 |
- Holidays from Svenska Dagar API
|
| 399 |
|
|
|
|
| 404 |
col1, col2 = st.columns([2, 1])
|
| 405 |
|
| 406 |
with col1:
|
| 407 |
+
st.subheader("Click on the map to select a location")
|
| 408 |
|
| 409 |
+
# Get weather/holidays for on-demand generation fallback
|
| 410 |
+
weather = get_weather_for_prediction(DEFAULT_LAT, DEFAULT_LON, selected_datetime)
|
| 411 |
+
holidays = get_holiday_features(selected_datetime)
|
| 412 |
+
|
| 413 |
+
# Get contour GeoJSON for current hour/day
|
| 414 |
+
contour_geojson = None
|
| 415 |
+
if show_heatmap:
|
| 416 |
+
contour_geojson = get_contour_geojson(
|
| 417 |
+
hour, selected_date.weekday(),
|
| 418 |
+
weather=weather, holidays=holidays
|
| 419 |
+
)
|
| 420 |
|
| 421 |
# Create and display map
|
| 422 |
m = create_map(
|
| 423 |
selected_lat=st.session_state.selected_lat,
|
| 424 |
selected_lon=st.session_state.selected_lon,
|
| 425 |
show_heatmap=show_heatmap,
|
| 426 |
+
contour_geojson=contour_geojson
|
| 427 |
)
|
| 428 |
|
| 429 |
+
# Render the map - key includes hour and day to force re-render on time change
|
| 430 |
map_data = st_folium(
|
| 431 |
m,
|
| 432 |
height=500,
|
| 433 |
use_container_width=True,
|
| 434 |
+
key=f"map_{hour}_{selected_date.weekday()}"
|
| 435 |
)
|
| 436 |
|
| 437 |
# Handle map clicks
|
|
|
|
| 442 |
st.rerun()
|
| 443 |
|
| 444 |
with col2:
|
| 445 |
+
st.subheader("Prediction")
|
| 446 |
|
| 447 |
# Show selected coordinates
|
| 448 |
st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
|
| 449 |
|
| 450 |
# Make prediction
|
| 451 |
+
with st.spinner("Fetching prediction..."):
|
| 452 |
+
pred_class, confidence, result = make_prediction(
|
| 453 |
+
st.session_state.selected_lat,
|
| 454 |
+
st.session_state.selected_lon,
|
| 455 |
+
selected_datetime
|
| 456 |
+
)
|
| 457 |
|
| 458 |
if pred_class is not None:
|
| 459 |
label_info = OCCUPANCY_LABELS[pred_class]
|
|
|
|
| 484 |
if isinstance(result, dict):
|
| 485 |
weather = result["weather"]
|
| 486 |
holidays = result["holidays"]
|
| 487 |
+
trip_info = result.get("trip_info")
|
| 488 |
|
| 489 |
+
day_type = "Holiday" if holidays.get("is_red_day") else (
|
| 490 |
+
"Work-free day" if holidays.get("is_work_free") else "Regular day"
|
| 491 |
)
|
| 492 |
|
| 493 |
+
if trip_info:
|
| 494 |
+
info_lines = []
|
| 495 |
+
|
| 496 |
+
# Route number or name
|
| 497 |
+
route_number = trip_info.get("route_short_name")
|
| 498 |
+
route_long_name = trip_info.get("route_long_name")
|
| 499 |
+
if route_number and route_long_name:
|
| 500 |
+
info_lines.append(f"{route_number} - {route_long_name}")
|
| 501 |
+
elif route_number:
|
| 502 |
+
info_lines.append(f"Route: {route_number}")
|
| 503 |
+
elif route_long_name:
|
| 504 |
+
info_lines.append(f"Route: {route_long_name}")
|
| 505 |
+
|
| 506 |
+
# Bus type / description
|
| 507 |
+
route_desc = trip_info.get("route_desc")
|
| 508 |
+
if route_desc:
|
| 509 |
+
info_lines.append(f"Type: {route_desc}")
|
| 510 |
+
|
| 511 |
+
# Trip ID
|
| 512 |
+
trip_id = trip_info.get("trip_id")
|
| 513 |
+
if trip_id is not None:
|
| 514 |
+
# Closest stop
|
| 515 |
+
static_stops_df = get_static_stops_df()
|
| 516 |
+
closest_stop = find_closest_stop(
|
| 517 |
+
st.session_state.selected_lat,
|
| 518 |
+
st.session_state.selected_lon,
|
| 519 |
+
trip_id,
|
| 520 |
+
static_stops_df
|
| 521 |
+
)
|
| 522 |
+
if closest_stop:
|
| 523 |
+
info_lines.append(f"Nearest stop: {closest_stop}")
|
| 524 |
+
|
| 525 |
+
if info_lines:
|
| 526 |
+
st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines))
|
| 527 |
+
|
| 528 |
+
# Weather conditions
|
| 529 |
+
conditions = []
|
| 530 |
+
temp = weather.get('temperature_2m')
|
| 531 |
+
if temp is not None:
|
| 532 |
+
conditions.append(f"{temp:.0f}°C")
|
| 533 |
+
|
| 534 |
+
if weather.get('snowfall', 0) > 0:
|
| 535 |
+
conditions.append("Snow")
|
| 536 |
+
if weather.get('rain', 0) > 0:
|
| 537 |
+
conditions.append("Rain")
|
| 538 |
+
|
| 539 |
+
conditions.append(day_type)
|
| 540 |
+
conditions.append(selected_datetime.strftime('%A'))
|
| 541 |
+
|
| 542 |
+
st.markdown("**Conditions:** " + " | ".join(conditions))
|
| 543 |
|
| 544 |
elif isinstance(result, str):
|
| 545 |
st.error(result)
|
contours.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Contour generation module for heatmap visualization.
|
| 3 |
+
|
| 4 |
+
Converts prediction grid data into GeoJSON polygons that can be rendered
|
| 5 |
+
as vector overlays on Folium maps. This provides zoom-independent visualization
|
| 6 |
+
similar to weather radar overlays.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.interpolate import griddata
|
| 11 |
+
from scipy.ndimage import binary_dilation
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use('Agg') # Non-interactive backend for server use
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from matplotlib.path import Path
|
| 16 |
+
from shapely.geometry import Polygon, MultiPolygon, mapping
|
| 17 |
+
from shapely.ops import unary_union
|
| 18 |
+
from shapely.validation import make_valid
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Color scheme: green -> lime -> yellow -> orange -> red
|
| 23 |
+
# Each class gets a distinct color for clear differentiation
|
| 24 |
+
CLASS_COLORS = {
|
| 25 |
+
0: "#22c55e", # Empty - green
|
| 26 |
+
1: "#84cc16", # Many seats - lime (green-yellow mix)
|
| 27 |
+
2: "#eab308", # Few seats - yellow
|
| 28 |
+
3: "#f97316", # Standing room - orange
|
| 29 |
+
4: "#ef4444", # Crushed standing - red
|
| 30 |
+
5: "#ef4444", # Full - red
|
| 31 |
+
6: "#6b7280", # Not accepting - gray
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Legacy contour colors (for backwards compatibility)
|
| 35 |
+
CONTOUR_COLORS = [
|
| 36 |
+
"#22c55e", # 0.0-0.2: Green (class 0 - empty)
|
| 37 |
+
"#eab308", # 0.2-0.4: Yellow (class 1 - many seats)
|
| 38 |
+
"#f97316", # 0.4-0.6: Orange (class 2 - few seats)
|
| 39 |
+
"#ef4444", # 0.6-0.8: Red (class 3 - standing)
|
| 40 |
+
"#7f1d1d", # 0.8-1.0: Dark red (class 4+ - crowded)
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
# Contour levels (intensity thresholds)
|
| 44 |
+
CONTOUR_LEVELS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _extract_polygons_from_contour(contour_set, level_idx):
|
| 48 |
+
"""
|
| 49 |
+
Extract polygons from a contourf result for a specific level.
|
| 50 |
+
Compatible with matplotlib 3.8+ which removed .collections attribute.
|
| 51 |
+
"""
|
| 52 |
+
polygons = []
|
| 53 |
+
|
| 54 |
+
# Try new API first (matplotlib 3.8+)
|
| 55 |
+
if hasattr(contour_set, 'get_paths'):
|
| 56 |
+
# New API: iterate through all paths
|
| 57 |
+
all_paths = contour_set.get_paths()
|
| 58 |
+
# In new API, paths are organized differently
|
| 59 |
+
# We need to use allsegs instead
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
# Use allsegs which works in both old and new matplotlib
|
| 63 |
+
if hasattr(contour_set, 'allsegs'):
|
| 64 |
+
if level_idx < len(contour_set.allsegs):
|
| 65 |
+
segments = contour_set.allsegs[level_idx]
|
| 66 |
+
for seg in segments:
|
| 67 |
+
if len(seg) >= 4:
|
| 68 |
+
try:
|
| 69 |
+
poly = Polygon(seg)
|
| 70 |
+
if not poly.is_valid:
|
| 71 |
+
poly = make_valid(poly)
|
| 72 |
+
if poly.is_valid and not poly.is_empty and poly.area > 0:
|
| 73 |
+
if isinstance(poly, MultiPolygon):
|
| 74 |
+
polygons.extend(poly.geoms)
|
| 75 |
+
else:
|
| 76 |
+
polygons.append(poly)
|
| 77 |
+
except Exception:
|
| 78 |
+
continue
|
| 79 |
+
return polygons
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def grid_to_contour_geojson(
|
| 83 |
+
prediction_data: list,
|
| 84 |
+
bounds: dict,
|
| 85 |
+
interpolation_resolution: int = 100,
|
| 86 |
+
fill_opacity: float = 0.35,
|
| 87 |
+
) -> dict:
|
| 88 |
+
"""
|
| 89 |
+
Convert prediction grid data to GeoJSON FeatureCollection with filled contour polygons.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
prediction_data: List of [lat, lon, intensity] where intensity is 0-1
|
| 93 |
+
bounds: Dict with min_lat, max_lat, min_lon, max_lon
|
| 94 |
+
interpolation_resolution: Number of points per axis for interpolation grid
|
| 95 |
+
fill_opacity: Opacity for the fill color (0-1)
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
GeoJSON FeatureCollection with colored polygon features
|
| 99 |
+
"""
|
| 100 |
+
if not prediction_data or len(prediction_data) < 4:
|
| 101 |
+
return _empty_feature_collection()
|
| 102 |
+
|
| 103 |
+
# Extract coordinates and values
|
| 104 |
+
points = np.array([[p[0], p[1]] for p in prediction_data]) # lat, lon
|
| 105 |
+
values = np.array([p[2] for p in prediction_data]) # intensity
|
| 106 |
+
|
| 107 |
+
# Create fine interpolation grid
|
| 108 |
+
lat_fine = np.linspace(bounds["min_lat"], bounds["max_lat"], interpolation_resolution)
|
| 109 |
+
lon_fine = np.linspace(bounds["min_lon"], bounds["max_lon"], interpolation_resolution)
|
| 110 |
+
lon_grid, lat_grid = np.meshgrid(lon_fine, lat_fine)
|
| 111 |
+
|
| 112 |
+
# Interpolate to fine grid (using lat,lon order for points)
|
| 113 |
+
try:
|
| 114 |
+
values_fine = griddata(
|
| 115 |
+
points, # (lat, lon) pairs
|
| 116 |
+
values,
|
| 117 |
+
(lat_grid, lon_grid), # grid in (lat, lon) format
|
| 118 |
+
method='cubic',
|
| 119 |
+
fill_value=0.0
|
| 120 |
+
)
|
| 121 |
+
except Exception:
|
| 122 |
+
# Fall back to linear if cubic fails
|
| 123 |
+
values_fine = griddata(
|
| 124 |
+
points,
|
| 125 |
+
values,
|
| 126 |
+
(lat_grid, lon_grid),
|
| 127 |
+
method='linear',
|
| 128 |
+
fill_value=0.0
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Clip values to valid range
|
| 132 |
+
values_fine = np.clip(values_fine, 0.0, 1.0)
|
| 133 |
+
|
| 134 |
+
# Generate contours using matplotlib (but don't display)
|
| 135 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
| 136 |
+
|
| 137 |
+
# contourf returns a QuadContourSet
|
| 138 |
+
contour_set = ax.contourf(
|
| 139 |
+
lon_grid, lat_grid, values_fine,
|
| 140 |
+
levels=CONTOUR_LEVELS,
|
| 141 |
+
extend='neither'
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
plt.close(fig) # Don't display, just extract the polygons
|
| 145 |
+
|
| 146 |
+
# Convert matplotlib contours to GeoJSON features
|
| 147 |
+
features = []
|
| 148 |
+
|
| 149 |
+
for level_idx in range(len(CONTOUR_LEVELS) - 1):
|
| 150 |
+
if level_idx >= len(CONTOUR_COLORS):
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
color = CONTOUR_COLORS[level_idx]
|
| 154 |
+
level_min = CONTOUR_LEVELS[level_idx]
|
| 155 |
+
level_max = CONTOUR_LEVELS[level_idx + 1]
|
| 156 |
+
|
| 157 |
+
# Extract polygons for this level
|
| 158 |
+
polygons = _extract_polygons_from_contour(contour_set, level_idx)
|
| 159 |
+
|
| 160 |
+
if not polygons:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
# Merge overlapping polygons at this level
|
| 164 |
+
try:
|
| 165 |
+
merged = unary_union(polygons)
|
| 166 |
+
if merged.is_empty:
|
| 167 |
+
continue
|
| 168 |
+
except Exception:
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# Create GeoJSON feature
|
| 172 |
+
feature = {
|
| 173 |
+
"type": "Feature",
|
| 174 |
+
"properties": {
|
| 175 |
+
"color": color,
|
| 176 |
+
"fillOpacity": fill_opacity,
|
| 177 |
+
"level_min": level_min,
|
| 178 |
+
"level_max": level_max,
|
| 179 |
+
"level_idx": level_idx,
|
| 180 |
+
},
|
| 181 |
+
"geometry": mapping(merged)
|
| 182 |
+
}
|
| 183 |
+
features.append(feature)
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
"type": "FeatureCollection",
|
| 187 |
+
"features": features
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _empty_feature_collection() -> dict:
|
| 192 |
+
"""Return an empty GeoJSON FeatureCollection."""
|
| 193 |
+
return {
|
| 194 |
+
"type": "FeatureCollection",
|
| 195 |
+
"features": []
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def grid_to_cells_geojson(
|
| 200 |
+
prediction_data: list,
|
| 201 |
+
lat_step: float,
|
| 202 |
+
lon_step: float,
|
| 203 |
+
fill_opacity: float = 0.35,
|
| 204 |
+
) -> dict:
|
| 205 |
+
"""
|
| 206 |
+
Convert prediction grid to GeoJSON rectangles - one cell per prediction point.
|
| 207 |
+
|
| 208 |
+
This is simpler and more accurate than contours:
|
| 209 |
+
- No interpolation artifacts
|
| 210 |
+
- Each cell shows the exact prediction for that area
|
| 211 |
+
- No fake background fill
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
prediction_data: List of [lat, lon, pred_class] where pred_class is 0-6
|
| 215 |
+
lat_step: Height of each cell in degrees
|
| 216 |
+
lon_step: Width of each cell in degrees
|
| 217 |
+
fill_opacity: Opacity for the fill color (0-1)
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
GeoJSON FeatureCollection with colored rectangle features
|
| 221 |
+
"""
|
| 222 |
+
if not prediction_data:
|
| 223 |
+
return _empty_feature_collection()
|
| 224 |
+
|
| 225 |
+
features = []
|
| 226 |
+
half_lat = lat_step / 2
|
| 227 |
+
half_lon = lon_step / 2
|
| 228 |
+
|
| 229 |
+
for lat, lon, pred_class in prediction_data:
|
| 230 |
+
pred_class = int(pred_class)
|
| 231 |
+
color = CLASS_COLORS.get(pred_class, CLASS_COLORS[0])
|
| 232 |
+
|
| 233 |
+
# Create rectangle centered on the prediction point
|
| 234 |
+
coords = [[
|
| 235 |
+
[lon - half_lon, lat - half_lat], # SW
|
| 236 |
+
[lon + half_lon, lat - half_lat], # SE
|
| 237 |
+
[lon + half_lon, lat + half_lat], # NE
|
| 238 |
+
[lon - half_lon, lat + half_lat], # NW
|
| 239 |
+
[lon - half_lon, lat - half_lat], # SW (close polygon)
|
| 240 |
+
]]
|
| 241 |
+
|
| 242 |
+
feature = {
|
| 243 |
+
"type": "Feature",
|
| 244 |
+
"properties": {
|
| 245 |
+
"color": color,
|
| 246 |
+
"fillOpacity": fill_opacity,
|
| 247 |
+
"pred_class": pred_class,
|
| 248 |
+
},
|
| 249 |
+
"geometry": {
|
| 250 |
+
"type": "Polygon",
|
| 251 |
+
"coordinates": coords
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
features.append(feature)
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"type": "FeatureCollection",
|
| 258 |
+
"features": features
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def precompute_contours_for_all_times(
|
| 263 |
+
prediction_func,
|
| 264 |
+
bounds: dict,
|
| 265 |
+
hours: list = None,
|
| 266 |
+
weekdays: list = None,
|
| 267 |
+
lat_steps: int = 20,
|
| 268 |
+
lon_steps: int = 25,
|
| 269 |
+
) -> dict:
|
| 270 |
+
"""
|
| 271 |
+
Precompute contour GeoJSON for all hour/weekday combinations.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
prediction_func: Function(lat, lon, hour, weekday) -> intensity (0-1)
|
| 275 |
+
bounds: Geographic bounds dict
|
| 276 |
+
hours: List of hours to compute (default: 5-23)
|
| 277 |
+
weekdays: List of weekdays to compute (default: 0-6)
|
| 278 |
+
lat_steps: Number of latitude grid points
|
| 279 |
+
lon_steps: Number of longitude grid points
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Dict mapping (hour, weekday) -> GeoJSON FeatureCollection
|
| 283 |
+
"""
|
| 284 |
+
if hours is None:
|
| 285 |
+
hours = list(range(5, 24)) # 5:00 to 23:00
|
| 286 |
+
if weekdays is None:
|
| 287 |
+
weekdays = list(range(7)) # Monday to Sunday
|
| 288 |
+
|
| 289 |
+
# Generate grid points
|
| 290 |
+
lats = np.linspace(bounds["min_lat"], bounds["max_lat"], lat_steps)
|
| 291 |
+
lons = np.linspace(bounds["min_lon"], bounds["max_lon"], lon_steps)
|
| 292 |
+
|
| 293 |
+
results = {}
|
| 294 |
+
total = len(hours) * len(weekdays)
|
| 295 |
+
count = 0
|
| 296 |
+
|
| 297 |
+
for hour in hours:
|
| 298 |
+
for weekday in weekdays:
|
| 299 |
+
count += 1
|
| 300 |
+
print(f"Generating contours {count}/{total}: hour={hour}, weekday={weekday}")
|
| 301 |
+
|
| 302 |
+
# Generate predictions for this time slot
|
| 303 |
+
prediction_data = []
|
| 304 |
+
for lat in lats:
|
| 305 |
+
for lon in lons:
|
| 306 |
+
try:
|
| 307 |
+
intensity = prediction_func(lat, lon, hour, weekday)
|
| 308 |
+
prediction_data.append([lat, lon, intensity])
|
| 309 |
+
except Exception:
|
| 310 |
+
prediction_data.append([lat, lon, 0.0])
|
| 311 |
+
|
| 312 |
+
# Convert to contour GeoJSON
|
| 313 |
+
geojson = grid_to_contour_geojson(prediction_data, bounds)
|
| 314 |
+
results[(hour, weekday)] = geojson
|
| 315 |
+
|
| 316 |
+
return results
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def save_contours_to_file(contours: dict, filepath: str):
|
| 320 |
+
"""
|
| 321 |
+
Save precomputed contours to a JSON file.
|
| 322 |
+
|
| 323 |
+
The dict keys (hour, weekday) are converted to strings for JSON serialization.
|
| 324 |
+
"""
|
| 325 |
+
# Convert tuple keys to string keys for JSON
|
| 326 |
+
json_compatible = {
|
| 327 |
+
f"{hour},{weekday}": geojson
|
| 328 |
+
for (hour, weekday), geojson in contours.items()
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
with open(filepath, 'w') as f:
|
| 332 |
+
json.dump(json_compatible, f)
|
| 333 |
+
|
| 334 |
+
print(f"Saved contours to {filepath}")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def load_contours_from_file(filepath: str) -> dict:
|
| 338 |
+
"""
|
| 339 |
+
Load precomputed contours from a JSON file.
|
| 340 |
+
|
| 341 |
+
Returns dict mapping (hour, weekday) tuple -> GeoJSON FeatureCollection.
|
| 342 |
+
"""
|
| 343 |
+
with open(filepath, 'r') as f:
|
| 344 |
+
data = json.load(f)
|
| 345 |
+
|
| 346 |
+
# Convert string keys back to tuple keys
|
| 347 |
+
return {
|
| 348 |
+
tuple(map(int, key.split(','))): geojson
|
| 349 |
+
for key, geojson in data.items()
|
| 350 |
+
}
|
predictor.py
CHANGED
|
@@ -7,6 +7,9 @@ Loads the XGBoost model from Hopsworks Model Registry and makes predictions.
|
|
| 7 |
import os
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Global model cache
|
| 12 |
_model = None
|
|
@@ -58,21 +61,29 @@ OCCUPANCY_LABELS = {
|
|
| 58 |
}
|
| 59 |
}
|
| 60 |
|
| 61 |
-
# Feature order expected by the model
|
| 62 |
-
# Must match training pipeline exactly
|
| 63 |
FEATURE_ORDER = [
|
| 64 |
-
"
|
|
|
|
| 65 |
"max_speed",
|
| 66 |
-
"speed_std",
|
| 67 |
"n_positions",
|
|
|
|
|
|
|
| 68 |
"lat_mean",
|
|
|
|
|
|
|
| 69 |
"lon_mean",
|
|
|
|
|
|
|
| 70 |
"hour",
|
| 71 |
"day_of_week",
|
| 72 |
"temperature_2m",
|
| 73 |
"precipitation",
|
| 74 |
"cloud_cover",
|
| 75 |
"wind_speed_10m",
|
|
|
|
|
|
|
| 76 |
"is_work_free",
|
| 77 |
"is_red_day",
|
| 78 |
"is_day_before_holiday",
|
|
@@ -81,10 +92,10 @@ FEATURE_ORDER = [
|
|
| 81 |
# Default values for vehicle features (we don't have real-time vehicle data)
|
| 82 |
# These are approximate averages from the training data
|
| 83 |
DEFAULT_VEHICLE_FEATURES = {
|
| 84 |
-
"avg_speed": 20.0, # typical urban bus speed (km/h)
|
| 85 |
"max_speed": 45.0, # typical max speed
|
| 86 |
-
"speed_std": 12.0, # typical speed variation
|
| 87 |
"n_positions": 30, # typical GPS points per trip window
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
|
|
@@ -101,6 +112,7 @@ def load_model():
|
|
| 101 |
|
| 102 |
# Check for API key before attempting connection
|
| 103 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
|
|
|
| 104 |
if not api_key:
|
| 105 |
raise ValueError("HOPSWORKS_API_KEY environment variable not set. Please add it in Space settings.")
|
| 106 |
|
|
@@ -109,11 +121,12 @@ def load_model():
|
|
| 109 |
from xgboost import XGBClassifier
|
| 110 |
|
| 111 |
print("Connecting to Hopsworks...")
|
| 112 |
-
project = hopsworks.login(api_key_value=api_key)
|
| 113 |
mr = project.get_model_registry()
|
| 114 |
|
| 115 |
print("Fetching model from registry...")
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
print(f"Downloading model version {model_entry.version}...")
|
| 119 |
model_dir = model_entry.download()
|
|
@@ -153,15 +166,23 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
|
|
| 153 |
# Assemble feature vector
|
| 154 |
features = {
|
| 155 |
# Vehicle features - use defaults
|
| 156 |
-
"
|
|
|
|
| 157 |
"max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
|
| 158 |
-
"speed_std": DEFAULT_VEHICLE_FEATURES["speed_std"],
|
| 159 |
"n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
|
| 160 |
|
| 161 |
-
# Location
|
|
|
|
|
|
|
| 162 |
"lat_mean": lat,
|
|
|
|
|
|
|
| 163 |
"lon_mean": lon,
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# Time
|
| 166 |
"hour": hour,
|
| 167 |
"day_of_week": day_of_week,
|
|
@@ -171,6 +192,8 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
|
|
| 171 |
"precipitation": weather.get("precipitation", 0.0),
|
| 172 |
"cloud_cover": weather.get("cloud_cover", 50.0),
|
| 173 |
"wind_speed_10m": weather.get("wind_speed_10m", 5.0),
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Holidays (convert bool to int)
|
| 176 |
"is_work_free": int(holidays.get("is_work_free", False)),
|
|
@@ -189,3 +212,65 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
|
|
| 189 |
confidence = float(probabilities[predicted_class])
|
| 190 |
|
| 191 |
return predicted_class, confidence, probabilities.tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import os
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
|
| 14 |
# Global model cache
|
| 15 |
_model = None
|
|
|
|
| 61 |
}
|
| 62 |
}
|
| 63 |
|
| 64 |
+
# Feature order expected by the model (occupancy_xgboost_model_new v4)
|
| 65 |
+
# Must match training pipeline exactly - includes lat/lon bounds and bearing
|
| 66 |
FEATURE_ORDER = [
|
| 67 |
+
"trip_id",
|
| 68 |
+
"vehicle_id",
|
| 69 |
"max_speed",
|
|
|
|
| 70 |
"n_positions",
|
| 71 |
+
"lat_min",
|
| 72 |
+
"lat_max",
|
| 73 |
"lat_mean",
|
| 74 |
+
"lon_min",
|
| 75 |
+
"lon_max",
|
| 76 |
"lon_mean",
|
| 77 |
+
"bearing_min",
|
| 78 |
+
"bearing_max",
|
| 79 |
"hour",
|
| 80 |
"day_of_week",
|
| 81 |
"temperature_2m",
|
| 82 |
"precipitation",
|
| 83 |
"cloud_cover",
|
| 84 |
"wind_speed_10m",
|
| 85 |
+
"rain",
|
| 86 |
+
"snowfall",
|
| 87 |
"is_work_free",
|
| 88 |
"is_red_day",
|
| 89 |
"is_day_before_holiday",
|
|
|
|
| 92 |
# Default values for vehicle features (we don't have real-time vehicle data)
|
| 93 |
# These are approximate averages from the training data
|
| 94 |
DEFAULT_VEHICLE_FEATURES = {
|
|
|
|
| 95 |
"max_speed": 45.0, # typical max speed
|
|
|
|
| 96 |
"n_positions": 30, # typical GPS points per trip window
|
| 97 |
+
"bearing_min": 0.0, # neutral bearing
|
| 98 |
+
"bearing_max": 360.0, # full range (stationary/unknown direction)
|
| 99 |
}
|
| 100 |
|
| 101 |
|
|
|
|
| 112 |
|
| 113 |
# Check for API key before attempting connection
|
| 114 |
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 115 |
+
project = os.environ.get("HOPSWORKS_PROJECT")
|
| 116 |
if not api_key:
|
| 117 |
raise ValueError("HOPSWORKS_API_KEY environment variable not set. Please add it in Space settings.")
|
| 118 |
|
|
|
|
| 121 |
from xgboost import XGBClassifier
|
| 122 |
|
| 123 |
print("Connecting to Hopsworks...")
|
| 124 |
+
project = hopsworks.login(project=project, api_key_value=api_key)
|
| 125 |
mr = project.get_model_registry()
|
| 126 |
|
| 127 |
print("Fetching model from registry...")
|
| 128 |
+
# Get version 4 explicitly (the model trained with 23 features)
|
| 129 |
+
model_entry = mr.get_model("occupancy_xgboost_model_new", version=4)
|
| 130 |
|
| 131 |
print(f"Downloading model version {model_entry.version}...")
|
| 132 |
model_dir = model_entry.download()
|
|
|
|
| 166 |
# Assemble feature vector
|
| 167 |
features = {
|
| 168 |
# Vehicle features - use defaults
|
| 169 |
+
"trip_id": 0, # placeholder
|
| 170 |
+
"vehicle_id": 0, # placeholder
|
| 171 |
"max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
|
|
|
|
| 172 |
"n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
|
| 173 |
|
| 174 |
+
# Location bounds (set equal to point for single-location prediction)
|
| 175 |
+
"lat_min": lat,
|
| 176 |
+
"lat_max": lat,
|
| 177 |
"lat_mean": lat,
|
| 178 |
+
"lon_min": lon,
|
| 179 |
+
"lon_max": lon,
|
| 180 |
"lon_mean": lon,
|
| 181 |
|
| 182 |
+
# Bearing (neutral values for point prediction)
|
| 183 |
+
"bearing_min": DEFAULT_VEHICLE_FEATURES["bearing_min"],
|
| 184 |
+
"bearing_max": DEFAULT_VEHICLE_FEATURES["bearing_max"],
|
| 185 |
+
|
| 186 |
# Time
|
| 187 |
"hour": hour,
|
| 188 |
"day_of_week": day_of_week,
|
|
|
|
| 192 |
"precipitation": weather.get("precipitation", 0.0),
|
| 193 |
"cloud_cover": weather.get("cloud_cover", 50.0),
|
| 194 |
"wind_speed_10m": weather.get("wind_speed_10m", 5.0),
|
| 195 |
+
"rain": weather.get("rain", 0.0),
|
| 196 |
+
"snowfall": weather.get("snowfall", 0.0),
|
| 197 |
|
| 198 |
# Holidays (convert bool to int)
|
| 199 |
"is_work_free": int(holidays.get("is_work_free", False)),
|
|
|
|
| 212 |
confidence = float(probabilities[predicted_class])
|
| 213 |
|
| 214 |
return predicted_class, confidence, probabilities.tolist()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def predict_occupancy_batch(locations, hour, day_of_week, weather, holidays):
|
| 218 |
+
"""
|
| 219 |
+
Predict occupancy for multiple locations in a single batch.
|
| 220 |
+
|
| 221 |
+
Much faster than calling predict_occupancy() in a loop.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
locations: List of (lat, lon) tuples
|
| 225 |
+
hour: Hour of day (0-23)
|
| 226 |
+
day_of_week: Day of week (0=Monday, 6=Sunday)
|
| 227 |
+
weather: Dict with temperature_2m, precipitation, cloud_cover, wind_speed_10m
|
| 228 |
+
holidays: Dict with is_work_free, is_red_day, is_day_before_holiday
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
List of (predicted_class, confidence) tuples
|
| 232 |
+
"""
|
| 233 |
+
model = load_model()
|
| 234 |
+
|
| 235 |
+
# Build all feature rows at once
|
| 236 |
+
rows = []
|
| 237 |
+
for lat, lon in locations:
|
| 238 |
+
rows.append({
|
| 239 |
+
"trip_id": 0,
|
| 240 |
+
"vehicle_id": 0,
|
| 241 |
+
"max_speed": DEFAULT_VEHICLE_FEATURES["max_speed"],
|
| 242 |
+
"n_positions": DEFAULT_VEHICLE_FEATURES["n_positions"],
|
| 243 |
+
"lat_min": lat,
|
| 244 |
+
"lat_max": lat,
|
| 245 |
+
"lat_mean": lat,
|
| 246 |
+
"lon_min": lon,
|
| 247 |
+
"lon_max": lon,
|
| 248 |
+
"lon_mean": lon,
|
| 249 |
+
"bearing_min": DEFAULT_VEHICLE_FEATURES["bearing_min"],
|
| 250 |
+
"bearing_max": DEFAULT_VEHICLE_FEATURES["bearing_max"],
|
| 251 |
+
"hour": hour,
|
| 252 |
+
"day_of_week": day_of_week,
|
| 253 |
+
"temperature_2m": weather.get("temperature_2m", 10.0),
|
| 254 |
+
"precipitation": weather.get("precipitation", 0.0),
|
| 255 |
+
"cloud_cover": weather.get("cloud_cover", 50.0),
|
| 256 |
+
"wind_speed_10m": weather.get("wind_speed_10m", 5.0),
|
| 257 |
+
"rain": weather.get("rain", 0.0),
|
| 258 |
+
"snowfall": weather.get("snowfall", 0.0),
|
| 259 |
+
"is_work_free": int(holidays.get("is_work_free", False)),
|
| 260 |
+
"is_red_day": int(holidays.get("is_red_day", False)),
|
| 261 |
+
"is_day_before_holiday": int(holidays.get("is_day_before_holiday", False)),
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
# Single DataFrame, single predict call
|
| 265 |
+
X = pd.DataFrame(rows)[FEATURE_ORDER]
|
| 266 |
+
probabilities = model.predict_proba(X)
|
| 267 |
+
|
| 268 |
+
# Extract results
|
| 269 |
+
results = []
|
| 270 |
+
for i, (lat, lon) in enumerate(locations):
|
| 271 |
+
probs = probabilities[i]
|
| 272 |
+
predicted_class = int(np.argmax(probs))
|
| 273 |
+
confidence = float(probs[predicted_class])
|
| 274 |
+
results.append((predicted_class, confidence))
|
| 275 |
+
|
| 276 |
+
return results
|
requirements.txt
CHANGED
|
@@ -8,3 +8,6 @@ pandas
|
|
| 8 |
numpy
|
| 9 |
requests
|
| 10 |
python-dotenv
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
numpy
|
| 9 |
requests
|
| 10 |
python-dotenv
|
| 11 |
+
scipy>=1.10.0
|
| 12 |
+
shapely>=2.0.0
|
| 13 |
+
matplotlib>=3.7.0
|
trip_info.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hopsworks
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def load_static_trip_info():
|
| 6 |
+
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 7 |
+
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 8 |
+
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 9 |
+
|
| 10 |
+
fs = project.get_feature_store()
|
| 11 |
+
fg = fs.get_feature_group("static_trip_info_fg", version=1) # adjust version
|
| 12 |
+
|
| 13 |
+
df = fg.read()
|
| 14 |
+
return df
|
| 15 |
+
|
| 16 |
+
def load_static_stops_info():
|
| 17 |
+
api_key = os.environ.get("HOPSWORKS_API_KEY")
|
| 18 |
+
project_name = os.environ.get("HOPSWORKS_PROJECT")
|
| 19 |
+
project = hopsworks.login(project=project_name, api_key_value=api_key)
|
| 20 |
+
|
| 21 |
+
fs = project.get_feature_store()
|
| 22 |
+
fg = fs.get_feature_group("static_trip_and_stops_info_fg", version=1) # adjust version
|
| 23 |
+
|
| 24 |
+
df = fg.read()
|
| 25 |
+
return df
|
| 26 |
+
|
| 27 |
+
def find_nearest_trip(lat, lon, datetime_obj, static_trip_df):
|
| 28 |
+
"""
|
| 29 |
+
Return the trip closest to the requested location/time.
|
| 30 |
+
Currently just filters static trips; could be enhanced with stops & routing.
|
| 31 |
+
"""
|
| 32 |
+
# For static data, we can only match by service_id/date/time if available
|
| 33 |
+
# Here we just pick a random trip as placeholder
|
| 34 |
+
if static_trip_df is None or len(static_trip_df) == 0:
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
trip = static_trip_df.sample(1).iloc[0] # pick 1 random trip for demo
|
| 38 |
+
return {
|
| 39 |
+
"trip_id": trip["trip_id"],
|
| 40 |
+
"route_short_name": trip["route_short_name"],
|
| 41 |
+
"route_long_name": trip["route_long_name"],
|
| 42 |
+
"trip_headsign": trip.get("trip_headsign", None)
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def find_closest_stop(lat, lon, trip_id, stops_df):
|
| 47 |
+
"""
|
| 48 |
+
Returns the closest stop to a given lat/lon for the specified trip_id.
|
| 49 |
+
"""
|
| 50 |
+
if stops_df is None:
|
| 51 |
+
return None
|
| 52 |
+
# Filter stops for this trip
|
| 53 |
+
trip_stops = stops_df[stops_df["trip_id"] == trip_id]
|
| 54 |
+
if trip_stops.empty:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
# Compute distances
|
| 58 |
+
lat_array = trip_stops["stop_lat"].to_numpy()
|
| 59 |
+
lon_array = trip_stops["stop_lon"].to_numpy()
|
| 60 |
+
|
| 61 |
+
distances = np.sqrt((lat_array - lat)**2 + (lon_array - lon)**2)
|
| 62 |
+
idx_min = distances.argmin()
|
| 63 |
+
closest_stop = trip_stops.iloc[idx_min]
|
| 64 |
+
|
| 65 |
+
return closest_stop["stop_name"]
|
weather.py
CHANGED
|
@@ -6,6 +6,7 @@ Uses Open-Meteo API to get weather forecasts.
|
|
| 6 |
|
| 7 |
import requests
|
| 8 |
from datetime import datetime
|
|
|
|
| 9 |
|
| 10 |
# Open-Meteo API
|
| 11 |
OPENMETEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
|
|
@@ -16,6 +17,8 @@ WEATHER_VARIABLES = [
|
|
| 16 |
"precipitation",
|
| 17 |
"cloud_cover",
|
| 18 |
"wind_speed_10m",
|
|
|
|
|
|
|
| 19 |
]
|
| 20 |
|
| 21 |
|
|
@@ -53,10 +56,25 @@ def get_weather_for_prediction(lat: float, lon: float, target_datetime: datetime
|
|
| 53 |
if days_ahead <= 0:
|
| 54 |
params["past_days"] = 1
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
if response.status_code != 200:
|
| 59 |
-
print(f"Weather API error: {response.status_code}")
|
| 60 |
return _default_weather()
|
| 61 |
|
| 62 |
data = response.json()
|
|
@@ -89,6 +107,8 @@ def get_weather_for_prediction(lat: float, lon: float, target_datetime: datetime
|
|
| 89 |
"precipitation": hourly.get("precipitation", [None])[idx] or 0.0,
|
| 90 |
"cloud_cover": hourly.get("cloud_cover", [None])[idx] or 50.0,
|
| 91 |
"wind_speed_10m": hourly.get("wind_speed_10m", [None])[idx] or 5.0,
|
|
|
|
|
|
|
| 92 |
}
|
| 93 |
|
| 94 |
except Exception as e:
|
|
@@ -103,4 +123,6 @@ def _default_weather() -> dict:
|
|
| 103 |
"precipitation": 0.0,
|
| 104 |
"cloud_cover": 50.0,
|
| 105 |
"wind_speed_10m": 5.0,
|
|
|
|
|
|
|
| 106 |
}
|
|
|
|
| 6 |
|
| 7 |
import requests
|
| 8 |
from datetime import datetime
|
| 9 |
+
import time
|
| 10 |
|
| 11 |
# Open-Meteo API
|
| 12 |
OPENMETEO_FORECAST_URL = "https://api.open-meteo.com/v1/forecast"
|
|
|
|
| 17 |
"precipitation",
|
| 18 |
"cloud_cover",
|
| 19 |
"wind_speed_10m",
|
| 20 |
+
"rain",
|
| 21 |
+
"snowfall"
|
| 22 |
]
|
| 23 |
|
| 24 |
|
|
|
|
| 56 |
if days_ahead <= 0:
|
| 57 |
params["past_days"] = 1
|
| 58 |
|
| 59 |
+
# Retry logic for transient failures
|
| 60 |
+
max_retries = 3
|
| 61 |
+
for attempt in range(max_retries):
|
| 62 |
+
try:
|
| 63 |
+
response = requests.get(OPENMETEO_FORECAST_URL, params=params, timeout=30)
|
| 64 |
+
if response.status_code == 200:
|
| 65 |
+
break
|
| 66 |
+
if attempt < max_retries - 1:
|
| 67 |
+
time.sleep(2 ** attempt) # Exponential backoff: 1s, 2s, 4s
|
| 68 |
+
except requests.exceptions.Timeout:
|
| 69 |
+
if attempt < max_retries - 1:
|
| 70 |
+
time.sleep(2 ** attempt)
|
| 71 |
+
continue
|
| 72 |
+
return _default_weather()
|
| 73 |
+
else:
|
| 74 |
+
print(f"Weather API error after {max_retries} retries: {response.status_code}")
|
| 75 |
+
return _default_weather()
|
| 76 |
|
| 77 |
if response.status_code != 200:
|
|
|
|
| 78 |
return _default_weather()
|
| 79 |
|
| 80 |
data = response.json()
|
|
|
|
| 107 |
"precipitation": hourly.get("precipitation", [None])[idx] or 0.0,
|
| 108 |
"cloud_cover": hourly.get("cloud_cover", [None])[idx] or 50.0,
|
| 109 |
"wind_speed_10m": hourly.get("wind_speed_10m", [None])[idx] or 5.0,
|
| 110 |
+
"rain": hourly.get("rain", [None])[idx] or 0.0,
|
| 111 |
+
"snowfall": hourly.get("snowfall", [None])[idx] or 0.0,
|
| 112 |
}
|
| 113 |
|
| 114 |
except Exception as e:
|
|
|
|
| 123 |
"precipitation": 0.0,
|
| 124 |
"cloud_cover": 50.0,
|
| 125 |
"wind_speed_10m": 5.0,
|
| 126 |
+
"rain": 0.0,
|
| 127 |
+
"snowfall": 0.0
|
| 128 |
}
|