import duckdb # type: ignore import os # type: ignore import sys # type: ignore from pathlib import Path # type: ignore from dotenv import load_dotenv # type: ignore # Add parent directory to path to allow imports from api/ sys.path.insert(0, str(Path(__file__).parent.parent)) from api.bus_cache import AsyncBusCache # type: ignore # Configuration - always save DB in src/ directory DB_PATH = str(Path(__file__).parent / "ttc_gtfs.duckdb") STATIC_DIR = str(Path(__file__).parent.parent / "static") def init_db(): """ Connects to DuckDB and imports the GTFS-Static data from the static/ directory. """ # 1. Connect to DuckDB (creates the file if it doesn't exist) con = duckdb.connect(DB_PATH) # 2. Check if the database is already populated tables = con.execute("SHOW TABLES").fetchall() if ('stop_times',) in tables: print("--- Database already exists and is populated ---") return con print("--- Initializing DuckDB: Importing CSVs from /static ---") # Core GTFS files we need for the rework files = ["routes.txt", "trips.txt", "stops.txt", "stop_times.txt"] for f in files: file_path = Path(STATIC_DIR) / f table_name = f.replace(".txt", "") if file_path.exists(): print(f"Loading {f} into table '{table_name}'...") # 'read_csv_auto' automatically detects headers and data types # Use absolute path for DuckDB abs_file_path = str(file_path.resolve()) con.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{abs_file_path}')") else: print(f"Error: {file_path} not found! Please ensure it is in the static/ folder.") print("--- Database Import Complete ---") return con async def test_data_integrity(con): """ Runs a test join to confirm that a trip ID can be linked to a route name and stop list. Uses AsyncBusCache to get a real trip_id from the live API. """ print("--- Running Integrity Test ---") try: # Get trip_id from live API using AsyncBusCache cache = AsyncBusCache(ttl=20) vehicles = await cache.get_data() if not vehicles: print("No vehicles available from API, falling back to database trip_id") sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0] else: # Extract trip_id from the first vehicle # We need to get the raw GTFS data to access trip_id import httpx # type: ignore from google.transit import gtfs_realtime_pb2 # type: ignore load_dotenv() gtfs_rt_url = os.getenv("GTFS_RT_URL") if not gtfs_rt_url: raise ValueError("GTFS_RT_URL is not set") async with httpx.AsyncClient() as client: response = await client.get(gtfs_rt_url, timeout=10) response.raise_for_status() feed = gtfs_realtime_pb2.FeedMessage() feed.ParseFromString(response.content) # Get trip_id from first vehicle entity sample_trip = None for entity in feed.entity: if entity.HasField('vehicle') and entity.vehicle.trip.trip_id: sample_trip = entity.vehicle.trip.trip_id break if not sample_trip: print("No trip_id found in API response, falling back to database") sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0] else: print(f"Using trip_id from live API: {sample_trip}") # First, get the total count count_query = f""" SELECT COUNT(*) FROM trips t JOIN stop_times st ON t.trip_id = st.trip_id WHERE t.trip_id = '{sample_trip}' """ total_count = con.execute(count_query).fetchone()[0] # Determine sample size - show all if <= 20, otherwise show first 20 sample_size = min(20, total_count) if total_count > 20 else total_count query = f""" SELECT r.route_short_name, t.trip_headsign, st.stop_sequence, s.stop_name FROM trips t JOIN routes r ON t.route_id = r.route_id JOIN stop_times st ON t.trip_id = st.trip_id JOIN stops s ON st.stop_id = s.stop_id WHERE t.trip_id = '{sample_trip}' ORDER BY st.stop_sequence LIMIT {sample_size}; """ results = con.execute(query).fetchall() print(f"\nSuccessfully joined data for Trip ID: {sample_trip}") print(f"Total stops in trip: {total_count}") if total_count > sample_size: print(f"Showing first {sample_size} stops (sample):\n") else: print(f"Showing all {total_count} stops:\n") print(f"{'Route':<8} {'Headsign':<30} {'Stop #':<8} {'Stop Name':<50}") print("-" * 100) for res in results: route = res[0] or "N/A" headsign = (res[1] or "N/A")[:28] # Truncate if too long stop_seq = res[2] stop_name = (res[3] or "N/A")[:48] # Truncate if too long print(f"{route:<8} {headsign:<30} {stop_seq:<8} {stop_name:<50}") except Exception as e: print(f"Integrity test failed: {e}") if __name__ == "__main__": import asyncio # type: ignore db_con = init_db() asyncio.run(test_data_integrity(db_con))