File size: 4,114 Bytes
126cf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""

可插拔 DB 后端: SQLite (默认本地) / Cloudflare D1 (云原生)

=======================================================



选择机制: 当同时存在以下 3 个环境变量时, 启用 D1; 否则回退到本地 sqlite3:

  HKDY_D1_ACCOUNT     Cloudflare Account ID

  HKDY_D1_DB_ID       D1 database UUID

  HKDY_D1_TOKEN       API token (Workers D1 Edit 权限)



接口与 sqlite3 兼容:

  con = d1db.connect(path_or_ignored)

  con.execute(sql, params).fetchone() / .fetchall()

  con.commit()

  con.close()



异常: D1 错误统一抛 sqlite3.OperationalError, 让 app.py 的

      `except sqlite3.OperationalError` 无需改动.

"""

import json
import os
import sqlite3
import urllib.error
import urllib.request


D1_ACCOUNT = os.environ.get("HKDY_D1_ACCOUNT", "").strip()
D1_DB_ID   = os.environ.get("HKDY_D1_DB_ID", "").strip()
D1_TOKEN   = os.environ.get("HKDY_D1_TOKEN", "").strip()
USE_D1     = bool(D1_ACCOUNT and D1_DB_ID and D1_TOKEN)


class _D1Row(dict):
    """模拟 sqlite3.Row, 支持 dict[key] / dict[index] 双访问."""
    def __init__(self, mapping):
        super().__init__(mapping)
        self._keys = list(mapping.keys())

    def __getitem__(self, key):
        if isinstance(key, int):
            return dict.__getitem__(self, self._keys[key])
        return dict.__getitem__(self, key)

    def keys(self):
        return self._keys


class _D1Cursor:
    def __init__(self, results):
        self._rows = [_D1Row(r) for r in (results or [])]

    def fetchone(self):
        return self._rows[0] if self._rows else None

    def fetchall(self):
        return list(self._rows)

    def __iter__(self):
        return iter(self._rows)


class _D1Conn:
    row_factory = None  # 占位, 兼容 `con.row_factory = sqlite3.Row`

    def __init__(self):
        self._url = (
            f"https://api.cloudflare.com/client/v4/accounts/{D1_ACCOUNT}"
            f"/d1/database/{D1_DB_ID}/query"
        )
        self._headers = {
            "Authorization": f"Bearer {D1_TOKEN}",
            "Content-Type": "application/json",
        }

    def execute(self, sql, params=()):
        body = json.dumps({"sql": sql, "params": list(params)}).encode("utf-8")
        req = urllib.request.Request(
            self._url, data=body, method="POST", headers=self._headers,
        )
        try:
            with urllib.request.urlopen(req, timeout=15) as r:
                data = json.load(r)
        except urllib.error.HTTPError as e:
            body = e.read().decode("utf-8", "replace")
            raise sqlite3.OperationalError(f"D1 HTTP {e.code}: {body[:200]}")
        except Exception as e:
            raise sqlite3.OperationalError(f"D1 request failed: {e}")

        if not data.get("success"):
            errs = data.get("errors", [])
            msg = "; ".join(e.get("message", str(e)) for e in errs)
            # 让 app.py 的 `except sqlite3.OperationalError: pass` 正常接住
            # (比如 ALTER TABLE ADD COLUMN 遇到已存在列)
            raise sqlite3.OperationalError(f"D1: {msg}")

        # D1 返回: result=[{meta:..., results: [...], success: ...}]
        result_blocks = data.get("result") or []
        rows = result_blocks[0].get("results", []) if result_blocks else []
        return _D1Cursor(rows)

    def commit(self):
        # D1 每个 query 都是独立自动提交, 无 explicit commit
        pass

    def close(self):
        pass

    # context manager 兼容
    def __enter__(self):
        return self
    def __exit__(self, *a):
        self.close()


def connect(path_or_ignored):
    """drop-in 替换 sqlite3.connect. D1 模式下 path 参数会被忽略."""
    if USE_D1:
        c = _D1Conn()
        return c
    con = sqlite3.connect(path_or_ignored)
    con.row_factory = sqlite3.Row
    return con


def is_d1():
    return USE_D1


def backend_info():
    if USE_D1:
        return f"D1 (account={D1_ACCOUNT[:8]}..., db={D1_DB_ID[:8]}...)"
    return "SQLite (local)"