WMB2Backened / src /app.py
42Cummer's picture
Upload app.py
00f5c2a verified
raw
history blame
13.8 kB
from datetime import datetime # type: ignore
import sys # type: ignore
from pathlib import Path
# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))
# Add src directory to path to allow imports from same directory
sys.path.insert(0, str(Path(__file__).parent))
from api.bus_cache import AsyncBusCache # type: ignore
from api.utils import hms_to_seconds, get_service_day_start_ts, translate_occupancy # type: ignore
from db_manager import init_db # type: ignore
from dotenv import load_dotenv # type: ignore
from fastapi import FastAPI, HTTPException # type: ignore
from fastapi.middleware.cors import CORSMiddleware # type: ignore
load_dotenv()
ttc_cache = AsyncBusCache(ttl=20)
# Initialize database connection globally
db = init_db()
app = FastAPI(title="WheresMyBus v2.0 API")
# Setup CORS for your React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, use your actual React URL
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
"""Simple health check endpoint"""
return "backend is running"
@app.get("/api/vehicles")
async def get_vehicles():
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
return {
"status": "success",
"count": len(vehicles),
"vehicles": vehicles
}
@app.get("/api/routes")
async def get_all_routes():
"""
Returns a complete list of TTC routes with their display names and colors.
"""
try:
# Run the query against DuckDB
# We handle missing colors by providing defaults (TTC Red: #FF0000)
query = """
SELECT
route_id,
route_short_name,
route_long_name,
COALESCE(route_color, 'FF0000') as route_color,
COALESCE(route_text_color, 'FFFFFF') as route_text_color
FROM routes
ORDER BY
CASE
WHEN CAST(route_short_name AS VARCHAR) ~ '^[0-9]+$' THEN CAST(route_short_name AS INTEGER)
ELSE 999
END,
route_short_name;
"""
results = db.execute(query).fetchall()
# Convert to a clean list of dictionaries
route_list = [
{
"id": r[0],
"number": r[1],
"name": r[2],
"color": f"#{r[3]}",
"text_color": f"#{r[4]}"
}
for r in results
]
return {
"status": "success",
"count": len(route_list),
"routes": route_list
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/api/routes/{route_id}")
async def get_route_view(route_id: str):
data = await ttc_cache.get_data()
all_buses = data.get("vehicles", [])
route_buses = [v for v in all_buses if v['route'] == route_id]
if not route_buses:
return {"route": route_id, "vehicles": []}
# IMPORTANT: Cast Trip IDs to strings to ensure they match the DB
trip_ids = [str(v['trip_id']) for v in route_buses]
placeholders = ','.join(['?'] * len(trip_ids))
# We use CAST(? AS VARCHAR) to force DuckDB to match strings to strings
query = f"""
SELECT
CAST(st.trip_id AS VARCHAR),
CAST(st.stop_id AS VARCHAR),
st.arrival_time as scheduled_time,
t.trip_headsign
FROM stop_times st
JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR)
WHERE CAST(st.trip_id AS VARCHAR) IN ({placeholders})
"""
db_rows = db.execute(query, trip_ids).fetchall()
# Check if we got ANYTHING back from the DB
if not db_rows:
print(f"DEBUG: No matches in DB for Trip IDs: {trip_ids[:3]}")
schedule_map = {(r[0], r[1]): r[2] for r in db_rows}
name_map = {r[0]: r[3] for r in db_rows}
service_day_ts = get_service_day_start_ts()
enriched = []
for bus in route_buses:
# Default delay is 0 if no prediction exists
raw_delay_mins = 0
pred_time = bus.get('predicted_time')
stop_id = bus.get('next_stop_id')
if pred_time and stop_id:
sched_hms = schedule_map.get((str(bus['trip_id']), str(stop_id)))
if sched_hms:
# DELAY = SCHEDULED - PREDICTED (negative = late, positive = early)
# Handle GTFS times >= 24 hours (next day)
h, m, s = map(int, sched_hms.split(':'))
extra_days = h // 24
plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(sched_hms)
raw_delay_mins = round((plan_ts - pred_time) / 60)
enriched.append({
"number": bus['id'],
"name": name_map.get(str(bus['trip_id']), "Not in Schedule"), # This is the destination
"location": {"lat": bus['lat'], "lon": bus['lon']},
"delay_mins": raw_delay_mins, # Actual integer: 5 = 5m late, -2 = 2m early
"fullness": translate_occupancy(bus['occupancy'])
})
return {
"route": route_id,
"count": len(enriched),
"vehicles": enriched
}
@app.get("/api/vehicles/{vehicle_id}")
async def get_vehicle_view(vehicle_id: str):
# 1. Pull latest from cache
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
# 2. Find this specific bus in the list
bus = next((v for v in vehicles if str(v['id']) == vehicle_id), None)
if not bus:
raise HTTPException(status_code=404, detail="Vehicle not active or not found")
trip_id = str(bus['trip_id'])
next_stop_id = bus.get('next_stop_id')
predicted_time = bus.get('predicted_time')
# 3. Handshake with Database (Cast to VARCHAR to avoid type errors)
# We get the destination name and the specific scheduled arrival time
destination = "Not in Schedule"
delay_mins = 0
if next_stop_id:
query = """
SELECT
t.trip_headsign,
st.arrival_time as scheduled_time
FROM trips t
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
WHERE CAST(t.trip_id AS VARCHAR) = ?
AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id, str(next_stop_id)]).fetchone()
if row:
destination = row[0]
scheduled_hms = row[1]
# DELAY = SCHEDULED - PREDICTED (negative = late, positive = early)
if predicted_time:
service_day_ts = get_service_day_start_ts()
# Handle GTFS times >= 24 hours (next day)
h, m, s = map(int, scheduled_hms.split(':'))
extra_days = h // 24
plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms)
delay_mins = round((plan_ts - predicted_time) / 60)
else:
# If no next_stop_id, try to get destination from trip_id only
query = """
SELECT trip_headsign
FROM trips
WHERE CAST(trip_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id]).fetchone()
if row:
destination = row[0]
return {
"vehicle_number": vehicle_id,
"route_id": bus['route'],
"name": destination,
"location": {
"lat": bus['lat'],
"lon": bus['lon']
},
"delay_mins": delay_mins,
"fullness": translate_occupancy(bus['occupancy']),
"trip_id": trip_id
}
@app.get("/api/stop/{stop_code}")
async def get_stop_view(stop_code: str):
# 1. Translate Pole Number to Database ID
stop_info = db.execute("SELECT stop_id, stop_name, stop_lat, stop_lon FROM stops WHERE CAST(stop_code AS VARCHAR) = ? LIMIT 1", [str(stop_code)]).fetchone()
if not stop_info:
return {"error": "Stop code not found"}
target_id = str(stop_info[0])
stop_name = stop_info[1]
stop_lat = stop_info[2]
stop_lon = stop_info[3]
# 2. Get the Cache structure (dict with vehicles, predictions, alerts)
cached_data = await ttc_cache.get_data()
vehicles_list = cached_data.get("vehicles", [])
predictions = cached_data.get("predictions", {})
# Build vehicles map for quick lookup
vehicles = {str(v['trip_id']): v for v in vehicles_list}
from datetime import timezone
now = datetime.now(timezone.utc).timestamp()
two_hours_out = now + 7200
arrivals = []
# 3. Search the FULL itineraries for our target_id
for trip_id, itinerary in predictions.items():
if target_id in itinerary:
pred_time = itinerary[target_id]
# Only include if the bus hasn't passed the stop yet and is within 2 hours
if now <= pred_time <= two_hours_out:
# 4. Handshake with DB for destination and schedule
query = """
SELECT t.trip_headsign, st.arrival_time as scheduled_time, r.route_short_name
FROM trips t
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(t.trip_id AS VARCHAR) = ? AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id, target_id]).fetchone()
if row:
# Find the actual bus for fullness (if it's on the road)
bus = vehicles.get(trip_id)
# DELAY = SCHEDULED - PREDICTED (negative = late, positive = early)
service_day_ts = get_service_day_start_ts()
scheduled_hms = row[1]
# Handle GTFS times >= 24 hours (next day)
h, m, s = map(int, scheduled_hms.split(':'))
extra_days = h // 24
plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms)
arrivals.append({
"route": row[2],
"destination": row[0],
"eta_mins": round((pred_time - now) / 60),
"delay_mins": round((plan_ts - pred_time) / 60),
"fullness": translate_occupancy(bus['occupancy']) if bus else "Unknown",
"vehicle_id": bus['id'] if bus else "In Transit"
})
arrivals.sort(key=lambda x: x['eta_mins'])
return {
"stop_name": stop_name,
"stop_code": stop_code,
"location": {
"lat": stop_lat,
"lon": stop_lon
},
"arrivals": arrivals
}
@app.get("/api/alerts")
async def get_all_alerts():
"""
Returns every active service alert for the entire TTC network.
"""
from datetime import timezone
data = await ttc_cache.get_data()
return {
"timestamp": datetime.now(timezone.utc).timestamp(),
"count": len(data["alerts"]),
"alerts": data["alerts"]
}
@app.get("/api/alerts/{route_id}")
async def get_alerts_for_route(route_id: str):
data = await ttc_cache.get_data()
alerts = data.get("alerts", {})
route_alerts = alerts.get(route_id, [])
if not route_alerts:
return {
"route_id": route_id,
"count": 0,
"alerts": "No alerts"
}
return {
"route_id": route_id,
"count": len(route_alerts),
"alerts": route_alerts
}
@app.get("/api/nearby")
async def get_nearby_context(lat: float, lon: float):
# Bounding box approximation for ~500m radius (fast, no Haversine needed)
# At Toronto's latitude (~43.6°): 1° lat ≈ 111km, 1° lon ≈ 55.4km
# For 500m: lat_range = 0.0045°, lon_range ≈ 0.0072°
lat_range = 0.0045 # ~500m in latitude (constant globally)
lon_range = 0.0072 # ~500m in longitude at Toronto's latitude
# 1. Find all stops within the bounding box
query_stops = """
SELECT stop_id, stop_code, stop_name, stop_lat, stop_lon
FROM stops
WHERE stop_lat BETWEEN ? AND ?
AND stop_lon BETWEEN ? AND ?
"""
stops = db.execute(query_stops, [lat - lat_range, lat + lat_range, lon - lon_range, lon + lon_range]).fetchall()
stop_ids = [str(s[0]) for s in stops]
if not stop_ids:
return {"stops": [], "routes": []}
# 2. Find all unique routes serving these specific stops
placeholders = ','.join(['?'] * len(stop_ids))
query_routes = f"""
SELECT DISTINCT r.route_id, r.route_short_name, r.route_long_name, r.route_color
FROM routes r
JOIN trips t ON r.route_id = t.route_id
JOIN stop_times st ON t.trip_id = st.trip_id
WHERE CAST(st.stop_id AS VARCHAR) IN ({placeholders})
"""
routes = db.execute(query_routes, stop_ids).fetchall()
return {
"stops": [
{"id": s[0], "code": s[1], "name": s[2], "lat": s[3], "lon": s[4]}
for s in stops
],
"routes": [
{"id": r[0], "short_name": r[1], "long_name": r[2], "color": f"#{r[3]}"}
for r in routes
]
}
if __name__ == "__main__":
import uvicorn # type: ignore
# Start the server
uvicorn.run(app, host="0.0.0.0", port=7860)