WMB2Backened / src /db_manager.py
42Cummer's picture
Upload 9 files
0170ac5 verified
raw
history blame
5.7 kB
import duckdb # type: ignore
import os # type: ignore
import sys # type: ignore
from pathlib import Path # type: ignore
from dotenv import load_dotenv # type: ignore
# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))
from api.bus_cache import AsyncBusCache # type: ignore
# Configuration - always save DB in src/ directory
DB_PATH = str(Path(__file__).parent / "ttc_gtfs.duckdb")
STATIC_DIR = str(Path(__file__).parent.parent / "static")
def init_db():
"""
Connects to DuckDB and imports the GTFS-Static data from the static/ directory.
"""
# 1. Connect to DuckDB (creates the file if it doesn't exist)
con = duckdb.connect(DB_PATH)
# 2. Check if the database is already populated
tables = con.execute("SHOW TABLES").fetchall()
if ('stop_times',) in tables:
print("--- Database already exists and is populated ---")
return con
print("--- Initializing DuckDB: Importing CSVs from /static ---")
# Core GTFS files we need for the rework
files = ["routes.txt", "trips.txt", "stops.txt", "stop_times.txt"]
for f in files:
file_path = Path(STATIC_DIR) / f
table_name = f.replace(".txt", "")
if file_path.exists():
print(f"Loading {f} into table '{table_name}'...")
# 'read_csv_auto' automatically detects headers and data types
# Use absolute path for DuckDB
abs_file_path = str(file_path.resolve())
con.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{abs_file_path}')")
else:
print(f"Error: {file_path} not found! Please ensure it is in the static/ folder.")
print("--- Database Import Complete ---")
return con
async def test_data_integrity(con):
"""
Runs a test join to confirm that a trip ID can be linked to a route name and stop list.
Uses AsyncBusCache to get a real trip_id from the live API.
"""
print("--- Running Integrity Test ---")
try:
# Get trip_id from live API using AsyncBusCache
cache = AsyncBusCache(ttl=20)
vehicles = await cache.get_data()
if not vehicles:
print("No vehicles available from API, falling back to database trip_id")
sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
else:
# Extract trip_id from the first vehicle
# We need to get the raw GTFS data to access trip_id
import httpx # type: ignore
from google.transit import gtfs_realtime_pb2 # type: ignore
load_dotenv()
gtfs_rt_url = os.getenv("GTFS_RT_URL")
if not gtfs_rt_url:
raise ValueError("GTFS_RT_URL is not set")
async with httpx.AsyncClient() as client:
response = await client.get(gtfs_rt_url, timeout=10)
response.raise_for_status()
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(response.content)
# Get trip_id from first vehicle entity
sample_trip = None
for entity in feed.entity:
if entity.HasField('vehicle') and entity.vehicle.trip.trip_id:
sample_trip = entity.vehicle.trip.trip_id
break
if not sample_trip:
print("No trip_id found in API response, falling back to database")
sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
else:
print(f"Using trip_id from live API: {sample_trip}")
# First, get the total count
count_query = f"""
SELECT COUNT(*)
FROM trips t
JOIN stop_times st ON t.trip_id = st.trip_id
WHERE t.trip_id = '{sample_trip}'
"""
total_count = con.execute(count_query).fetchone()[0]
# Determine sample size - show all if <= 20, otherwise show first 20
sample_size = min(20, total_count) if total_count > 20 else total_count
query = f"""
SELECT
r.route_short_name,
t.trip_headsign,
st.stop_sequence,
s.stop_name
FROM trips t
JOIN routes r ON t.route_id = r.route_id
JOIN stop_times st ON t.trip_id = st.trip_id
JOIN stops s ON st.stop_id = s.stop_id
WHERE t.trip_id = '{sample_trip}'
ORDER BY st.stop_sequence
LIMIT {sample_size};
"""
results = con.execute(query).fetchall()
print(f"\nSuccessfully joined data for Trip ID: {sample_trip}")
print(f"Total stops in trip: {total_count}")
if total_count > sample_size:
print(f"Showing first {sample_size} stops (sample):\n")
else:
print(f"Showing all {total_count} stops:\n")
print(f"{'Route':<8} {'Headsign':<30} {'Stop #':<8} {'Stop Name':<50}")
print("-" * 100)
for res in results:
route = res[0] or "N/A"
headsign = (res[1] or "N/A")[:28] # Truncate if too long
stop_seq = res[2]
stop_name = (res[3] or "N/A")[:48] # Truncate if too long
print(f"{route:<8} {headsign:<30} {stop_seq:<8} {stop_name:<50}")
except Exception as e:
print(f"Integrity test failed: {e}")
if __name__ == "__main__":
import asyncio # type: ignore
db_con = init_db()
asyncio.run(test_data_integrity(db_con))