Spaces:
Running
Running
File size: 7,210 Bytes
0170ac5 999c03e 0170ac5 999c03e 0170ac5 999c03e 0170ac5 999c03e d73b136 999c03e d73b136 0170ac5 d73b136 0170ac5 d73b136 0170ac5 d73b136 0170ac5 d73b136 0170ac5 d73b136 0170ac5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import duckdb # type: ignore
import os # type: ignore
import sys # type: ignore
from pathlib import Path # type: ignore
from dotenv import load_dotenv # type: ignore
# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))
from api.update_static import GTFSSyncManager
from api.bus_cache import AsyncBusCache # type: ignore
# Configuration - always save DB in src/ directory
DB_PATH = str(Path(__file__).parent / "ttc_gtfs.duckdb")
STATIC_DIR = str(Path(__file__).parent.parent / "static")
def init_db():
sync_mgr = GTFSSyncManager()
remote_data = sync_mgr.get_remote_metadata()
# Connect to existing or new DB
con = duckdb.connect(sync_mgr.DB_PATH)
# 1. Setup metadata tracking
con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)")
local_update = con.execute("SELECT value FROM sync_metadata WHERE key = 'last_modified'").fetchone()
# 2. Check if we need to sync based on API metadata
should_sync = False
if not local_update or (remote_data and remote_data["updated_at"] > local_update[0]):
should_sync = True
if should_sync and remote_data:
print(f"--- Data Stale. Remote: {remote_data['updated_at']} | Local: {local_update[0] if local_update else 'None'} ---")
con.close() # Close to allow file deletion
sync_mgr.perform_full_sync(remote_data["url"])
# Reconnect and finalize metadata
con = duckdb.connect(sync_mgr.DB_PATH)
con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)")
con.execute("INSERT OR REPLACE INTO sync_metadata VALUES ('last_modified', ?)", [remote_data["updated_at"]])
# 3. Standard Import Loop (runs if DB was nuked or is missing tables)
tables = [t[0] for t in con.execute("SHOW TABLES").fetchall()]
if all(t in tables for t in ["routes", "trips", "stops", "stop_times", "shapes"]):
return con
print("--- Initializing/Updating DuckDB: Importing CSVs ---")
# Updated files list to include the missing shapes
files = ["routes.txt", "trips.txt", "stops.txt", "stop_times.txt", "shapes.txt"]
for f in files:
file_path = Path(STATIC_DIR) / f
table_name = f.replace(".txt", "")
if file_path.exists():
print(f"Importing {f} into table '{table_name}'...")
abs_file_path = str(file_path.resolve())
# Use 'CREATE OR REPLACE' to overwrite existing tables without crashing
con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM read_csv_auto('{abs_file_path}')")
else:
print(f"Error: {file_path} not found!")
# Add this inside init_db in db_manager.py after the file import loop
print("--- Creating Indexes for Performance ---")
# This speeds up the /api/shapes/{shape_id} endpoint significantly
con.execute("CREATE INDEX IF NOT EXISTS idx_shape_id ON shapes (shape_id)")
# While you're at it, indexing trip_id in stop_times speeds up your arrival logic
con.execute("CREATE INDEX IF NOT EXISTS idx_stop_times_trip_id ON stop_times (trip_id)")
print("--- Database Import Complete ---")
return con
async def test_data_integrity(con):
"""
Runs a test join to confirm that a trip ID can be linked to a route name and stop list.
Uses AsyncBusCache to get a real trip_id from the live API.
"""
print("--- Running Integrity Test ---")
try:
# Get trip_id from live API using AsyncBusCache
cache = AsyncBusCache(ttl=20)
vehicles = await cache.get_data()
if not vehicles:
print("No vehicles available from API, falling back to database trip_id")
sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
else:
# Extract trip_id from the first vehicle
# We need to get the raw GTFS data to access trip_id
import httpx # type: ignore
from google.transit import gtfs_realtime_pb2 # type: ignore
load_dotenv()
gtfs_rt_url = os.getenv("GTFS_RT_URL")
if not gtfs_rt_url:
raise ValueError("GTFS_RT_URL is not set")
async with httpx.AsyncClient() as client:
response = await client.get(gtfs_rt_url, timeout=10)
response.raise_for_status()
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(response.content)
# Get trip_id from first vehicle entity
sample_trip = None
for entity in feed.entity:
if entity.HasField('vehicle') and entity.vehicle.trip.trip_id:
sample_trip = entity.vehicle.trip.trip_id
break
if not sample_trip:
print("No trip_id found in API response, falling back to database")
sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
else:
print(f"Using trip_id from live API: {sample_trip}")
# First, get the total count
count_query = f"""
SELECT COUNT(*)
FROM trips t
JOIN stop_times st ON t.trip_id = st.trip_id
WHERE t.trip_id = '{sample_trip}'
"""
total_count = con.execute(count_query).fetchone()[0]
# Determine sample size - show all if <= 20, otherwise show first 20
sample_size = min(20, total_count) if total_count > 20 else total_count
query = f"""
SELECT
r.route_short_name,
t.trip_headsign,
st.stop_sequence,
s.stop_name
FROM trips t
JOIN routes r ON t.route_id = r.route_id
JOIN stop_times st ON t.trip_id = st.trip_id
JOIN stops s ON st.stop_id = s.stop_id
WHERE t.trip_id = '{sample_trip}'
ORDER BY st.stop_sequence
LIMIT {sample_size};
"""
results = con.execute(query).fetchall()
print(f"\nSuccessfully joined data for Trip ID: {sample_trip}")
print(f"Total stops in trip: {total_count}")
if total_count > sample_size:
print(f"Showing first {sample_size} stops (sample):\n")
else:
print(f"Showing all {total_count} stops:\n")
print(f"{'Route':<8} {'Headsign':<30} {'Stop #':<8} {'Stop Name':<50}")
print("-" * 100)
for res in results:
route = res[0] or "N/A"
headsign = (res[1] or "N/A")[:28] # Truncate if too long
stop_seq = res[2]
stop_name = (res[3] or "N/A")[:48] # Truncate if too long
print(f"{route:<8} {headsign:<30} {stop_seq:<8} {stop_name:<50}")
except Exception as e:
print(f"Integrity test failed: {e}")
if __name__ == "__main__":
import asyncio # type: ignore
db_con = init_db()
asyncio.run(test_data_integrity(db_con)) |