WMB2Backened / src /app.py
42Cummer's picture
Upload 9 files
0170ac5 verified
raw
history blame
11 kB
from datetime import datetime # type: ignore
import sys # type: ignore
from pathlib import Path
# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))
# Add src directory to path to allow imports from same directory
sys.path.insert(0, str(Path(__file__).parent))
from api.bus_cache import AsyncBusCache # type: ignore
from api.utils import hms_to_seconds, get_service_day_start_ts, translate_occupancy # type: ignore
from db_manager import init_db # type: ignore
from dotenv import load_dotenv # type: ignore
from fastapi import FastAPI, HTTPException # type: ignore
from fastapi.middleware.cors import CORSMiddleware # type: ignore
load_dotenv()
ttc_cache = AsyncBusCache(ttl=20)
# Initialize database connection globally
db = init_db()
app = FastAPI(title="WheresMyBus v2.0 API")
# Setup CORS for your React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, use your actual React URL
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
"""Simple health check endpoint"""
return "backend is running"
@app.get("/api/vehicles")
async def get_vehicles():
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
return {
"status": "success",
"count": len(vehicles),
"vehicles": vehicles
}
@app.get("/api/routes")
async def get_all_routes():
"""
Returns a complete list of TTC routes with their display names and colors.
"""
try:
# Run the query against DuckDB
# We handle missing colors by providing defaults (TTC Red: #FF0000)
query = """
SELECT
route_id,
route_short_name,
route_long_name,
COALESCE(route_color, 'FF0000') as route_color,
COALESCE(route_text_color, 'FFFFFF') as route_text_color
FROM routes
ORDER BY
CASE
WHEN CAST(route_short_name AS VARCHAR) ~ '^[0-9]+$' THEN CAST(route_short_name AS INTEGER)
ELSE 999
END,
route_short_name;
"""
results = db.execute(query).fetchall()
# Convert to a clean list of dictionaries
route_list = [
{
"id": r[0],
"number": r[1],
"name": r[2],
"color": f"#{r[3]}",
"text_color": f"#{r[4]}"
}
for r in results
]
return {
"status": "success",
"count": len(route_list),
"routes": route_list
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/api/routes/{route_id}")
async def get_route_view(route_id: str):
data = await ttc_cache.get_data()
all_buses = data.get("vehicles", [])
route_buses = [v for v in all_buses if v['route'] == route_id]
if not route_buses:
return {"route": route_id, "vehicles": []}
# IMPORTANT: Cast Trip IDs to strings to ensure they match the DB
trip_ids = [str(v['trip_id']) for v in route_buses]
placeholders = ','.join(['?'] * len(trip_ids))
# We use CAST(? AS VARCHAR) to force DuckDB to match strings to strings
query = f"""
SELECT
CAST(st.trip_id AS VARCHAR),
CAST(st.stop_id AS VARCHAR),
st.arrival_time,
t.trip_headsign
FROM stop_times st
JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR)
WHERE CAST(st.trip_id AS VARCHAR) IN ({placeholders})
"""
db_rows = db.execute(query, trip_ids).fetchall()
# Check if we got ANYTHING back from the DB
if not db_rows:
print(f"DEBUG: No matches in DB for Trip IDs: {trip_ids[:3]}")
schedule_map = {(r[0], r[1]): r[2] for r in db_rows}
name_map = {r[0]: r[3] for r in db_rows}
service_day_ts = get_service_day_start_ts()
enriched = []
for bus in route_buses:
# Default delay is 0 if no prediction exists
raw_delay_mins = 0
pred_time = bus.get('predicted_time')
stop_id = bus.get('next_stop_id')
if pred_time and stop_id:
sched_hms = schedule_map.get((str(bus['trip_id']), str(stop_id)))
if sched_hms:
# Math: (Reality Unix - Plan Unix) / 60
plan_ts = service_day_ts + hms_to_seconds(sched_hms)
raw_delay_mins = round((pred_time - plan_ts) / 60)
enriched.append({
"number": bus['id'],
"name": name_map.get(str(bus['trip_id']), "Not in Schedule"), # This is the destination
"location": {"lat": bus['lat'], "lon": bus['lon']},
"delay_mins": raw_delay_mins, # Actual integer: 5 = 5m late, -2 = 2m early
"fullness": translate_occupancy(bus['occupancy'])
})
return {
"route": route_id,
"count": len(enriched),
"vehicles": enriched
}
@app.get("/api/vehicles/{vehicle_id}")
async def get_vehicle_view(vehicle_id: str):
# 1. Pull latest from cache
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
# 2. Find this specific bus in the list
bus = next((v for v in vehicles if str(v['id']) == vehicle_id), None)
if not bus:
raise HTTPException(status_code=404, detail="Vehicle not active or not found")
trip_id = str(bus['trip_id'])
next_stop_id = bus.get('next_stop_id')
predicted_time = bus.get('predicted_time')
# 3. Handshake with Database (Cast to VARCHAR to avoid type errors)
# We get the destination name and the specific scheduled arrival time
destination = "Not in Schedule"
delay_mins = 0
if next_stop_id:
query = """
SELECT
t.trip_headsign,
st.arrival_time
FROM trips t
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
WHERE CAST(t.trip_id AS VARCHAR) = ?
AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id, str(next_stop_id)]).fetchone()
if row:
destination = row[0]
scheduled_hms = row[1]
# Math: Reality (Unix Time) - Plan (Service Day + Scheduled Seconds)
if predicted_time:
service_day_ts = get_service_day_start_ts()
plan_ts = service_day_ts + hms_to_seconds(scheduled_hms)
delay_mins = round((predicted_time - plan_ts) / 60)
else:
# If no next_stop_id, try to get destination from trip_id only
query = """
SELECT trip_headsign
FROM trips
WHERE CAST(trip_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id]).fetchone()
if row:
destination = row[0]
return {
"vehicle_number": vehicle_id,
"route_id": bus['route'],
"name": destination,
"location": {
"lat": bus['lat'],
"lon": bus['lon']
},
"delay_mins": delay_mins,
"fullness": translate_occupancy(bus['occupancy']),
"trip_id": trip_id
}
@app.get("/api/stop/{stop_code}")
async def get_stop_view(stop_code: str):
# 1. Translate Pole Number to Database ID
stop_info = db.execute("SELECT stop_id, stop_name FROM stops WHERE CAST(stop_code AS VARCHAR) = ? LIMIT 1", [str(stop_code)]).fetchone()
if not stop_info:
return {"error": "Stop code not found"}
target_id = str(stop_info[0])
stop_name = stop_info[1]
# 2. Get the Cache structure (dict with vehicles, predictions, alerts)
cached_data = await ttc_cache.get_data()
vehicles_list = cached_data.get("vehicles", [])
predictions = cached_data.get("predictions", {})
# Build vehicles map for quick lookup
vehicles = {str(v['trip_id']): v for v in vehicles_list}
now = datetime.now().timestamp()
two_hours_out = now + 7200
arrivals = []
# 3. Search the FULL itineraries for our target_id
for trip_id, itinerary in predictions.items():
if target_id in itinerary:
pred_time = itinerary[target_id]
# Only include if the bus hasn't passed the stop yet and is within 2 hours
if now <= pred_time <= two_hours_out:
# 4. Handshake with DB for destination and schedule
query = """
SELECT t.trip_headsign, st.arrival_time, r.route_short_name
FROM trips t
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(t.trip_id AS VARCHAR) = ? AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id, target_id]).fetchone()
if row:
# Find the actual bus for fullness (if it's on the road)
bus = vehicles.get(trip_id)
plan_ts = get_service_day_start_ts() + hms_to_seconds(row[1])
arrivals.append({
"route": row[2],
"destination": row[0],
"eta_mins": round((pred_time - now) / 60),
"delay_mins": round((pred_time - plan_ts) / 60),
"fullness": translate_occupancy(bus['occupancy']) if bus else "Unknown",
"vehicle_id": bus['id'] if bus else "In Transit"
})
arrivals.sort(key=lambda x: x['eta_mins'])
return {"stop_name": stop_name, "stop_code": stop_code, "arrivals": arrivals}
@app.get("/api/alerts")
async def get_all_alerts():
"""
Returns every active service alert for the entire TTC network.
"""
data = await ttc_cache.get_data()
return {
"timestamp": datetime.now().timestamp(),
"count": len(data["alerts"]),
"alerts": data["alerts"]
}
@app.get("/api/alerts/{route_id}")
async def get_alerts_for_route(route_id: str):
data = await ttc_cache.get_data()
alerts = data.get("alerts", {})
route_alerts = alerts.get(route_id, [])
if not route_alerts:
return {
"route_id": route_id,
"count": 0,
"alerts": "No alerts"
}
return {
"route_id": route_id,
"count": len(route_alerts),
"alerts": route_alerts
}
if __name__ == "__main__":
import uvicorn # type: ignore
# Start the server
uvicorn.run(app, host="0.0.0.0", port=7860)