File size: 7,210 Bytes
0170ac5
 
 
 
 
 
999c03e
0170ac5
 
 
999c03e
0170ac5
 
 
 
 
 
 
999c03e
 
0170ac5
999c03e
 
 
 
 
 
 
 
 
 
 
d73b136
999c03e
 
 
 
 
 
 
 
 
 
 
 
d73b136
0170ac5
 
d73b136
0170ac5
d73b136
 
0170ac5
 
 
 
 
 
d73b136
0170ac5
d73b136
 
 
0170ac5
d73b136
 
 
 
 
 
 
 
 
0170ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import duckdb # type: ignore
import os # type: ignore
import sys # type: ignore
from pathlib import Path # type: ignore
from dotenv import load_dotenv # type: ignore


# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))

from api.update_static import GTFSSyncManager
from api.bus_cache import AsyncBusCache # type: ignore

# Configuration - always save DB in src/ directory
DB_PATH = str(Path(__file__).parent / "ttc_gtfs.duckdb")
STATIC_DIR = str(Path(__file__).parent.parent / "static")

def init_db():
    sync_mgr = GTFSSyncManager()
    remote_data = sync_mgr.get_remote_metadata()
    
    # Connect to existing or new DB
    con = duckdb.connect(sync_mgr.DB_PATH)
    
    # 1. Setup metadata tracking
    con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)")
    local_update = con.execute("SELECT value FROM sync_metadata WHERE key = 'last_modified'").fetchone()

    # 2. Check if we need to sync based on API metadata
    should_sync = False
    if not local_update or (remote_data and remote_data["updated_at"] > local_update[0]):
        should_sync = True
    
    if should_sync and remote_data:
        print(f"--- Data Stale. Remote: {remote_data['updated_at']} | Local: {local_update[0] if local_update else 'None'} ---")
        con.close() # Close to allow file deletion
        sync_mgr.perform_full_sync(remote_data["url"])
        
        # Reconnect and finalize metadata
        con = duckdb.connect(sync_mgr.DB_PATH)
        con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)")
        con.execute("INSERT OR REPLACE INTO sync_metadata VALUES ('last_modified', ?)", [remote_data["updated_at"]])

    # 3. Standard Import Loop (runs if DB was nuked or is missing tables)
    tables = [t[0] for t in con.execute("SHOW TABLES").fetchall()]
    if all(t in tables for t in ["routes", "trips", "stops", "stop_times", "shapes"]):
        return con

    print("--- Initializing/Updating DuckDB: Importing CSVs ---")
    
    # Updated files list to include the missing shapes
    files = ["routes.txt", "trips.txt", "stops.txt", "stop_times.txt", "shapes.txt"]
    
    for f in files:
        file_path = Path(STATIC_DIR) / f
        table_name = f.replace(".txt", "")
        
        if file_path.exists():
            print(f"Importing {f} into table '{table_name}'...")
            abs_file_path = str(file_path.resolve())
            
            # Use 'CREATE OR REPLACE' to overwrite existing tables without crashing
            con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM read_csv_auto('{abs_file_path}')")
        else:
            print(f"Error: {file_path} not found!")
        
    # Add this inside init_db in db_manager.py after the file import loop
    print("--- Creating Indexes for Performance ---")
    # This speeds up the /api/shapes/{shape_id} endpoint significantly
    con.execute("CREATE INDEX IF NOT EXISTS idx_shape_id ON shapes (shape_id)")

    # While you're at it, indexing trip_id in stop_times speeds up your arrival logic
    con.execute("CREATE INDEX IF NOT EXISTS idx_stop_times_trip_id ON stop_times (trip_id)")

    print("--- Database Import Complete ---")
    return con

async def test_data_integrity(con):
    """
    Runs a test join to confirm that a trip ID can be linked to a route name and stop list.
    Uses AsyncBusCache to get a real trip_id from the live API.
    """
    print("--- Running Integrity Test ---")
    try:
        # Get trip_id from live API using AsyncBusCache
        cache = AsyncBusCache(ttl=20)
        vehicles = await cache.get_data()
        
        if not vehicles:
            print("No vehicles available from API, falling back to database trip_id")
            sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
        else:
            # Extract trip_id from the first vehicle
            # We need to get the raw GTFS data to access trip_id
            import httpx # type: ignore
            from google.transit import gtfs_realtime_pb2 # type: ignore
            
            load_dotenv()
            gtfs_rt_url = os.getenv("GTFS_RT_URL")
            if not gtfs_rt_url:
                raise ValueError("GTFS_RT_URL is not set")
            
            async with httpx.AsyncClient() as client:
                response = await client.get(gtfs_rt_url, timeout=10)
                response.raise_for_status()
            
            feed = gtfs_realtime_pb2.FeedMessage()
            feed.ParseFromString(response.content)
            
            # Get trip_id from first vehicle entity
            sample_trip = None
            for entity in feed.entity:
                if entity.HasField('vehicle') and entity.vehicle.trip.trip_id:
                    sample_trip = entity.vehicle.trip.trip_id
                    break
            
            if not sample_trip:
                print("No trip_id found in API response, falling back to database")
                sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0]
            else:
                print(f"Using trip_id from live API: {sample_trip}")
        
        # First, get the total count
        count_query = f"""
            SELECT COUNT(*) 
            FROM trips t
            JOIN stop_times st ON t.trip_id = st.trip_id
            WHERE t.trip_id = '{sample_trip}'
        """
        total_count = con.execute(count_query).fetchone()[0]
        
        # Determine sample size - show all if <= 20, otherwise show first 20
        sample_size = min(20, total_count) if total_count > 20 else total_count
        
        query = f"""
            SELECT 
                r.route_short_name, 
                t.trip_headsign, 
                st.stop_sequence, 
                s.stop_name
            FROM trips t
            JOIN routes r ON t.route_id = r.route_id
            JOIN stop_times st ON t.trip_id = st.trip_id
            JOIN stops s ON st.stop_id = s.stop_id
            WHERE t.trip_id = '{sample_trip}'
            ORDER BY st.stop_sequence
            LIMIT {sample_size};
        """
        results = con.execute(query).fetchall()
        
        print(f"\nSuccessfully joined data for Trip ID: {sample_trip}")
        print(f"Total stops in trip: {total_count}")
        if total_count > sample_size:
            print(f"Showing first {sample_size} stops (sample):\n")
        else:
            print(f"Showing all {total_count} stops:\n")
        
        print(f"{'Route':<8} {'Headsign':<30} {'Stop #':<8} {'Stop Name':<50}")
        print("-" * 100)
        for res in results:
            route = res[0] or "N/A"
            headsign = (res[1] or "N/A")[:28]  # Truncate if too long
            stop_seq = res[2]
            stop_name = (res[3] or "N/A")[:48]  # Truncate if too long
            print(f"{route:<8} {headsign:<30} {stop_seq:<8} {stop_name:<50}")
            
    except Exception as e:
        print(f"Integrity test failed: {e}")

if __name__ == "__main__":
    import asyncio # type: ignore
    db_con = init_db()
    asyncio.run(test_data_integrity(db_con))