File size: 3,827 Bytes
ac5551d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
database/connection.py β€” Async SQLite connection & migration bootstrap.
Single module responsible for DB lifecycle. All queries use this pool.
"""
from __future__ import annotations

import asyncio
from pathlib import Path

import aiosqlite

from config import settings

# Module-level connection (shared within the process)
_db: aiosqlite.Connection | None = None
_lock = asyncio.Lock()


async def get_db() -> aiosqlite.Connection:
    """Return the singleton async database connection."""
    global _db
    async with _lock:
        if _db is None:
            _db = await _open_connection()
    return _db


async def _open_connection() -> aiosqlite.Connection:
    settings.ensure_dirs()
    conn = await aiosqlite.connect(settings.db_path, check_same_thread=False)
    conn.row_factory = aiosqlite.Row
    await conn.execute("PRAGMA journal_mode=WAL")
    await conn.execute("PRAGMA foreign_keys=ON")
    await conn.execute("PRAGMA synchronous=NORMAL")
    await conn.execute("PRAGMA cache_size=-65536")   # 64 MB page cache
    await _run_migrations(conn)
    await conn.commit()
    return conn


async def _run_migrations(conn: aiosqlite.Connection) -> None:
    """Apply all schema files idempotently (CREATE IF NOT EXISTS)."""
    base = Path(__file__).parent
    
    # ── STEP 1: Ensure basic tables exist ──
    for schema_file in ["schema.sql", "dataset_schema.sql", "benchmark_schema.sql"]:
        path = base / schema_file
        if path.exists():
            sql = path.read_text(encoding="utf-8")
            await conn.executescript(sql)

    # ── STEP 2: Legacy Alterations ──
    # Check 'models' table for specific columns
    async with conn.execute("PRAGMA table_info(models)") as cur:
        cols = {r[1] for r in await cur.fetchall()}
    
    if cols: # only if table exists
        if "download_url" not in cols:
            await conn.execute("ALTER TABLE models ADD COLUMN download_url TEXT")

        if "active_version" not in cols:
            await conn.execute("ALTER TABLE models ADD COLUMN active_version TEXT")

        if "metrics" not in cols:
            await conn.execute("ALTER TABLE models ADD COLUMN metrics TEXT NOT NULL DEFAULT '{}' ")

    # Check 'datasets' table for new columns (e.g. active_version)
    async with conn.execute("PRAGMA table_info(datasets)") as cur:
        ds_cols = {r[1] for r in await cur.fetchall()}
    
    if ds_cols:
        if "active_version" not in ds_cols:
            await conn.execute("ALTER TABLE datasets ADD COLUMN active_version TEXT NOT NULL DEFAULT 'v1'")
        if "roboflow_id" not in ds_cols:
            await conn.execute("ALTER TABLE datasets ADD COLUMN roboflow_id TEXT")
        if "health_score" not in ds_cols:
            await conn.execute("ALTER TABLE datasets ADD COLUMN health_score INTEGER NOT NULL DEFAULT 0")

    # Check 'models' table for project_id
    async with conn.execute("PRAGMA table_info(models)") as cur:
        model_cols = {r[1] for r in await cur.fetchall()}
    
    if model_cols and "project_id" not in model_cols:
        await conn.execute("ALTER TABLE models ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE CASCADE")

    # Clean up any lingering temporary tables from failed legacy migrations
    # COMMIT is essential here to ensure background jobs see the clean state immediately
    # We use a try/except block to avoid "no such table" errors if the table is already gone
    try:
        await conn.execute("DROP TABLE IF EXISTS datasets_old")
    except:
        pass
    
    try:
        await conn.execute("DROP TABLE IF EXISTS dataset_jobs_old")
    except:
        pass

    await conn.commit()

async def close_db() -> None:
    global _db
    async with _lock:
        if _db is not None:
            await _db.close()
            _db = None