File size: 3,913 Bytes
353a253
a13a754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353a253
 
 
a13a754
353a253
 
a13a754
353a253
a13a754
353a253
a13a754
353a253
 
a13a754
233e17d
 
 
 
 
a13a754
 
233e17d
 
 
a13a754
 
 
233e17d
a13a754
353a253
a13a754
 
 
233e17d
 
 
 
 
 
353a253
 
 
 
 
 
a13a754
353a253
 
 
 
 
 
 
 
 
 
a13a754
 
353a253
 
 
 
 
 
 
a13a754
353a253
 
a13a754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353a253
a13a754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353a253
 
a13a754
 
 
353a253
 
a13a754
 
 
 
 
 
 
233e17d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
db.py β€” Supabase (PostgreSQL) metadata store.
Table: files

Before first run, create the table in Supabase SQL Editor:

    CREATE TABLE IF NOT EXISTS files (
        id          UUID DEFAULT gen_random_uuid() PRIMARY KEY,
        file_id     TEXT UNIQUE NOT NULL,
        filename    TEXT NOT NULL,
        mime_type   TEXT NOT NULL,
        size_bytes  BIGINT NOT NULL,
        tg_message_id BIGINT NOT NULL,
        tg_file_id  TEXT,
        public_url  TEXT NOT NULL,
        custom_path TEXT UNIQUE,
        uploaded_at TIMESTAMPTZ DEFAULT now()
    );
"""

import os
from datetime import datetime, timezone
from typing import Optional

from supabase import create_client, Client

_supabase: Optional[Client] = None

TABLE = "files"


def _get_client() -> Client:
    """
    FIX: env vars are now read INSIDE this function (not at module level).
    This ensures load_dotenv() has already run before we read them,
    and avoids caching empty strings at import time.
    """
    global _supabase
    if _supabase is None:
        url = os.getenv("SUPABASE_URL", "")
        key = os.getenv("SUPABASE_KEY", "")
        if not url or not key:
            raise RuntimeError(
                "SUPABASE_URL and SUPABASE_KEY must be set in environment / .env"
            )
        _supabase = create_client(url, key)
    return _supabase


def init_db():
    """Verify Supabase connection by performing a lightweight query."""
    try:
        client = _get_client()
        client.table(TABLE).select("file_id").limit(1).execute()
    except Exception as exc:
        # Re-raise with a clearer message so _startup() can log it properly
        raise RuntimeError(f"Supabase connection failed: {exc}") from exc


# ──────────────────────────────────────────────────
#  CRUD
# ──────────────────────────────────────────────────

def save_file_record(
    *,
    file_id: str,
    filename: str,
    mime_type: str,
    size: int,
    tg_message_id: int,
    tg_file_id: str | None,
    public_url: str,
    custom_path: str | None = None,
):
    client = _get_client()
    row = {
        "file_id":       file_id,
        "filename":      filename,
        "mime_type":     mime_type,
        "size_bytes":    size,
        "tg_message_id": tg_message_id,
        "tg_file_id":    tg_file_id,
        "public_url":    public_url,
        "uploaded_at":   datetime.now(timezone.utc).isoformat(),
    }
    if custom_path:
        row["custom_path"] = custom_path
    client.table(TABLE).insert(row).execute()


def get_file_record(file_id: str) -> dict | None:
    client = _get_client()
    resp = (
        client.table(TABLE)
        .select("*")
        .eq("file_id", file_id)
        .limit(1)
        .execute()
    )
    if resp.data:
        return resp.data[0]
    return None


def get_file_by_custom_path(custom_path: str) -> dict | None:
    client = _get_client()
    resp = (
        client.table(TABLE)
        .select("*")
        .eq("custom_path", custom_path)
        .limit(1)
        .execute()
    )
    if resp.data:
        return resp.data[0]
    return None


def list_file_records(limit: int = 50, offset: int = 0) -> list[dict]:
    client = _get_client()
    resp = (
        client.table(TABLE)
        .select("*")
        .order("uploaded_at", desc=True)
        .range(offset, offset + limit - 1)
        .execute()
    )
    return resp.data or []


def delete_file_record(file_id: str):
    client = _get_client()
    client.table(TABLE).delete().eq("file_id", file_id).execute()


def count_files() -> int:
    client = _get_client()
    resp = (
        client.table(TABLE)
        .select("file_id", count="exact")
        .execute()
    )
    return resp.count or 0