Spaces:
Running
Running
| from datetime import datetime # type: ignore | |
| import sys # type: ignore | |
| from pathlib import Path | |
| # Add parent directory to path to allow imports from api/ | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| # Add src directory to path to allow imports from same directory | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from api.bus_cache import AsyncBusCache # type: ignore | |
| from api.utils import hms_to_seconds, get_service_day_start_ts, translate_occupancy # type: ignore | |
| from db_manager import init_db # type: ignore | |
| from dotenv import load_dotenv # type: ignore | |
| from fastapi import FastAPI, HTTPException # type: ignore | |
| from fastapi.middleware.cors import CORSMiddleware # type: ignore | |
| load_dotenv() | |
| ttc_cache = AsyncBusCache(ttl=20) | |
| # Initialize database connection globally | |
| db = init_db() | |
| app = FastAPI(title="WheresMyBus v2.0 API") | |
| # Setup CORS for your React frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, use your actual React URL | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health_check(): | |
| """Simple health check endpoint""" | |
| return "backend is running" | |
| async def get_vehicles(): | |
| data = await ttc_cache.get_data() | |
| vehicles = data.get("vehicles", []) | |
| return { | |
| "status": "success", | |
| "count": len(vehicles), | |
| "vehicles": vehicles | |
| } | |
| async def get_all_routes(): | |
| """ | |
| Returns a complete list of TTC routes with their display names and colors. | |
| """ | |
| try: | |
| # Run the query against DuckDB | |
| # We handle missing colors by providing defaults (TTC Red: #FF0000) | |
| query = """ | |
| SELECT | |
| route_id, | |
| route_short_name, | |
| route_long_name, | |
| COALESCE(route_color, 'FF0000') as route_color, | |
| COALESCE(route_text_color, 'FFFFFF') as route_text_color | |
| FROM routes | |
| ORDER BY | |
| CASE | |
| WHEN CAST(route_short_name AS VARCHAR) ~ '^[0-9]+$' THEN CAST(route_short_name AS INTEGER) | |
| ELSE 999 | |
| END, | |
| route_short_name; | |
| """ | |
| results = db.execute(query).fetchall() | |
| # Convert to a clean list of dictionaries | |
| route_list = [ | |
| { | |
| "id": r[0], | |
| "number": r[1], | |
| "name": r[2], | |
| "color": f"#{r[3]}", | |
| "text_color": f"#{r[4]}" | |
| } | |
| for r in results | |
| ] | |
| return { | |
| "status": "success", | |
| "count": len(route_list), | |
| "routes": route_list | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_route_view(route_id: str): | |
| data = await ttc_cache.get_data() | |
| all_buses = data.get("vehicles", []) | |
| route_buses = [v for v in all_buses if v['route'] == route_id] | |
| if not route_buses: | |
| return {"route": route_id, "vehicles": []} | |
| # IMPORTANT: Cast Trip IDs to strings to ensure they match the DB | |
| trip_ids = [str(v['trip_id']) for v in route_buses] | |
| placeholders = ','.join(['?'] * len(trip_ids)) | |
| # We use CAST(? AS VARCHAR) to force DuckDB to match strings to strings | |
| query = f""" | |
| SELECT | |
| CAST(st.trip_id AS VARCHAR), | |
| CAST(st.stop_id AS VARCHAR), | |
| st.arrival_time, | |
| t.trip_headsign | |
| FROM stop_times st | |
| JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR) | |
| WHERE CAST(st.trip_id AS VARCHAR) IN ({placeholders}) | |
| """ | |
| db_rows = db.execute(query, trip_ids).fetchall() | |
| # Check if we got ANYTHING back from the DB | |
| if not db_rows: | |
| print(f"DEBUG: No matches in DB for Trip IDs: {trip_ids[:3]}") | |
| schedule_map = {(r[0], r[1]): r[2] for r in db_rows} | |
| name_map = {r[0]: r[3] for r in db_rows} | |
| service_day_ts = get_service_day_start_ts() | |
| enriched = [] | |
| for bus in route_buses: | |
| # Default delay is 0 if no prediction exists | |
| raw_delay_mins = 0 | |
| pred_time = bus.get('predicted_time') | |
| stop_id = bus.get('next_stop_id') | |
| if pred_time and stop_id: | |
| sched_hms = schedule_map.get((str(bus['trip_id']), str(stop_id))) | |
| if sched_hms: | |
| # Math: (Reality Unix - Plan Unix) / 60 | |
| plan_ts = service_day_ts + hms_to_seconds(sched_hms) | |
| raw_delay_mins = round((pred_time - plan_ts) / 60) | |
| enriched.append({ | |
| "number": bus['id'], | |
| "name": name_map.get(str(bus['trip_id']), "Not in Schedule"), # This is the destination | |
| "location": {"lat": bus['lat'], "lon": bus['lon']}, | |
| "delay_mins": raw_delay_mins, # Actual integer: 5 = 5m late, -2 = 2m early | |
| "fullness": translate_occupancy(bus['occupancy']) | |
| }) | |
| return { | |
| "route": route_id, | |
| "count": len(enriched), | |
| "vehicles": enriched | |
| } | |
| async def get_vehicle_view(vehicle_id: str): | |
| # 1. Pull latest from cache | |
| data = await ttc_cache.get_data() | |
| vehicles = data.get("vehicles", []) | |
| # 2. Find this specific bus in the list | |
| bus = next((v for v in vehicles if str(v['id']) == vehicle_id), None) | |
| if not bus: | |
| raise HTTPException(status_code=404, detail="Vehicle not active or not found") | |
| trip_id = str(bus['trip_id']) | |
| next_stop_id = bus.get('next_stop_id') | |
| predicted_time = bus.get('predicted_time') | |
| # 3. Handshake with Database (Cast to VARCHAR to avoid type errors) | |
| # We get the destination name and the specific scheduled arrival time | |
| destination = "Not in Schedule" | |
| delay_mins = 0 | |
| if next_stop_id: | |
| query = """ | |
| SELECT | |
| t.trip_headsign, | |
| st.arrival_time | |
| FROM trips t | |
| JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR) | |
| WHERE CAST(t.trip_id AS VARCHAR) = ? | |
| AND CAST(st.stop_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id, str(next_stop_id)]).fetchone() | |
| if row: | |
| destination = row[0] | |
| scheduled_hms = row[1] | |
| # Math: Reality (Unix Time) - Plan (Service Day + Scheduled Seconds) | |
| if predicted_time: | |
| service_day_ts = get_service_day_start_ts() | |
| plan_ts = service_day_ts + hms_to_seconds(scheduled_hms) | |
| delay_mins = round((predicted_time - plan_ts) / 60) | |
| else: | |
| # If no next_stop_id, try to get destination from trip_id only | |
| query = """ | |
| SELECT trip_headsign | |
| FROM trips | |
| WHERE CAST(trip_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id]).fetchone() | |
| if row: | |
| destination = row[0] | |
| return { | |
| "vehicle_number": vehicle_id, | |
| "route_id": bus['route'], | |
| "name": destination, | |
| "location": { | |
| "lat": bus['lat'], | |
| "lon": bus['lon'] | |
| }, | |
| "delay_mins": delay_mins, | |
| "fullness": translate_occupancy(bus['occupancy']), | |
| "trip_id": trip_id | |
| } | |
| async def get_stop_view(stop_code: str): | |
| # 1. Translate Pole Number to Database ID | |
| stop_info = db.execute("SELECT stop_id, stop_name FROM stops WHERE CAST(stop_code AS VARCHAR) = ? LIMIT 1", [str(stop_code)]).fetchone() | |
| if not stop_info: | |
| return {"error": "Stop code not found"} | |
| target_id = str(stop_info[0]) | |
| stop_name = stop_info[1] | |
| # 2. Get the Cache structure (dict with vehicles, predictions, alerts) | |
| cached_data = await ttc_cache.get_data() | |
| vehicles_list = cached_data.get("vehicles", []) | |
| predictions = cached_data.get("predictions", {}) | |
| # Build vehicles map for quick lookup | |
| vehicles = {str(v['trip_id']): v for v in vehicles_list} | |
| now = datetime.now().timestamp() | |
| two_hours_out = now + 7200 | |
| arrivals = [] | |
| # 3. Search the FULL itineraries for our target_id | |
| for trip_id, itinerary in predictions.items(): | |
| if target_id in itinerary: | |
| pred_time = itinerary[target_id] | |
| # Only include if the bus hasn't passed the stop yet and is within 2 hours | |
| if now <= pred_time <= two_hours_out: | |
| # 4. Handshake with DB for destination and schedule | |
| query = """ | |
| SELECT t.trip_headsign, st.arrival_time, r.route_short_name | |
| FROM trips t | |
| JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR) | |
| JOIN routes r ON t.route_id = r.route_id | |
| WHERE CAST(t.trip_id AS VARCHAR) = ? AND CAST(st.stop_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id, target_id]).fetchone() | |
| if row: | |
| # Find the actual bus for fullness (if it's on the road) | |
| bus = vehicles.get(trip_id) | |
| plan_ts = get_service_day_start_ts() + hms_to_seconds(row[1]) | |
| arrivals.append({ | |
| "route": row[2], | |
| "destination": row[0], | |
| "eta_mins": round((pred_time - now) / 60), | |
| "delay_mins": round((pred_time - plan_ts) / 60), | |
| "fullness": translate_occupancy(bus['occupancy']) if bus else "Unknown", | |
| "vehicle_id": bus['id'] if bus else "In Transit" | |
| }) | |
| arrivals.sort(key=lambda x: x['eta_mins']) | |
| return {"stop_name": stop_name, "stop_code": stop_code, "arrivals": arrivals} | |
| async def get_all_alerts(): | |
| """ | |
| Returns every active service alert for the entire TTC network. | |
| """ | |
| data = await ttc_cache.get_data() | |
| return { | |
| "timestamp": datetime.now().timestamp(), | |
| "count": len(data["alerts"]), | |
| "alerts": data["alerts"] | |
| } | |
| async def get_alerts_for_route(route_id: str): | |
| data = await ttc_cache.get_data() | |
| alerts = data.get("alerts", {}) | |
| route_alerts = alerts.get(route_id, []) | |
| if not route_alerts: | |
| return { | |
| "route_id": route_id, | |
| "count": 0, | |
| "alerts": "No alerts" | |
| } | |
| return { | |
| "route_id": route_id, | |
| "count": len(route_alerts), | |
| "alerts": route_alerts | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn # type: ignore | |
| # Start the server | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |