Spaces:
Sleeping
Sleeping
| """ | |
| HappySardines - Bus Occupancy Predictor UI (Streamlit version) | |
| A Streamlit app with clickable map and heat map overlay for predicting | |
| bus crowding in Östergötland. | |
| """ | |
| import streamlit as st | |
| import json | |
| # Page config - MUST be first Streamlit command | |
| st.set_page_config( | |
| page_title="HappySardines", | |
| page_icon="🐟", | |
| layout="wide" | |
| ) | |
| import os | |
| import folium | |
| from streamlit_folium import st_folium | |
| import numpy as np | |
| from datetime import datetime, timedelta | |
| import hopsworks | |
| # Import prediction and data fetching modules | |
| from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS | |
| from weather import get_weather_for_prediction | |
| from holidays import get_holiday_features | |
| from trip_info import ( | |
| load_static_trip_info, find_nearest_trip, load_static_stops_info | |
| ) | |
| from contours import load_contours_from_file, grid_to_cells_geojson | |
| # Constants | |
| DEFAULT_LAT = 58.4108 | |
| DEFAULT_LON = 15.6214 | |
| DEFAULT_ZOOM = 9 # Slightly zoomed out to show more of the region | |
| # Bounds derived from actual GTFS stop locations (3119 stops) | |
| # Run ui/get_boundaries.py to recalculate if needed | |
| BOUNDS = { | |
| "min_lat": 56.6414, | |
| "max_lat": 58.8654, | |
| "min_lon": 14.6144, | |
| "max_lon": 16.9578, | |
| } | |
| # Color scheme for occupancy levels (must match contours.py CLASS_COLORS) | |
| OCCUPANCY_COLORS = { | |
| 0: "#22c55e", # Empty - green | |
| 1: "#84cc16", # Many seats - lime | |
| 2: "#eab308", # Few seats - yellow | |
| 3: "#f97316", # Standing - orange | |
| 4: "#ef4444", # Crushed - red | |
| 5: "#ef4444", # Full - red | |
| 6: "#6b7280", # Not accepting - gray | |
| } | |
| # Lazy-load static data (deferred to avoid blocking app startup) | |
| def get_static_trip_df(): | |
| """Load static trip info from Hopsworks (cached).""" | |
| try: | |
| return load_static_trip_info() | |
| except Exception as e: | |
| print(f"Warning: Could not load static trip info: {e}") | |
| return None | |
| def get_static_stops_df(): | |
| """Load static stops info from Hopsworks (cached).""" | |
| try: | |
| return load_static_stops_info() | |
| except Exception as e: | |
| print(f"Warning: Could not load static stops info: {e}") | |
| return None | |
| def is_stops_data_cached(): | |
| """Check if stops data is already in cache without triggering load.""" | |
| # Check if the cache has been populated by looking at session state | |
| return "stops_data_loaded" in st.session_state and st.session_state.stops_data_loaded | |
| def get_model(): | |
| """Load model once and cache it.""" | |
| try: | |
| return load_model() | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None | |
| def cached_predict_occupancy(lat, lon, hour, day_of_week, weather, holidays): | |
| return predict_occupancy(lat, lon, hour, day_of_week, weather, holidays) | |
| def fetch_trip_forecasts_from_hopsworks(): | |
| """ | |
| Fetch trip forecasts from Hopsworks forecast_fg. | |
| Returns DataFrame with columns: trip_id, hour, weekday, predicted_occupancy, confidence | |
| Returns None if forecast_fg doesn't exist or is empty. | |
| """ | |
| try: | |
| project = hopsworks.login() | |
| fs = project.get_feature_store() | |
| # Try v2 (new schema with hour/weekday), fall back to v1 | |
| for version in [2, 1]: | |
| try: | |
| forecast_fg = fs.get_feature_group("forecast_fg", version=version) | |
| df = forecast_fg.read() | |
| if df is not None and not df.empty: | |
| print(f"Loaded {len(df)} trip forecasts from Hopsworks v{version}") | |
| return df | |
| except Exception: | |
| continue | |
| return None | |
| except Exception as e: | |
| print(f"Could not load trip forecasts: {e}") | |
| return None | |
| def fetch_heatmaps_from_hopsworks(): | |
| """ | |
| Fetch all precomputed heatmaps from Hopsworks Feature Store. | |
| Tries v3 (high-res 40x50) first, falls back to v2 (low-res 20x25). | |
| Uses cache_resource to persist across reruns - only fetches once per session. | |
| Returns dict mapping (hour, weekday) -> GeoJSON FeatureCollection | |
| """ | |
| try: | |
| print("Fetching heatmaps from Hopsworks...") | |
| project = hopsworks.login() | |
| fs = project.get_feature_store() | |
| # Try v3 first (high-res), fall back to v2 (low-res) | |
| for version in [3, 2]: | |
| try: | |
| heatmap_fg = fs.get_feature_group("heatmap_geojson_fg", version=version) | |
| df = heatmap_fg.read() | |
| if df is not None and not df.empty: | |
| # Convert to dict with tuple keys | |
| heatmaps = {} | |
| for _, row in df.iterrows(): | |
| key = (int(row["hour"]), int(row["weekday"])) | |
| geojson = json.loads(row["geojson"]) | |
| heatmaps[key] = geojson | |
| print(f"Loaded {len(heatmaps)} heatmaps from Hopsworks v{version}") | |
| return heatmaps | |
| else: | |
| print(f"No data in Hopsworks v{version}, trying fallback...") | |
| except Exception as e: | |
| print(f"Could not fetch v{version}: {e}") | |
| continue | |
| print("No heatmap data found in any Hopsworks version") | |
| return {} | |
| except Exception as e: | |
| print(f"Error fetching heatmaps from Hopsworks: {e}") | |
| return {} | |
| def load_precomputed_contours(): | |
| """Load precomputed contour GeoJSON from file (not cached to pick up new files).""" | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| contours_path = os.path.join(script_dir, "precomputed_contours.json") | |
| if os.path.exists(contours_path): | |
| try: | |
| contours = load_contours_from_file(contours_path) | |
| print(f"Loaded {len(contours)} precomputed time slots from {contours_path}") | |
| return contours | |
| except Exception as e: | |
| print(f"Error loading contours: {e}") | |
| return {} | |
| print(f"Contours file not found: {contours_path}") | |
| return {} | |
| def generate_contours_on_demand(hour, day_of_week, weather, holidays): | |
| """ | |
| Generate grid cell GeoJSON on-demand if precomputed data is not available. | |
| This is slower but provides a fallback. | |
| """ | |
| model = get_model() | |
| if model is None: | |
| return None | |
| # Grid for on-demand generation (smaller for speed) | |
| lat_steps = 15 | |
| lon_steps = 20 | |
| lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps) | |
| lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps) | |
| lat_step = (BOUNDS["max_lat"] - BOUNDS["min_lat"]) / (lat_steps - 1) | |
| lon_step = (BOUNDS["max_lon"] - BOUNDS["min_lon"]) / (lon_steps - 1) | |
| prediction_data = [] | |
| for lat in lats: | |
| for lon in lons: | |
| try: | |
| pred_class, confidence, _ = cached_predict_occupancy( | |
| lat=lat, lon=lon, hour=hour, day_of_week=day_of_week, | |
| weather=weather, holidays=holidays | |
| ) | |
| prediction_data.append([lat, lon, pred_class]) | |
| except Exception: | |
| prediction_data.append([lat, lon, 0]) | |
| # Convert to GeoJSON grid cells | |
| return grid_to_cells_geojson(prediction_data, lat_step, lon_step) | |
| def get_test_contour_geojson(): | |
| """ | |
| Return a simple hardcoded test GeoJSON to verify rendering works. | |
| Creates a small grid of cells with different colors. | |
| """ | |
| # Create a 3x3 grid of test cells around Linköping | |
| center_lat = 58.41 | |
| center_lon = 15.62 | |
| cell_size = 0.15 | |
| # Test predictions: mix of classes | |
| test_data = [ | |
| (center_lat - cell_size, center_lon - cell_size, 0), # green | |
| (center_lat - cell_size, center_lon, 0), # green | |
| (center_lat - cell_size, center_lon + cell_size, 1), # green | |
| (center_lat, center_lon - cell_size, 0), # green | |
| (center_lat, center_lon, 2), # yellow | |
| (center_lat, center_lon + cell_size, 2), # yellow | |
| (center_lat + cell_size, center_lon - cell_size, 0), # green | |
| (center_lat + cell_size, center_lon, 3), # orange | |
| (center_lat + cell_size, center_lon + cell_size, 0), # green | |
| ] | |
| return grid_to_cells_geojson(test_data, cell_size, cell_size) | |
| def get_contour_geojson(hour, day_of_week, weather=None, holidays=None): | |
| """ | |
| Get contour GeoJSON for the given hour and day of week. | |
| Tries sources in order: | |
| 1. Hopsworks Feature Store (primary) | |
| 2. Local JSON file (fallback) | |
| 3. Test contours (last resort) | |
| """ | |
| key = (hour, day_of_week) | |
| # Try Hopsworks first (cached by @st.cache_resource - only fetches once) | |
| hopsworks_heatmaps = fetch_heatmaps_from_hopsworks() | |
| if hopsworks_heatmaps and key in hopsworks_heatmaps: | |
| geojson = hopsworks_heatmaps[key] | |
| n_features = len(geojson.get("features", [])) | |
| print(f"Found heatmap in Hopsworks for {key}: {n_features} features") | |
| return geojson | |
| # Fall back to local JSON file | |
| precomputed = load_precomputed_contours() | |
| if key in precomputed: | |
| geojson = precomputed[key] | |
| n_features = len(geojson.get("features", [])) | |
| print(f"Found heatmap in local file for {key}: {n_features} features") | |
| return geojson | |
| # Last resort: test contours | |
| print(f"No heatmap for {key}, using test contours") | |
| return get_test_contour_geojson() | |
| def create_map(selected_lat=None, selected_lon=None, show_heatmap=False, | |
| contour_geojson=None): | |
| """Create a Folium map with optional marker and contour overlay.""" | |
| center_lat = selected_lat if selected_lat else DEFAULT_LAT | |
| center_lon = selected_lon if selected_lon else DEFAULT_LON | |
| m = folium.Map( | |
| location=[center_lat, center_lon], | |
| zoom_start=DEFAULT_ZOOM, | |
| tiles="CartoDB positron" | |
| ) | |
| # Add contour overlay if enabled | |
| if show_heatmap and contour_geojson and contour_geojson.get("features"): | |
| # Add each contour level as a separate GeoJSON layer | |
| folium.GeoJson( | |
| contour_geojson, | |
| style_function=lambda feature: { | |
| 'fillColor': feature['properties']['color'], | |
| 'fillOpacity': feature['properties'].get('fillOpacity', 0.35), | |
| 'color': 'none', # No border | |
| 'weight': 0 | |
| }, | |
| name="Crowding Forecast" | |
| ).add_to(m) | |
| # Add coverage area rectangle (subtle border) | |
| folium.Rectangle( | |
| bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]], | |
| [BOUNDS["max_lat"], BOUNDS["max_lon"]]], | |
| color="#6b7280", | |
| fill=False, | |
| weight=1, | |
| opacity=0.3, | |
| dash_array="5, 5", | |
| ).add_to(m) | |
| # Add marker if location selected | |
| if selected_lat and selected_lon: | |
| folium.Marker( | |
| [selected_lat, selected_lon], | |
| tooltip=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}", | |
| icon=folium.Icon(color="blue", icon="info-sign") | |
| ).add_to(m) | |
| return m | |
| def make_prediction(lat, lon, selected_datetime, skip_trip_info=False): | |
| """Make prediction and return formatted result. | |
| Args: | |
| skip_trip_info: If True, skip the slow trip info lookup | |
| """ | |
| if lat is None or lon is None: | |
| return None, None, None | |
| # Check bounds | |
| if not (BOUNDS["min_lat"] <= lat <= BOUNDS["max_lat"] and | |
| BOUNDS["min_lon"] <= lon <= BOUNDS["max_lon"]): | |
| return None, None, "Location outside coverage area" | |
| try: | |
| weather = get_weather_for_prediction(lat, lon, selected_datetime) | |
| holidays = get_holiday_features(selected_datetime) | |
| pred_class, confidence, probs = cached_predict_occupancy( | |
| lat=lat, lon=lon, | |
| hour=selected_datetime.hour, | |
| day_of_week=selected_datetime.weekday(), | |
| weather=weather, | |
| holidays=holidays | |
| ) | |
| # Find nearest trip from static data (only if not skipping) | |
| trip_info = None | |
| trip_forecast = None | |
| if not skip_trip_info: | |
| static_stops_df = get_static_stops_df() | |
| # Mark that we've loaded the data (for future quick checks) | |
| st.session_state.stops_data_loaded = True | |
| if static_stops_df is not None: | |
| trip_info = find_nearest_trip(lat, lon, selected_datetime, static_stops_df) | |
| # Try to get trip forecast if available | |
| if trip_info and trip_info.get("trip_id"): | |
| forecasts_df = fetch_trip_forecasts_from_hopsworks() | |
| if forecasts_df is not None: | |
| trip_id = trip_info["trip_id"] | |
| hour = selected_datetime.hour | |
| weekday = selected_datetime.weekday() | |
| # Find matching forecast | |
| match = forecasts_df[ | |
| (forecasts_df["trip_id"] == trip_id) & | |
| (forecasts_df["hour"] == hour) & | |
| (forecasts_df["weekday"] == weekday) | |
| ] | |
| if not match.empty: | |
| row = match.iloc[0] | |
| trip_forecast = { | |
| "predicted_occupancy": int(row.get("predicted_occupancy", 0)), | |
| "confidence": float(row.get("confidence", 0)), | |
| } | |
| return pred_class, confidence, { | |
| "weather": weather, | |
| "holidays": holidays, | |
| "datetime": selected_datetime, | |
| "trip_info": trip_info, | |
| "trip_forecast": trip_forecast | |
| } | |
| except Exception as e: | |
| return None, None, str(e) | |
| # Initialize session state | |
| if "selected_lat" not in st.session_state: | |
| st.session_state.selected_lat = DEFAULT_LAT | |
| if "selected_lon" not in st.session_state: | |
| st.session_state.selected_lon = DEFAULT_LON | |
| # Header | |
| st.title("HappySardines") | |
| st.markdown("*Predicted bus crowding in Östergötland*") | |
| # Check if model is available | |
| model = get_model() | |
| if model is None: | |
| st.error("Could not load prediction model. Please check the configuration.") | |
| st.stop() | |
| # Sidebar controls | |
| with st.sidebar: | |
| st.header("Settings") | |
| # Date/time selection | |
| st.subheader("When?") | |
| date_option = st.radio("Date", ["Today", "Tomorrow"], horizontal=True) | |
| hour = st.slider("Hour", 5, 23, 8) | |
| today = datetime.now().date() | |
| selected_date = today if date_option == "Today" else today + timedelta(days=1) | |
| selected_datetime = datetime.combine(selected_date, datetime.min.time().replace(hour=hour)) | |
| st.markdown(f"**{selected_datetime.strftime('%A, %B %d at %H:00')}**") | |
| st.divider() | |
| # View mode | |
| st.subheader("View Mode") | |
| show_heatmap = st.toggle("Show Crowding Forecast", value=False, | |
| help="Display predicted crowding across the region") | |
| if show_heatmap: | |
| st.markdown(""" | |
| **Legend:** | |
| <div style="display: flex; flex-direction: column; gap: 4px; font-size: 14px;"> | |
| <div><span style="display: inline-block; width: 16px; height: 16px; background: #22c55e; border-radius: 3px; vertical-align: middle;"></span> Empty</div> | |
| <div><span style="display: inline-block; width: 16px; height: 16px; background: #84cc16; border-radius: 3px; vertical-align: middle;"></span> Many seats</div> | |
| <div><span style="display: inline-block; width: 16px; height: 16px; background: #eab308; border-radius: 3px; vertical-align: middle;"></span> Few seats</div> | |
| <div><span style="display: inline-block; width: 16px; height: 16px; background: #f97316; border-radius: 3px; vertical-align: middle;"></span> Standing room</div> | |
| <div><span style="display: inline-block; width: 16px; height: 16px; background: #ef4444; border-radius: 3px; vertical-align: middle;"></span> Crowded</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.divider() | |
| # About | |
| with st.expander("About this tool"): | |
| st.markdown(""" | |
| **How it works:** | |
| This tool predicts bus crowding levels based on: | |
| - 📍 Location | |
| - 🕐 Time of day | |
| - 📅 Day of week | |
| - 🌡️ Weather conditions | |
| - 🇸🇪 Holidays | |
| **Data sources:** | |
| - Bus occupancy data from Östgötatrafiken (KODA API) | |
| - Weather from Open-Meteo | |
| - Holidays from Svenska Dagar API | |
| **Built for KTH ID2223** | |
| """) | |
| # Main content | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("Click on the map to select a location") | |
| # Only fetch heatmap data when toggle is ON | |
| contour_geojson = None | |
| if show_heatmap: | |
| # fetch_heatmaps_from_hopsworks() is cached - only slow on first call | |
| # weather/holidays not needed - already baked into precomputed heatmaps | |
| contour_geojson = get_contour_geojson(hour, selected_date.weekday()) | |
| # Create and display map | |
| m = create_map( | |
| selected_lat=st.session_state.selected_lat, | |
| selected_lon=st.session_state.selected_lon, | |
| show_heatmap=show_heatmap, | |
| contour_geojson=contour_geojson | |
| ) | |
| # Render the map | |
| # Use returned_objects to only trigger rerun on clicks, not zoom/pan | |
| map_data = st_folium( | |
| m, | |
| height=500, | |
| use_container_width=True, | |
| returned_objects=["last_clicked"], | |
| key="main_map" | |
| ) | |
| # Handle map clicks | |
| if map_data and map_data.get("last_clicked"): | |
| clicked = map_data["last_clicked"] | |
| st.session_state.selected_lat = clicked["lat"] | |
| st.session_state.selected_lon = clicked["lng"] | |
| st.rerun() | |
| with col2: | |
| st.subheader("Prediction") | |
| # Show selected coordinates | |
| st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}") | |
| # Check if stops data is already cached (fast check) | |
| stops_already_loaded = st.session_state.get("stops_data_loaded", False) | |
| # Make prediction (skip trip info on first load to be fast) | |
| with st.spinner("Fetching prediction..."): | |
| pred_class, confidence, result = make_prediction( | |
| st.session_state.selected_lat, | |
| st.session_state.selected_lon, | |
| selected_datetime, | |
| skip_trip_info=not stops_already_loaded | |
| ) | |
| if pred_class is not None: | |
| label_info = OCCUPANCY_LABELS[pred_class] | |
| color = OCCUPANCY_COLORS[pred_class] | |
| # Result card | |
| st.markdown(f""" | |
| <div style=" | |
| background: linear-gradient(135deg, {color}22, {color}11); | |
| border-left: 4px solid {color}; | |
| border-radius: 12px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| "> | |
| <div style="font-size: 1.3em; font-weight: 600; color: {color};"> | |
| {label_info['icon']} {label_info['label']} | |
| </div> | |
| <div style="margin-top: 8px; color: #374151;"> | |
| {label_info['message']} | |
| </div> | |
| <div style="margin-top: 12px; font-size: 0.9em; opacity: 0.8;"> | |
| Confidence: {confidence:.0%} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Context info | |
| if isinstance(result, dict): | |
| weather = result["weather"] | |
| holidays = result["holidays"] | |
| trip_info = result.get("trip_info") | |
| day_type = "Holiday" if holidays.get("is_red_day") else ( | |
| "Work-free day" if holidays.get("is_work_free") else "Regular day" | |
| ) | |
| if trip_info: | |
| info_lines = [] | |
| # Route number or name | |
| route_number = trip_info.get("route_short_name") | |
| route_long_name = trip_info.get("route_long_name") | |
| if route_number and route_long_name: | |
| info_lines.append(f"{route_number} - {route_long_name}") | |
| elif route_number: | |
| info_lines.append(f"Route: {route_number}") | |
| elif route_long_name: | |
| info_lines.append(f"Route: {route_long_name}") | |
| # Bus type / description | |
| route_desc = trip_info.get("route_desc") | |
| if route_desc: | |
| info_lines.append(f"Type: {route_desc}") | |
| # Closest stop from trip info (already computed) | |
| closest_stop = trip_info.get("closest_stop") | |
| if closest_stop: | |
| info_lines.append(f"Nearest stop: {closest_stop}") | |
| # Distance to stop | |
| distance = trip_info.get("distance_m") | |
| if distance is not None: | |
| info_lines.append(f"Distance: {distance}m") | |
| if info_lines: | |
| st.markdown("**Bus Info:**\n- " + "\n- ".join(info_lines)) | |
| elif not stops_already_loaded: | |
| # Offer to load trip info (it's slow on first load) | |
| if st.button("Load nearby bus info", help="First load takes ~1-2 minutes"): | |
| with st.spinner("Loading trip data from Hopsworks (this may take a minute)..."): | |
| # Trigger the load and rerun | |
| get_static_stops_df() | |
| st.session_state.stops_data_loaded = True | |
| st.rerun() | |
| # Show trip-specific forecast if available | |
| trip_forecast = result.get("trip_forecast") | |
| if trip_forecast: | |
| forecast_class = trip_forecast["predicted_occupancy"] | |
| forecast_conf = trip_forecast["confidence"] | |
| forecast_label = OCCUPANCY_LABELS.get(forecast_class, OCCUPANCY_LABELS[0]) | |
| forecast_color = OCCUPANCY_COLORS.get(forecast_class, "#6b7280") | |
| st.markdown(f""" | |
| <div style=" | |
| background: {forecast_color}11; | |
| border: 1px solid {forecast_color}44; | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 8px 0; | |
| "> | |
| <div style="font-size: 0.85em; color: #6b7280; margin-bottom: 4px;"> | |
| Trip-specific forecast: | |
| </div> | |
| <div style="font-weight: 600; color: {forecast_color};"> | |
| {forecast_label['icon']} {forecast_label['label']} ({forecast_conf:.0%}) | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Weather conditions | |
| conditions = [] | |
| temp = weather.get('temperature_2m') | |
| if temp is not None: | |
| conditions.append(f"{temp:.0f}°C") | |
| if weather.get('snowfall', 0) > 0: | |
| conditions.append("Snow") | |
| if weather.get('rain', 0) > 0: | |
| conditions.append("Rain") | |
| conditions.append(day_type) | |
| conditions.append(selected_datetime.strftime('%A')) | |
| st.markdown("**Conditions:** " + " | ".join(conditions)) | |
| elif isinstance(result, str): | |
| st.error(result) | |
| else: | |
| st.info("Click on the map to select a location") | |
| # Footer | |
| st.divider() | |
| st.markdown( | |
| "<div style='text-align: center; opacity: 0.6;'>Built for KTH ID2223 - Scalable Machine Learning</div>", | |
| unsafe_allow_html=True | |
| ) | |