AxelHolst commited on
Commit
1f415af
·
1 Parent(s): 4e356c1

Switch to Streamlit with clickable map and heat map overlay

Browse files
Files changed (4) hide show
  1. README.md +14 -13
  2. app.py +252 -295
  3. predictor.py +1 -37
  4. requirements.txt +3 -3
README.md CHANGED
@@ -3,8 +3,8 @@ title: HappySardines
3
  emoji: 🐟
4
  colorFrom: blue
5
  colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
@@ -15,11 +15,18 @@ short_description: Predict bus crowding levels in Östergötland, Sweden
15
 
16
  **How packed are buses in Östergötland?**
17
 
18
- Drop a pin on the map, pick a time, and find out how crowded buses typically are in that area. Built with ML using historical transit data from Östgötatrafiken.
 
 
 
 
 
 
 
19
 
20
  ## How it works
21
 
22
- This tool predicts typical bus crowding levels based on:
23
  - **Location** - Different areas have different ridership patterns
24
  - **Time** - Rush hours vs. off-peak
25
  - **Day of week** - Weekdays vs. weekends
@@ -28,22 +35,16 @@ This tool predicts typical bus crowding levels based on:
28
 
29
  ## Data sources
30
 
31
- - Historical bus occupancy data from Östgötatrafiken (GTFS-RT, Nov-Dec 2025)
32
  - Weather forecasts from [Open-Meteo](https://open-meteo.com/)
33
  - Swedish holiday calendar from [Svenska Dagar API](https://sholiday.faboul.se/)
34
 
35
- ## Limitations
36
-
37
- - Predictions are based on historical patterns, not real-time data
38
- - Accuracy varies by location and time
39
- - The model predicts general area crowding, not specific bus lines
40
-
41
  ## Technical details
42
 
43
- - **Model**: XGBoost Classifier trained on ~6M trip records
44
  - **Features**: Location, time, weather, holidays
45
  - **Feature Store**: Hopsworks
46
- - **Framework**: Gradio
47
 
48
  ## Credits
49
 
 
3
  emoji: 🐟
4
  colorFrom: blue
5
  colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.28.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
15
 
16
  **How packed are buses in Östergötland?**
17
 
18
+ Click on the map to select a location, pick a time, and see predicted crowding levels. Toggle the heat map to see crowding patterns across the entire region.
19
+
20
+ ## Features
21
+
22
+ - 🗺️ **Interactive map** - Click to select any location
23
+ - 🔥 **Heat map overlay** - See predicted crowding across the region
24
+ - 🌡️ **Real-time weather** - Forecasts from Open-Meteo
25
+ - 📅 **Holiday awareness** - Swedish red days and work-free days
26
 
27
  ## How it works
28
 
29
+ This tool predicts bus crowding levels based on:
30
  - **Location** - Different areas have different ridership patterns
31
  - **Time** - Rush hours vs. off-peak
32
  - **Day of week** - Weekdays vs. weekends
 
35
 
36
  ## Data sources
37
 
38
+ - Bus occupancy data from Östgötatrafiken (GTFS-RT)
39
  - Weather forecasts from [Open-Meteo](https://open-meteo.com/)
40
  - Swedish holiday calendar from [Svenska Dagar API](https://sholiday.faboul.se/)
41
 
 
 
 
 
 
 
42
  ## Technical details
43
 
44
+ - **Model**: XGBoost Classifier
45
  - **Features**: Location, time, weather, holidays
46
  - **Feature Store**: Hopsworks
47
+ - **Framework**: Streamlit + Folium
48
 
49
  ## Credits
50
 
app.py CHANGED
@@ -1,42 +1,35 @@
1
  """
2
- HappySardines - Bus Occupancy Predictor UI
3
 
4
- A Gradio app that predicts how crowded buses are in Östergötland based on
5
- location, time, weather, and holidays.
6
  """
7
 
8
  import os
9
- import gradio as gr
10
  import folium
 
 
 
11
  from datetime import datetime, timedelta
12
 
13
  # Import prediction and data fetching modules
14
- from predictor import predict_occupancy, predict_occupancy_mock, OCCUPANCY_LABELS
15
  from weather import get_weather_for_prediction
16
  from holidays import get_holiday_features
17
 
18
- # Try to load model on startup, fall back to mock
19
- USE_MOCK = os.environ.get("USE_MOCK", "false").lower() == "true"
 
 
 
 
20
 
21
- if not USE_MOCK:
22
- try:
23
- from predictor import load_model
24
- load_model()
25
- print("Model loaded successfully - using real predictions")
26
- except Exception as e:
27
- print(f"Could not load model: {e}")
28
- print("Using mock predictions for testing")
29
- USE_MOCK = True
30
-
31
- # Select predictor function
32
- _predict_fn = predict_occupancy_mock if USE_MOCK else predict_occupancy
33
-
34
- # Default map center: Linköping
35
  DEFAULT_LAT = 58.4108
36
  DEFAULT_LON = 15.6214
37
- DEFAULT_ZOOM = 11
38
 
39
- # Östergötland bounds (roughly)
40
  BOUNDS = {
41
  "min_lat": 57.8,
42
  "max_lat": 58.9,
@@ -44,335 +37,299 @@ BOUNDS = {
44
  "max_lon": 16.8
45
  }
46
 
47
- # Preset locations for quick selection
48
- PRESET_LOCATIONS = {
49
- "Linköping Central": (58.4158, 15.6253),
50
- "Norrköping Central": (58.5942, 16.1826),
51
- "Linköping University": (58.3980, 15.5762),
52
- "Mjärdevi Science Park": (58.4027, 15.5672),
53
- "Motala": (58.5375, 15.0364),
54
- "Finspång": (58.7050, 15.7700),
 
55
  }
56
 
57
 
58
- def create_map(lat=DEFAULT_LAT, lon=DEFAULT_LON):
59
- """Create a Folium map centered on location with marker."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  m = folium.Map(
61
- location=[lat, lon],
62
  zoom_start=DEFAULT_ZOOM,
63
  tiles="CartoDB positron"
64
  )
65
 
66
- # Add marker at selected location
67
- folium.Marker(
68
- [lat, lon],
69
- popup=f"Selected: {lat:.4f}, {lon:.4f}",
70
- icon=folium.Icon(color="blue", icon="bus", prefix="fa")
71
- ).add_to(m)
72
-
73
- # Add a rectangle showing the coverage area
74
  folium.Rectangle(
75
  bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
76
  [BOUNDS["max_lat"], BOUNDS["max_lon"]]],
77
  color="#3388ff",
78
  fill=False,
79
- weight=1,
80
- opacity=0.3,
81
  popup="Coverage area"
82
  ).add_to(m)
83
 
84
- return m._repr_html_()
85
-
86
-
87
- def make_prediction(lat, lon, date_choice, hour):
88
- """
89
- Make occupancy prediction for given inputs.
90
-
91
- Returns formatted result HTML.
92
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if lat is None or lon is None:
94
- return create_result_card(
95
- "Please select a location",
96
- "Use the preset buttons or enter coordinates.",
97
- "gray",
98
- None
99
- )
100
 
101
- # Validate coordinates are in Östergötland
102
  if not (BOUNDS["min_lat"] <= lat <= BOUNDS["max_lat"] and
103
  BOUNDS["min_lon"] <= lon <= BOUNDS["max_lon"]):
104
- return create_result_card(
105
- "Location outside coverage area",
106
- f"Please select a location within Östergötland. Selected: {lat:.4f}, {lon:.4f}",
107
- "gray",
108
- None
109
- )
110
-
111
- # Determine date
112
- today = datetime.now().date()
113
- if date_choice == "Today":
114
- selected_date = today
115
- else: # Tomorrow
116
- selected_date = today + timedelta(days=1)
117
-
118
- selected_datetime = datetime.combine(selected_date, datetime.min.time().replace(hour=int(hour)))
119
 
120
  try:
121
- # Get weather forecast
122
  weather = get_weather_for_prediction(lat, lon, selected_datetime)
123
-
124
- # Get holiday features
125
  holidays = get_holiday_features(selected_datetime)
126
 
127
- # Make prediction
128
- prediction, confidence, probabilities = _predict_fn(
129
- lat=lat,
130
- lon=lon,
131
- hour=int(hour),
132
- day_of_week=selected_date.weekday(),
133
  weather=weather,
134
  holidays=holidays
135
  )
136
 
137
- # Format result
138
- label_info = OCCUPANCY_LABELS[prediction]
 
 
 
 
 
139
 
140
- # Build context string
141
- day_name = selected_date.strftime("%A")
142
- day_type = "Holiday" if holidays.get("is_red_day") else ("Work-free day" if holidays.get("is_work_free") else "Regular day")
143
- temp = weather.get("temperature_2m", "?")
144
 
145
- context = f"{temp:.0f}°C • {day_name} • {day_type}"
 
 
 
 
146
 
147
- return create_result_card(
148
- label_info["label"],
149
- label_info["message"],
150
- label_info["color"],
151
- context,
152
- confidence
153
- )
154
 
155
- except Exception as e:
156
- return create_result_card(
157
- "Prediction failed",
158
- f"Error: {str(e)}",
159
- "gray",
160
- None
161
- )
162
 
 
 
 
163
 
164
- def create_result_card(title, message, color, context, confidence=None):
165
- """Create HTML result card."""
166
- color_map = {
167
- "green": "#22c55e",
168
- "yellow": "#eab308",
169
- "orange": "#f97316",
170
- "red": "#ef4444",
171
- "gray": "#6b7280"
172
- }
173
- bg_color = color_map.get(color, "#6b7280")
174
-
175
- confidence_html = ""
176
- if confidence is not None:
177
- confidence_html = f'<div style="font-size: 0.9em; opacity: 0.8;">Confidence: {confidence:.0%}</div>'
178
-
179
- context_html = ""
180
- if context:
181
- context_html = f'<div style="margin-top: 15px; font-size: 0.9em; opacity: 0.7;">{context}</div>'
182
-
183
- return f"""
184
- <div style="
185
- background: linear-gradient(135deg, {bg_color}22, {bg_color}11);
186
- border-left: 4px solid {bg_color};
187
- border-radius: 12px;
188
- padding: 24px;
189
- margin: 10px 0;
190
- ">
191
- <div style="
192
- font-size: 1.4em;
193
- font-weight: 600;
194
- color: {bg_color};
195
- margin-bottom: 8px;
196
- ">{title}</div>
197
- <div style="
198
- font-size: 1.1em;
199
- color: #374151;
200
- line-height: 1.5;
201
- ">{message}</div>
202
- {confidence_html}
203
- {context_html}
204
- </div>
205
- """
206
-
207
-
208
- # Custom CSS
209
- CUSTOM_CSS = """
210
- .main-title {
211
- text-align: center;
212
- margin-bottom: 0;
213
- }
214
- .subtitle {
215
- text-align: center;
216
- color: #6b7280;
217
- margin-top: 5px;
218
- margin-bottom: 20px;
219
- }
220
- .location-btn {
221
- margin: 2px !important;
222
- }
223
- """
224
 
225
- # Build Gradio interface
226
- with gr.Blocks(
227
- title="HappySardines",
228
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
229
- css=CUSTOM_CSS
230
- ) as app:
231
-
232
- # Header
233
- gr.Markdown("# 🐟 HappySardines", elem_classes=["main-title"])
234
- gr.Markdown("*How packed are buses in Östergötland?*", elem_classes=["subtitle"])
235
-
236
- with gr.Row():
237
- # Left column: Map and location
238
- with gr.Column(scale=2):
239
- gr.Markdown("### Select Location")
240
-
241
- # Quick location buttons
242
- gr.Markdown("**Quick select:**")
243
- with gr.Row():
244
- location_buttons = []
245
- for name in list(PRESET_LOCATIONS.keys())[:3]:
246
- btn = gr.Button(name, size="sm", elem_classes=["location-btn"])
247
- location_buttons.append((name, btn))
248
- with gr.Row():
249
- for name in list(PRESET_LOCATIONS.keys())[3:]:
250
- btn = gr.Button(name, size="sm", elem_classes=["location-btn"])
251
- location_buttons.append((name, btn))
252
-
253
- # Coordinate inputs
254
- with gr.Row():
255
- lat_input = gr.Number(
256
- label="Latitude",
257
- value=DEFAULT_LAT,
258
- precision=4,
259
- minimum=BOUNDS["min_lat"],
260
- maximum=BOUNDS["max_lat"]
261
- )
262
- lon_input = gr.Number(
263
- label="Longitude",
264
- value=DEFAULT_LON,
265
- precision=4,
266
- minimum=BOUNDS["min_lon"],
267
- maximum=BOUNDS["max_lon"]
268
- )
269
 
270
- # Map display
271
- map_display = gr.HTML(value=create_map())
272
 
273
- # Right column: Time and predict
274
- with gr.Column(scale=1):
275
- gr.Markdown("### When?")
276
 
277
- date_choice = gr.Radio(
278
- choices=["Today", "Tomorrow"],
279
- value="Today",
280
- label="Date"
281
- )
282
 
283
- hour_slider = gr.Slider(
284
- minimum=5,
285
- maximum=23,
286
- value=8,
287
- step=1,
288
- label="Hour",
289
- info="Select time of day (24h format)"
290
- )
291
 
292
- time_display = gr.Markdown("**Selected: 08:00**")
 
 
 
 
 
 
293
 
294
- predict_btn = gr.Button("Predict Crowding", variant="primary", size="lg")
295
 
296
- # Result section
297
- gr.Markdown("### Prediction")
298
- result_display = gr.HTML(
299
- value=create_result_card(
300
- "Select location and time",
301
- "Then click 'Predict Crowding' to see the forecast.",
302
- "gray",
303
- None
304
- )
305
- )
306
-
307
- # About section
308
- with gr.Accordion("About this tool", open=False):
309
- gr.Markdown("""
310
  **How it works:**
311
 
312
- This tool predicts typical bus crowding levels based on:
313
- - **Location** - Different areas have different ridership patterns
314
- - **Time** - Rush hours vs. off-peak
315
- - **Day of week** - Weekdays vs. weekends
316
- - **Weather** - Temperature, precipitation, etc.
317
- - **Holidays** - Swedish red days and work-free days
318
 
319
  **Data sources:**
320
- - Historical bus occupancy data from Östgötatrafiken (GTFS-RT, Nov-Dec 2025)
321
- - Weather forecasts from Open-Meteo
322
- - Swedish holiday calendar from Svenska Dagar API
323
-
324
- **Limitations:**
325
- - Accuracy varies by location and time
326
- - The model predicts general area crowding, not specific bus lines
327
 
328
- **Built for KTH ID2223 - Scalable Machine Learning and Deep Learning**
329
  """)
330
 
331
- # Event handlers
332
- def update_time_display(hour):
333
- return f"**Selected: {int(hour):02d}:00**"
334
 
335
- def update_location(name):
336
- lat, lon = PRESET_LOCATIONS[name]
337
- return lat, lon, create_map(lat, lon)
338
 
339
- def update_map_from_coords(lat, lon):
340
- if lat is not None and lon is not None:
341
- return create_map(lat, lon)
342
- return create_map()
343
 
344
- hour_slider.change(
345
- fn=update_time_display,
346
- inputs=[hour_slider],
347
- outputs=[time_display]
 
 
348
  )
349
 
350
- # Connect location buttons
351
- for name, btn in location_buttons:
352
- btn.click(
353
- fn=lambda n=name: update_location(n),
354
- outputs=[lat_input, lon_input, map_display]
355
- )
356
-
357
- # Update map when coordinates change
358
- lat_input.change(
359
- fn=update_map_from_coords,
360
- inputs=[lat_input, lon_input],
361
- outputs=[map_display]
362
- )
363
- lon_input.change(
364
- fn=update_map_from_coords,
365
- inputs=[lat_input, lon_input],
366
- outputs=[map_display]
367
  )
368
 
369
- predict_btn.click(
370
- fn=make_prediction,
371
- inputs=[lat_input, lon_input, date_choice, hour_slider],
372
- outputs=[result_display]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  )
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- # For local testing
377
- if __name__ == "__main__":
378
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ HappySardines - Bus Occupancy Predictor UI (Streamlit version)
3
 
4
+ A Streamlit app with clickable map and heat map overlay for predicting
5
+ bus crowding in Östergötland.
6
  """
7
 
8
  import os
9
+ import streamlit as st
10
  import folium
11
+ from folium.plugins import HeatMap
12
+ from streamlit_folium import st_folium
13
+ import numpy as np
14
  from datetime import datetime, timedelta
15
 
16
  # Import prediction and data fetching modules
17
+ from predictor import predict_occupancy, load_model, OCCUPANCY_LABELS
18
  from weather import get_weather_for_prediction
19
  from holidays import get_holiday_features
20
 
21
+ # Page config
22
+ st.set_page_config(
23
+ page_title="HappySardines",
24
+ page_icon="🐟",
25
+ layout="wide"
26
+ )
27
 
28
+ # Constants
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  DEFAULT_LAT = 58.4108
30
  DEFAULT_LON = 15.6214
31
+ DEFAULT_ZOOM = 10
32
 
 
33
  BOUNDS = {
34
  "min_lat": 57.8,
35
  "max_lat": 58.9,
 
37
  "max_lon": 16.8
38
  }
39
 
40
+ # Color scheme for occupancy levels
41
+ OCCUPANCY_COLORS = {
42
+ 0: "#22c55e", # Empty - green
43
+ 1: "#22c55e", # Many seats - green
44
+ 2: "#eab308", # Few seats - yellow
45
+ 3: "#f97316", # Standing - orange
46
+ 4: "#ef4444", # Crushed - red
47
+ 5: "#ef4444", # Full - red
48
+ 6: "#6b7280", # Not accepting - gray
49
  }
50
 
51
 
52
+ @st.cache_resource
53
+ def get_model():
54
+ """Load model once and cache it."""
55
+ try:
56
+ return load_model()
57
+ except Exception as e:
58
+ st.error(f"Failed to load model: {e}")
59
+ return None
60
+
61
+
62
+ def generate_heatmap_data(hour, day_of_week, weather, holidays):
63
+ """Generate heat map data by predicting crowding across a grid."""
64
+ model = get_model()
65
+ if model is None:
66
+ return []
67
+
68
+ # Create grid of points across Östergötland
69
+ lat_steps = 15
70
+ lon_steps = 20
71
+ lats = np.linspace(BOUNDS["min_lat"], BOUNDS["max_lat"], lat_steps)
72
+ lons = np.linspace(BOUNDS["min_lon"], BOUNDS["max_lon"], lon_steps)
73
+
74
+ heatmap_data = []
75
+
76
+ for lat in lats:
77
+ for lon in lons:
78
+ try:
79
+ pred_class, confidence, _ = predict_occupancy(
80
+ lat=lat, lon=lon, hour=hour, day_of_week=day_of_week,
81
+ weather=weather, holidays=holidays
82
+ )
83
+ # Weight by occupancy level (higher = more crowded = more intense)
84
+ intensity = pred_class / 5.0 # Normalize to 0-1
85
+ if intensity > 0.1: # Only show if there's some crowding
86
+ heatmap_data.append([lat, lon, intensity])
87
+ except Exception:
88
+ pass
89
+
90
+ return heatmap_data
91
+
92
+
93
+ def create_map(selected_lat=None, selected_lon=None, show_heatmap=False,
94
+ heatmap_data=None):
95
+ """Create a Folium map with optional marker and heatmap."""
96
+ center_lat = selected_lat if selected_lat else DEFAULT_LAT
97
+ center_lon = selected_lon if selected_lon else DEFAULT_LON
98
+
99
  m = folium.Map(
100
+ location=[center_lat, center_lon],
101
  zoom_start=DEFAULT_ZOOM,
102
  tiles="CartoDB positron"
103
  )
104
 
105
+ # Add coverage area rectangle
 
 
 
 
 
 
 
106
  folium.Rectangle(
107
  bounds=[[BOUNDS["min_lat"], BOUNDS["min_lon"]],
108
  [BOUNDS["max_lat"], BOUNDS["max_lon"]]],
109
  color="#3388ff",
110
  fill=False,
111
+ weight=2,
112
+ opacity=0.5,
113
  popup="Coverage area"
114
  ).add_to(m)
115
 
116
+ # Add heatmap if enabled
117
+ if show_heatmap and heatmap_data:
118
+ HeatMap(
119
+ heatmap_data,
120
+ min_opacity=0.3,
121
+ radius=25,
122
+ blur=15,
123
+ gradient={0.2: 'green', 0.4: 'yellow', 0.6: 'orange', 0.8: 'red'}
124
+ ).add_to(m)
125
+
126
+ # Add marker if location selected
127
+ if selected_lat and selected_lon:
128
+ folium.Marker(
129
+ [selected_lat, selected_lon],
130
+ popup=f"Selected: {selected_lat:.4f}, {selected_lon:.4f}",
131
+ icon=folium.Icon(color="blue", icon="bus", prefix="fa")
132
+ ).add_to(m)
133
+
134
+ return m
135
+
136
+
137
+ def make_prediction(lat, lon, selected_datetime):
138
+ """Make prediction and return formatted result."""
139
  if lat is None or lon is None:
140
+ return None, None, None
 
 
 
 
 
141
 
142
+ # Check bounds
143
  if not (BOUNDS["min_lat"] <= lat <= BOUNDS["max_lat"] and
144
  BOUNDS["min_lon"] <= lon <= BOUNDS["max_lon"]):
145
+ return None, None, "Location outside coverage area"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  try:
 
148
  weather = get_weather_for_prediction(lat, lon, selected_datetime)
 
 
149
  holidays = get_holiday_features(selected_datetime)
150
 
151
+ pred_class, confidence, probs = predict_occupancy(
152
+ lat=lat, lon=lon,
153
+ hour=selected_datetime.hour,
154
+ day_of_week=selected_datetime.weekday(),
 
 
155
  weather=weather,
156
  holidays=holidays
157
  )
158
 
159
+ return pred_class, confidence, {
160
+ "weather": weather,
161
+ "holidays": holidays,
162
+ "datetime": selected_datetime
163
+ }
164
+ except Exception as e:
165
+ return None, None, str(e)
166
 
 
 
 
 
167
 
168
+ # Initialize session state
169
+ if "selected_lat" not in st.session_state:
170
+ st.session_state.selected_lat = DEFAULT_LAT
171
+ if "selected_lon" not in st.session_state:
172
+ st.session_state.selected_lon = DEFAULT_LON
173
 
174
+ # Header
175
+ st.title("🐟 HappySardines")
176
+ st.markdown("*How packed are buses in Östergötland?*")
 
 
 
 
177
 
178
+ # Check if model is available
179
+ model = get_model()
180
+ if model is None:
181
+ st.error("⚠️ Could not load prediction model. Please check the configuration.")
182
+ st.stop()
 
 
183
 
184
+ # Sidebar controls
185
+ with st.sidebar:
186
+ st.header("Settings")
187
 
188
+ # Date/time selection
189
+ st.subheader("When?")
190
+ date_option = st.radio("Date", ["Today", "Tomorrow"], horizontal=True)
191
+ hour = st.slider("Hour", 5, 23, 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ today = datetime.now().date()
194
+ selected_date = today if date_option == "Today" else today + timedelta(days=1)
195
+ selected_datetime = datetime.combine(selected_date, datetime.min.time().replace(hour=hour))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ st.markdown(f"**{selected_datetime.strftime('%A, %B %d at %H:00')}**")
 
198
 
199
+ st.divider()
 
 
200
 
201
+ # View mode
202
+ st.subheader("View Mode")
203
+ show_heatmap = st.toggle("Show Crowding Forecast", value=False,
204
+ help="Display predicted crowding across the region")
 
205
 
206
+ if show_heatmap:
207
+ st.info("🔥 Heat map shows predicted crowding levels. Red = busy, Green = quiet.")
 
 
 
 
 
 
208
 
209
+ if st.button("Generate Heat Map", type="primary"):
210
+ with st.spinner("Generating predictions across region..."):
211
+ weather = get_weather_for_prediction(DEFAULT_LAT, DEFAULT_LON, selected_datetime)
212
+ holidays = get_holiday_features(selected_datetime)
213
+ st.session_state.heatmap_data = generate_heatmap_data(
214
+ hour, selected_date.weekday(), weather, holidays
215
+ )
216
 
217
+ st.divider()
218
 
219
+ # About
220
+ with st.expander("About this tool"):
221
+ st.markdown("""
 
 
 
 
 
 
 
 
 
 
 
222
  **How it works:**
223
 
224
+ This tool predicts bus crowding levels based on:
225
+ - 📍 Location
226
+ - 🕐 Time of day
227
+ - 📅 Day of week
228
+ - 🌡️ Weather conditions
229
+ - 🎉 Holidays
230
 
231
  **Data sources:**
232
+ - Bus occupancy data from Östgötatrafiken
233
+ - Weather from Open-Meteo
234
+ - Holidays from Svenska Dagar API
 
 
 
 
235
 
236
+ **Built for KTH ID2223**
237
  """)
238
 
239
+ # Main content
240
+ col1, col2 = st.columns([2, 1])
 
241
 
242
+ with col1:
243
+ st.subheader("📍 Click on the map to select a location")
 
244
 
245
+ # Get heatmap data if available
246
+ heatmap_data = st.session_state.get("heatmap_data", [])
 
 
247
 
248
+ # Create and display map
249
+ m = create_map(
250
+ selected_lat=st.session_state.selected_lat,
251
+ selected_lon=st.session_state.selected_lon,
252
+ show_heatmap=show_heatmap,
253
+ heatmap_data=heatmap_data
254
  )
255
 
256
+ map_data = st_folium(
257
+ m,
258
+ height=500,
259
+ width=None,
260
+ returned_objects=["last_clicked"],
261
+ key="map"
 
 
 
 
 
 
 
 
 
 
 
262
  )
263
 
264
+ # Handle map clicks
265
+ if map_data and map_data.get("last_clicked"):
266
+ clicked = map_data["last_clicked"]
267
+ st.session_state.selected_lat = clicked["lat"]
268
+ st.session_state.selected_lon = clicked["lng"]
269
+ st.rerun()
270
+
271
+ with col2:
272
+ st.subheader("🔮 Prediction")
273
+
274
+ # Show selected coordinates
275
+ st.markdown(f"**Location:** {st.session_state.selected_lat:.4f}, {st.session_state.selected_lon:.4f}")
276
+
277
+ # Make prediction
278
+ pred_class, confidence, result = make_prediction(
279
+ st.session_state.selected_lat,
280
+ st.session_state.selected_lon,
281
+ selected_datetime
282
  )
283
 
284
+ if pred_class is not None:
285
+ label_info = OCCUPANCY_LABELS[pred_class]
286
+ color = OCCUPANCY_COLORS[pred_class]
287
+
288
+ # Result card
289
+ st.markdown(f"""
290
+ <div style="
291
+ background: linear-gradient(135deg, {color}22, {color}11);
292
+ border-left: 4px solid {color};
293
+ border-radius: 12px;
294
+ padding: 20px;
295
+ margin: 10px 0;
296
+ ">
297
+ <div style="font-size: 1.3em; font-weight: 600; color: {color};">
298
+ {label_info['icon']} {label_info['label']}
299
+ </div>
300
+ <div style="margin-top: 8px; color: #374151;">
301
+ {label_info['message']}
302
+ </div>
303
+ <div style="margin-top: 12px; font-size: 0.9em; opacity: 0.8;">
304
+ Confidence: {confidence:.0%}
305
+ </div>
306
+ </div>
307
+ """, unsafe_allow_html=True)
308
+
309
+ # Context info
310
+ if isinstance(result, dict):
311
+ weather = result["weather"]
312
+ holidays = result["holidays"]
313
+
314
+ day_type = "🎉 Holiday" if holidays.get("is_red_day") else (
315
+ "🏖️ Work-free day" if holidays.get("is_work_free") else "📅 Regular day"
316
+ )
317
 
318
+ st.markdown(f"""
319
+ **Conditions:**
320
+ - 🌡️ {weather.get('temperature_2m', '?'):.0f}°C
321
+ - {day_type}
322
+ - {selected_datetime.strftime('%A')}
323
+ """)
324
+
325
+ elif isinstance(result, str):
326
+ st.error(result)
327
+ else:
328
+ st.info("Click on the map to select a location")
329
+
330
+ # Footer
331
+ st.divider()
332
+ st.markdown(
333
+ "<div style='text-align: center; opacity: 0.6;'>Built for KTH ID2223 - Scalable Machine Learning</div>",
334
+ unsafe_allow_html=True
335
+ )
predictor.py CHANGED
@@ -102,7 +102,7 @@ def load_model():
102
  # Check for API key before attempting connection
103
  api_key = os.environ.get("HOPSWORKS_API_KEY")
104
  if not api_key:
105
- raise ValueError("HOPSWORKS_API_KEY not set - using mock predictions")
106
 
107
  try:
108
  import hopsworks
@@ -189,39 +189,3 @@ def predict_occupancy(lat, lon, hour, day_of_week, weather, holidays):
189
  confidence = float(probabilities[predicted_class])
190
 
191
  return predicted_class, confidence, probabilities.tolist()
192
-
193
-
194
- # Mock prediction for testing without Hopsworks
195
- def predict_occupancy_mock(lat, lon, hour, day_of_week, weather, holidays):
196
- """
197
- Mock prediction for testing UI without model.
198
- """
199
- # Simple heuristic based on time
200
- if 7 <= hour <= 9 or 16 <= hour <= 18:
201
- # Rush hour
202
- if holidays.get("is_work_free") or holidays.get("is_red_day"):
203
- predicted_class = 1 # Holiday rush hour = many seats
204
- else:
205
- predicted_class = 2 if hour < 8 or hour > 17 else 3 # Peak = standing
206
- elif 10 <= hour <= 15:
207
- predicted_class = 1 # Midday = many seats
208
- else:
209
- predicted_class = 0 # Early/late = empty
210
-
211
- # Mock probabilities
212
- probabilities = [0.1] * 7
213
- probabilities[predicted_class] = 0.6
214
- confidence = 0.6
215
-
216
- return predicted_class, confidence, probabilities
217
-
218
-
219
- # For testing - use mock if model not available
220
- def get_predictor():
221
- """Get the appropriate predictor function."""
222
- try:
223
- load_model()
224
- return predict_occupancy
225
- except Exception as e:
226
- print(f"Using mock predictor: {e}")
227
- return predict_occupancy_mock
 
102
  # Check for API key before attempting connection
103
  api_key = os.environ.get("HOPSWORKS_API_KEY")
104
  if not api_key:
105
+ raise ValueError("HOPSWORKS_API_KEY environment variable not set. Please add it in Space settings.")
106
 
107
  try:
108
  import hopsworks
 
189
  confidence = float(probabilities[predicted_class])
190
 
191
  return predicted_class, confidence, probabilities.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- gradio==4.44.0
2
- huggingface_hub>=0.24.0,<1.0.0
 
3
  hopsworks==4.2.*
4
  xgboost>=2.0.0
5
  scikit-learn
@@ -7,4 +8,3 @@ pandas
7
  numpy
8
  requests
9
  python-dotenv
10
- folium>=0.15.0
 
1
+ streamlit>=1.28.0
2
+ streamlit-folium>=0.15.0
3
+ folium>=0.15.0
4
  hopsworks==4.2.*
5
  xgboost>=2.0.0
6
  scikit-learn
 
8
  numpy
9
  requests
10
  python-dotenv