Srushti-Kamble commited on
Commit
a4fe8f2
·
1 Parent(s): 0b3a3e8

Add SQLite to Supabase migration utility

Browse files
backend/requirements.txt CHANGED
@@ -8,6 +8,7 @@ python-multipart
8
  # Database
9
  sqlalchemy
10
  aiosqlite
 
11
 
12
  # Auth
13
  pyjwt
 
8
  # Database
9
  sqlalchemy
10
  aiosqlite
11
+ psycopg[binary]
12
 
13
  # Auth
14
  pyjwt
backend/scripts/migrate_sqlite_to_postgres.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Migrate SQLite app data into a Supabase/Postgres database.
2
+
3
+ The script supports both the current FastAPI SQLite schema
4
+ (`users`, `documents`, `chat_messages`) and the older legacy
5
+ `instance/users.db` schema (`user` only).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import logging
11
+ import os
12
+ import sys
13
+ import uuid
14
+ from dataclasses import dataclass, field
15
+ from datetime import datetime, timezone
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ from sqlalchemy import (
20
+ Boolean,
21
+ Column,
22
+ DateTime,
23
+ ForeignKey,
24
+ Integer,
25
+ MetaData,
26
+ String,
27
+ Table,
28
+ Text,
29
+ create_engine,
30
+ inspect,
31
+ select,
32
+ )
33
+ from sqlalchemy.engine import Engine
34
+ from sqlalchemy.exc import IntegrityError
35
+ from sqlalchemy.orm import Session, sessionmaker
36
+
37
+ LOGGER = logging.getLogger("sqlite_to_postgres")
38
+
39
+
40
+ def generate_uuid() -> str:
41
+ return str(uuid.uuid4())
42
+
43
+
44
+ metadata = MetaData()
45
+
46
+ users = Table(
47
+ "users",
48
+ metadata,
49
+ Column("id", String, primary_key=True, default=generate_uuid),
50
+ Column("username", String(80), unique=True, nullable=False, index=True),
51
+ Column("email", String(120), unique=True, nullable=False, index=True),
52
+ Column("hashed_password", String(255), nullable=False),
53
+ Column("is_admin", Boolean, default=False),
54
+ Column("created_at", DateTime, default=lambda: datetime.now(timezone.utc)),
55
+ Column("last_login", DateTime, nullable=True, index=True),
56
+ Column("hf_token", String(255), nullable=True),
57
+ )
58
+
59
+ api_keys = Table(
60
+ "api_keys",
61
+ metadata,
62
+ Column("id", String, primary_key=True, default=generate_uuid),
63
+ Column("user_id", String, ForeignKey("users.id"), nullable=False, index=True),
64
+ Column("key_prefix", String(10), nullable=False),
65
+ Column("hashed_key", String(255), nullable=False, unique=True, index=True),
66
+ Column("created_at", DateTime, default=lambda: datetime.now(timezone.utc)),
67
+ Column("last_used", DateTime, nullable=True),
68
+ )
69
+
70
+ documents = Table(
71
+ "documents",
72
+ metadata,
73
+ Column("id", String, primary_key=True, default=generate_uuid),
74
+ Column("user_id", String, ForeignKey("users.id"), nullable=False, index=True),
75
+ Column("filename", String(255), nullable=False),
76
+ Column("original_name", String(255), nullable=False),
77
+ Column("file_size", Integer, default=0),
78
+ Column("page_count", Integer, default=0),
79
+ Column("chunk_count", Integer, default=0),
80
+ Column("status", String(20), default="pending"),
81
+ Column("error_message", Text, nullable=True),
82
+ Column("uploaded_at", DateTime, default=lambda: datetime.now(timezone.utc)),
83
+ Column("summary", Text, nullable=True),
84
+ )
85
+
86
+ chat_messages = Table(
87
+ "chat_messages",
88
+ metadata,
89
+ Column("id", String, primary_key=True, default=generate_uuid),
90
+ Column("user_id", String, ForeignKey("users.id"), nullable=False, index=True),
91
+ Column("document_id", String, ForeignKey("documents.id"), nullable=True, index=True),
92
+ Column("role", String(20), nullable=False),
93
+ Column("content", Text, nullable=False),
94
+ Column("sources_json", Text, nullable=True),
95
+ Column("created_at", DateTime, default=lambda: datetime.now(timezone.utc)),
96
+ )
97
+
98
+ shared_messages = Table(
99
+ "shared_messages",
100
+ metadata,
101
+ Column("id", String, primary_key=True, default=generate_uuid),
102
+ Column("message_id", String, ForeignKey("chat_messages.id"), nullable=False, unique=True, index=True),
103
+ Column("created_at", DateTime, default=lambda: datetime.now(timezone.utc)),
104
+ )
105
+
106
+
107
+ @dataclass
108
+ class MigrationStats:
109
+ inserted: dict[str, int] = field(default_factory=dict)
110
+ reused: dict[str, int] = field(default_factory=dict)
111
+ skipped: dict[str, int] = field(default_factory=dict)
112
+
113
+ def add(self, table_name: str, action: str) -> None:
114
+ getattr(self, action)[table_name] = getattr(self, action).get(table_name, 0) + 1
115
+
116
+
117
+ def normalize_postgres_url(url: str) -> str:
118
+ """Prefer psycopg v3 when callers pass Supabase's common URL forms."""
119
+ if url.startswith("postgres://"):
120
+ return "postgresql+psycopg://" + url.removeprefix("postgres://")
121
+ if url.startswith("postgresql://"):
122
+ return "postgresql+psycopg://" + url.removeprefix("postgresql://")
123
+ return url
124
+
125
+
126
+ def sqlite_url_from_path(path: str) -> str:
127
+ return f"sqlite:///{Path(path).resolve().as_posix()}"
128
+
129
+
130
+ def make_engine(url: str) -> Engine:
131
+ return create_engine(url, future=True)
132
+
133
+
134
+ def make_session(engine: Engine) -> Session:
135
+ return sessionmaker(bind=engine, autocommit=False, autoflush=False, future=True)()
136
+
137
+
138
+ def reflected_table(engine: Engine, table_name: str) -> Table | None:
139
+ if not inspect(engine).has_table(table_name):
140
+ return None
141
+ reflected = MetaData()
142
+ return Table(table_name, reflected, autoload_with=engine)
143
+
144
+
145
+ def fetch_rows(session: Session, table: Table) -> list[dict[str, Any]]:
146
+ stmt = select(table)
147
+ if "id" in table.c:
148
+ stmt = stmt.order_by(table.c.id)
149
+ return [dict(row) for row in session.execute(stmt).mappings().all()]
150
+
151
+
152
+ def existing_id(session: Session, table: Table, source_id: str | None) -> str | None:
153
+ if not source_id:
154
+ return None
155
+ return session.execute(select(table.c.id).where(table.c.id == source_id)).scalar_one_or_none()
156
+
157
+
158
+ def available_id(session: Session, table: Table, source_id: Any) -> str:
159
+ candidate = str(source_id) if source_id is not None else generate_uuid()
160
+ if existing_id(session, table, candidate) is None:
161
+ return candidate
162
+
163
+ while True:
164
+ candidate = generate_uuid()
165
+ if existing_id(session, table, candidate) is None:
166
+ return candidate
167
+
168
+
169
+ def first_existing_user(session: Session, row: dict[str, Any]) -> str | None:
170
+ email = row.get("email")
171
+ username = row.get("username")
172
+ if email:
173
+ match = session.execute(select(users.c.id).where(users.c.email == email)).scalar_one_or_none()
174
+ if match:
175
+ return match
176
+ if username:
177
+ return session.execute(select(users.c.id).where(users.c.username == username)).scalar_one_or_none()
178
+ return None
179
+
180
+
181
+ def copy_users(
182
+ source_session: Session,
183
+ target_session: Session,
184
+ source_table: Table,
185
+ stats: MigrationStats,
186
+ ) -> dict[str, str]:
187
+ id_map: dict[str, str] = {}
188
+ now = datetime.now(timezone.utc)
189
+
190
+ for row in fetch_rows(source_session, source_table):
191
+ old_id = str(row.get("id"))
192
+ existing = existing_id(target_session, users, old_id) or first_existing_user(target_session, row)
193
+ if existing:
194
+ id_map[old_id] = existing
195
+ stats.add("users", "reused")
196
+ continue
197
+
198
+ is_legacy = source_table.name == "user"
199
+ new_id = available_id(target_session, users, None if is_legacy else old_id)
200
+ user_values = {
201
+ "id": new_id,
202
+ "username": row["username"],
203
+ "email": row["email"],
204
+ "hashed_password": row.get("hashed_password") or row.get("password") or "",
205
+ "is_admin": bool(row.get("is_admin") or False),
206
+ "created_at": row.get("created_at") or now,
207
+ "last_login": row.get("last_login"),
208
+ "hf_token": row.get("hf_token"),
209
+ }
210
+ target_session.execute(users.insert().values(**user_values))
211
+ id_map[old_id] = new_id
212
+ stats.add("users", "inserted")
213
+
214
+ return id_map
215
+
216
+
217
+ def copy_api_keys(
218
+ source_session: Session,
219
+ target_session: Session,
220
+ source_table: Table | None,
221
+ user_id_map: dict[str, str],
222
+ stats: MigrationStats,
223
+ ) -> dict[str, str]:
224
+ id_map: dict[str, str] = {}
225
+ if source_table is None:
226
+ return id_map
227
+
228
+ for row in fetch_rows(source_session, source_table):
229
+ old_id = str(row.get("id"))
230
+ new_user_id = user_id_map.get(str(row.get("user_id")))
231
+ if not new_user_id:
232
+ stats.add("api_keys", "skipped")
233
+ continue
234
+
235
+ existing = (
236
+ existing_id(target_session, api_keys, old_id)
237
+ or target_session.execute(
238
+ select(api_keys.c.id).where(api_keys.c.hashed_key == row.get("hashed_key"))
239
+ ).scalar_one_or_none()
240
+ )
241
+ if existing:
242
+ id_map[old_id] = existing
243
+ stats.add("api_keys", "reused")
244
+ continue
245
+
246
+ new_id = available_id(target_session, api_keys, old_id)
247
+ target_session.execute(
248
+ api_keys.insert().values(
249
+ id=new_id,
250
+ user_id=new_user_id,
251
+ key_prefix=row["key_prefix"],
252
+ hashed_key=row["hashed_key"],
253
+ created_at=row.get("created_at") or datetime.now(timezone.utc),
254
+ last_used=row.get("last_used"),
255
+ )
256
+ )
257
+ id_map[old_id] = new_id
258
+ stats.add("api_keys", "inserted")
259
+
260
+ return id_map
261
+
262
+
263
+ def copy_documents(
264
+ source_session: Session,
265
+ target_session: Session,
266
+ source_table: Table | None,
267
+ user_id_map: dict[str, str],
268
+ stats: MigrationStats,
269
+ ) -> dict[str, str]:
270
+ id_map: dict[str, str] = {}
271
+ if source_table is None:
272
+ return id_map
273
+
274
+ for row in fetch_rows(source_session, source_table):
275
+ old_id = str(row.get("id"))
276
+ new_user_id = user_id_map.get(str(row.get("user_id")))
277
+ if not new_user_id:
278
+ stats.add("documents", "skipped")
279
+ continue
280
+
281
+ existing = existing_id(target_session, documents, old_id)
282
+ if existing:
283
+ id_map[old_id] = existing
284
+ stats.add("documents", "reused")
285
+ continue
286
+
287
+ new_id = available_id(target_session, documents, old_id)
288
+ target_session.execute(
289
+ documents.insert().values(
290
+ id=new_id,
291
+ user_id=new_user_id,
292
+ filename=row["filename"],
293
+ original_name=row["original_name"],
294
+ file_size=row.get("file_size") or 0,
295
+ page_count=row.get("page_count") or 0,
296
+ chunk_count=row.get("chunk_count") or 0,
297
+ status=row.get("status") or "pending",
298
+ error_message=row.get("error_message"),
299
+ uploaded_at=row.get("uploaded_at") or datetime.now(timezone.utc),
300
+ summary=row.get("summary"),
301
+ )
302
+ )
303
+ id_map[old_id] = new_id
304
+ stats.add("documents", "inserted")
305
+
306
+ return id_map
307
+
308
+
309
+ def copy_chat_messages(
310
+ source_session: Session,
311
+ target_session: Session,
312
+ source_table: Table | None,
313
+ user_id_map: dict[str, str],
314
+ document_id_map: dict[str, str],
315
+ stats: MigrationStats,
316
+ ) -> dict[str, str]:
317
+ id_map: dict[str, str] = {}
318
+ if source_table is None:
319
+ return id_map
320
+
321
+ for row in fetch_rows(source_session, source_table):
322
+ old_id = str(row.get("id"))
323
+ new_user_id = user_id_map.get(str(row.get("user_id")))
324
+ old_document_id = row.get("document_id")
325
+ new_document_id = document_id_map.get(str(old_document_id)) if old_document_id else None
326
+ if not new_user_id or (old_document_id and not new_document_id):
327
+ stats.add("chat_messages", "skipped")
328
+ continue
329
+
330
+ existing = existing_id(target_session, chat_messages, old_id)
331
+ if existing:
332
+ id_map[old_id] = existing
333
+ stats.add("chat_messages", "reused")
334
+ continue
335
+
336
+ new_id = available_id(target_session, chat_messages, old_id)
337
+ target_session.execute(
338
+ chat_messages.insert().values(
339
+ id=new_id,
340
+ user_id=new_user_id,
341
+ document_id=new_document_id,
342
+ role=row["role"],
343
+ content=row["content"],
344
+ sources_json=row.get("sources_json"),
345
+ created_at=row.get("created_at") or datetime.now(timezone.utc),
346
+ )
347
+ )
348
+ id_map[old_id] = new_id
349
+ stats.add("chat_messages", "inserted")
350
+
351
+ return id_map
352
+
353
+
354
+ def copy_shared_messages(
355
+ source_session: Session,
356
+ target_session: Session,
357
+ source_table: Table | None,
358
+ message_id_map: dict[str, str],
359
+ stats: MigrationStats,
360
+ ) -> None:
361
+ if source_table is None:
362
+ return
363
+
364
+ for row in fetch_rows(source_session, source_table):
365
+ old_id = str(row.get("id"))
366
+ new_message_id = message_id_map.get(str(row.get("message_id")))
367
+ if not new_message_id:
368
+ stats.add("shared_messages", "skipped")
369
+ continue
370
+
371
+ existing = (
372
+ existing_id(target_session, shared_messages, old_id)
373
+ or target_session.execute(
374
+ select(shared_messages.c.id).where(shared_messages.c.message_id == new_message_id)
375
+ ).scalar_one_or_none()
376
+ )
377
+ if existing:
378
+ stats.add("shared_messages", "reused")
379
+ continue
380
+
381
+ target_session.execute(
382
+ shared_messages.insert().values(
383
+ id=available_id(target_session, shared_messages, old_id),
384
+ message_id=new_message_id,
385
+ created_at=row.get("created_at") or datetime.now(timezone.utc),
386
+ )
387
+ )
388
+ stats.add("shared_messages", "inserted")
389
+
390
+
391
+ def migrate(
392
+ sqlite_url: str,
393
+ postgres_url: str,
394
+ create_tables: bool,
395
+ dry_run: bool,
396
+ ) -> MigrationStats:
397
+ source_engine = make_engine(sqlite_url)
398
+ target_engine = make_engine(normalize_postgres_url(postgres_url))
399
+
400
+ if create_tables:
401
+ metadata.create_all(target_engine)
402
+
403
+ source_session = make_session(source_engine)
404
+ target_session = make_session(target_engine)
405
+ stats = MigrationStats()
406
+
407
+ try:
408
+ current_users = reflected_table(source_engine, "users")
409
+ legacy_users = reflected_table(source_engine, "user")
410
+ source_users = current_users if current_users is not None else legacy_users
411
+ if source_users is None:
412
+ raise RuntimeError("No users table found. Expected 'users' or legacy 'user'.")
413
+
414
+ user_id_map = copy_users(source_session, target_session, source_users, stats)
415
+ copy_api_keys(source_session, target_session, reflected_table(source_engine, "api_keys"), user_id_map, stats)
416
+ document_id_map = copy_documents(
417
+ source_session,
418
+ target_session,
419
+ reflected_table(source_engine, "documents"),
420
+ user_id_map,
421
+ stats,
422
+ )
423
+ message_id_map = copy_chat_messages(
424
+ source_session,
425
+ target_session,
426
+ reflected_table(source_engine, "chat_messages"),
427
+ user_id_map,
428
+ document_id_map,
429
+ stats,
430
+ )
431
+ copy_shared_messages(
432
+ source_session,
433
+ target_session,
434
+ reflected_table(source_engine, "shared_messages"),
435
+ message_id_map,
436
+ stats,
437
+ )
438
+
439
+ if dry_run:
440
+ target_session.rollback()
441
+ LOGGER.info("Dry run complete; rolled back target transaction.")
442
+ else:
443
+ target_session.commit()
444
+ LOGGER.info("Migration committed.")
445
+
446
+ return stats
447
+ except IntegrityError:
448
+ target_session.rollback()
449
+ LOGGER.exception("Migration failed because the target database rejected a row.")
450
+ raise
451
+ except Exception:
452
+ target_session.rollback()
453
+ LOGGER.exception("Migration failed; rolled back target transaction.")
454
+ raise
455
+ finally:
456
+ source_session.close()
457
+ target_session.close()
458
+ source_engine.dispose()
459
+ target_engine.dispose()
460
+
461
+
462
+ def parse_args() -> argparse.Namespace:
463
+ parser = argparse.ArgumentParser(description="Migrate SQLite users/documents/chat history to Supabase Postgres.")
464
+ parser.add_argument(
465
+ "--sqlite-path",
466
+ default="instance/users.db",
467
+ help="Path to the SQLite database file. Defaults to instance/users.db.",
468
+ )
469
+ parser.add_argument(
470
+ "--sqlite-url",
471
+ help="Full SQLite SQLAlchemy URL. Overrides --sqlite-path.",
472
+ )
473
+ parser.add_argument(
474
+ "--postgres-url",
475
+ default=os.getenv("SUPABASE_DB_URL") or os.getenv("POSTGRES_DATABASE_URL") or os.getenv("DATABASE_URL"),
476
+ help="Supabase/Postgres SQLAlchemy URL. Also read from SUPABASE_DB_URL, POSTGRES_DATABASE_URL, or DATABASE_URL.",
477
+ )
478
+ parser.add_argument(
479
+ "--no-create-tables",
480
+ action="store_true",
481
+ help="Do not create missing target tables before migrating.",
482
+ )
483
+ parser.add_argument(
484
+ "--dry-run",
485
+ action="store_true",
486
+ help="Run the migration and roll back the target transaction.",
487
+ )
488
+ parser.add_argument("--verbose", action="store_true", help="Enable debug logging.")
489
+ return parser.parse_args()
490
+
491
+
492
+ def main() -> int:
493
+ args = parse_args()
494
+ logging.basicConfig(
495
+ level=logging.DEBUG if args.verbose else logging.INFO,
496
+ format="%(levelname)s %(message)s",
497
+ )
498
+
499
+ postgres_url = args.postgres_url
500
+ if not postgres_url or postgres_url.startswith("sqlite"):
501
+ LOGGER.error("Provide a Supabase/Postgres URL with --postgres-url or SUPABASE_DB_URL.")
502
+ return 2
503
+
504
+ sqlite_url = args.sqlite_url or sqlite_url_from_path(args.sqlite_path)
505
+ stats = migrate(
506
+ sqlite_url=sqlite_url,
507
+ postgres_url=postgres_url,
508
+ create_tables=not args.no_create_tables,
509
+ dry_run=args.dry_run,
510
+ )
511
+
512
+ for table_name in sorted(set(stats.inserted) | set(stats.reused) | set(stats.skipped)):
513
+ LOGGER.info(
514
+ "%s: inserted=%s reused=%s skipped=%s",
515
+ table_name,
516
+ stats.inserted.get(table_name, 0),
517
+ stats.reused.get(table_name, 0),
518
+ stats.skipped.get(table_name, 0),
519
+ )
520
+ return 0
521
+
522
+
523
+ if __name__ == "__main__":
524
+ sys.exit(main())