| """
|
| 可插拔 DB 后端: SQLite (默认本地) / Cloudflare D1 (云原生)
|
| =======================================================
|
|
|
| 选择机制: 当同时存在以下 3 个环境变量时, 启用 D1; 否则回退到本地 sqlite3:
|
| HKDY_D1_ACCOUNT Cloudflare Account ID
|
| HKDY_D1_DB_ID D1 database UUID
|
| HKDY_D1_TOKEN API token (Workers D1 Edit 权限)
|
|
|
| 接口与 sqlite3 兼容:
|
| con = d1db.connect(path_or_ignored)
|
| con.execute(sql, params).fetchone() / .fetchall()
|
| con.commit()
|
| con.close()
|
|
|
| 异常: D1 错误统一抛 sqlite3.OperationalError, 让 app.py 的
|
| `except sqlite3.OperationalError` 无需改动.
|
| """
|
|
|
| import json
|
| import os
|
| import sqlite3
|
| import urllib.error
|
| import urllib.request
|
|
|
|
|
| D1_ACCOUNT = os.environ.get("HKDY_D1_ACCOUNT", "").strip()
|
| D1_DB_ID = os.environ.get("HKDY_D1_DB_ID", "").strip()
|
| D1_TOKEN = os.environ.get("HKDY_D1_TOKEN", "").strip()
|
| USE_D1 = bool(D1_ACCOUNT and D1_DB_ID and D1_TOKEN)
|
|
|
|
|
| class _D1Row(dict):
|
| """模拟 sqlite3.Row, 支持 dict[key] / dict[index] 双访问."""
|
| def __init__(self, mapping):
|
| super().__init__(mapping)
|
| self._keys = list(mapping.keys())
|
|
|
| def __getitem__(self, key):
|
| if isinstance(key, int):
|
| return dict.__getitem__(self, self._keys[key])
|
| return dict.__getitem__(self, key)
|
|
|
| def keys(self):
|
| return self._keys
|
|
|
|
|
| class _D1Cursor:
|
| def __init__(self, results):
|
| self._rows = [_D1Row(r) for r in (results or [])]
|
|
|
| def fetchone(self):
|
| return self._rows[0] if self._rows else None
|
|
|
| def fetchall(self):
|
| return list(self._rows)
|
|
|
| def __iter__(self):
|
| return iter(self._rows)
|
|
|
|
|
| class _D1Conn:
|
| row_factory = None
|
|
|
| def __init__(self):
|
| self._url = (
|
| f"https://api.cloudflare.com/client/v4/accounts/{D1_ACCOUNT}"
|
| f"/d1/database/{D1_DB_ID}/query"
|
| )
|
| self._headers = {
|
| "Authorization": f"Bearer {D1_TOKEN}",
|
| "Content-Type": "application/json",
|
| }
|
|
|
| def execute(self, sql, params=()):
|
| body = json.dumps({"sql": sql, "params": list(params)}).encode("utf-8")
|
| req = urllib.request.Request(
|
| self._url, data=body, method="POST", headers=self._headers,
|
| )
|
| try:
|
| with urllib.request.urlopen(req, timeout=15) as r:
|
| data = json.load(r)
|
| except urllib.error.HTTPError as e:
|
| body = e.read().decode("utf-8", "replace")
|
| raise sqlite3.OperationalError(f"D1 HTTP {e.code}: {body[:200]}")
|
| except Exception as e:
|
| raise sqlite3.OperationalError(f"D1 request failed: {e}")
|
|
|
| if not data.get("success"):
|
| errs = data.get("errors", [])
|
| msg = "; ".join(e.get("message", str(e)) for e in errs)
|
|
|
|
|
| raise sqlite3.OperationalError(f"D1: {msg}")
|
|
|
|
|
| result_blocks = data.get("result") or []
|
| rows = result_blocks[0].get("results", []) if result_blocks else []
|
| return _D1Cursor(rows)
|
|
|
| def commit(self):
|
|
|
| pass
|
|
|
| def close(self):
|
| pass
|
|
|
|
|
| def __enter__(self):
|
| return self
|
| def __exit__(self, *a):
|
| self.close()
|
|
|
|
|
| def connect(path_or_ignored):
|
| """drop-in 替换 sqlite3.connect. D1 模式下 path 参数会被忽略."""
|
| if USE_D1:
|
| c = _D1Conn()
|
| return c
|
| con = sqlite3.connect(path_or_ignored)
|
| con.row_factory = sqlite3.Row
|
| return con
|
|
|
|
|
| def is_d1():
|
| return USE_D1
|
|
|
|
|
| def backend_info():
|
| if USE_D1:
|
| return f"D1 (account={D1_ACCOUNT[:8]}..., db={D1_DB_ID[:8]}...)"
|
| return "SQLite (local)"
|
|
|