FUCAT's picture
Deploy grok2api to HF Spaces (Docker)
7e55e53
Raw
History Blame Contribute Delete
34 kB
"""Shared SQLAlchemy-based backend for MySQL and PostgreSQL.
Both dialects share the same table schema and query logic;
only the DDL fragments and upsert syntax differ.
"""
import json
import os
import ssl
from threading import Lock
from typing import Any
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import asyncio
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from app.platform.runtime.clock import now_ms
from ..commands import AccountPatch, AccountUpsert, BulkReplacePoolCommand, ListAccountsQuery
from ..enums import AccountStatus
from ..models import (
AccountChangeSet,
AccountMutationResult,
AccountPage,
AccountRecord,
RuntimeSnapshot,
)
from ..quota_defaults import default_quota_set
_TBL_ACCOUNTS = "accounts"
_TBL_META = "account_meta"
metadata = sa.MetaData()
accounts_table = sa.Table(
_TBL_ACCOUNTS,
metadata,
sa.Column("token", sa.String(512), primary_key=True),
sa.Column("pool", sa.Text, nullable=False, default="basic"),
sa.Column("status", sa.Text, nullable=False, default="active"),
sa.Column("created_at", sa.BigInteger, nullable=False),
sa.Column("updated_at", sa.BigInteger, nullable=False),
sa.Column("tags", sa.Text, nullable=False, default="[]"),
sa.Column("quota_auto", sa.Text, nullable=False, default="{}"),
sa.Column("quota_fast", sa.Text, nullable=False, default="{}"),
sa.Column("quota_expert", sa.Text, nullable=False, default="{}"),
sa.Column("quota_heavy", sa.Text, nullable=False, default="{}"),
sa.Column("quota_grok_4_3", sa.Text, nullable=False, default="{}"),
sa.Column("usage_use_count", sa.Integer, nullable=False, default=0),
sa.Column("usage_fail_count", sa.Integer, nullable=False, default=0),
sa.Column("usage_sync_count", sa.Integer, nullable=False, default=0),
sa.Column("last_use_at", sa.BigInteger),
sa.Column("last_fail_at", sa.BigInteger),
sa.Column("last_fail_reason", sa.Text),
sa.Column("last_sync_at", sa.BigInteger),
sa.Column("last_clear_at", sa.BigInteger),
sa.Column("state_reason", sa.Text),
sa.Column("deleted_at", sa.BigInteger),
sa.Column("ext", sa.Text, nullable=False, default="{}"),
sa.Column("revision", sa.BigInteger, nullable=False, default=0),
)
meta_table = sa.Table(
_TBL_META,
metadata,
sa.Column("key", sa.String(128), primary_key=True),
sa.Column("value", sa.Text, nullable=False),
)
_SQL_SSL_MODE_PARAM_KEYS = ("sslmode", "ssl-mode", "ssl")
_PG_SSL_CERT_PARAM_KEYS = ("sslrootcert", "sslcert", "sslkey")
_PG_SSL_UNSUPPORTED_PARAM_KEYS = (
"sslcrl",
"sslpassword",
"sslnegotiation",
"ssl_min_protocol_version",
"ssl_max_protocol_version",
)
_PG_SSL_QUERY_PARAM_KEYS = (
*_SQL_SSL_MODE_PARAM_KEYS,
*_PG_SSL_CERT_PARAM_KEYS,
*_PG_SSL_UNSUPPORTED_PARAM_KEYS,
)
_MYSQL_SSL_CERT_PARAM_KEYS = ("ssl-ca", "ssl-capath", "ssl-cert", "ssl-key")
_MYSQL_SSL_OPTIONAL_PARAM_KEYS = ("ssl-check-hostname", "ssl-cipher")
_MYSQL_SSL_QUERY_PARAM_KEYS = (
*_SQL_SSL_MODE_PARAM_KEYS,
*_MYSQL_SSL_CERT_PARAM_KEYS,
*_MYSQL_SSL_OPTIONAL_PARAM_KEYS,
)
_SSL_BOOL_TRUE = {"1", "true", "yes", "on"}
_SSL_BOOL_FALSE = {"0", "false", "no", "off"}
_PG_SSL_MODE_ALIASES: dict[str, str] = {
"disable": "disable",
"disabled": "disable",
"false": "disable",
"0": "disable",
"no": "disable",
"off": "disable",
"prefer": "prefer",
"preferred": "prefer",
"allow": "allow",
"require": "require",
"required": "require",
"true": "require",
"1": "require",
"yes": "require",
"on": "require",
"verify-ca": "verify-ca",
"verify_ca": "verify-ca",
"verify-full": "verify-full",
"verify_full": "verify-full",
"verify-identity": "verify-full",
"verify_identity": "verify-full",
}
_MYSQL_SSL_MODE_ALIASES: dict[str, str] = {
"disable": "disabled",
"disabled": "disabled",
"false": "disabled",
"0": "disabled",
"no": "disabled",
"off": "disabled",
"prefer": "preferred",
"preferred": "preferred",
"allow": "preferred",
"require": "required",
"required": "required",
"true": "required",
"1": "required",
"yes": "required",
"on": "required",
"verify-ca": "verify_ca",
"verify_ca": "verify_ca",
"verify-full": "verify_identity",
"verify_full": "verify_identity",
"verify-identity": "verify_identity",
"verify_identity": "verify_identity",
}
_ENGINE_CACHE_LOCK = Lock()
_ENGINE_CACHE: dict[tuple[str, str, str], AsyncEngine] = {}
_ENGINE_KEYS_BY_ID: dict[int, set[tuple[str, str, str]]] = {}
def _normalize_sql_url(dialect: str, url: str) -> str:
"""Rewrite SQL URLs to the async SQLAlchemy dialect form."""
if not url or "://" not in url:
return url
if dialect == "mysql":
if url.startswith("mysql://"):
return f"mysql+aiomysql://{url[len('mysql://') :]}"
if url.startswith("mariadb://"):
return f"mysql+aiomysql://{url[len('mariadb://') :]}"
if url.startswith("mariadb+aiomysql://"):
return f"mysql+aiomysql://{url[len('mariadb+aiomysql://') :]}"
return url
if url.startswith("postgres://"):
return f"postgresql+asyncpg://{url[len('postgres://') :]}"
if url.startswith("postgresql://"):
return f"postgresql+asyncpg://{url[len('postgresql://') :]}"
if url.startswith("pgsql://"):
return f"postgresql+asyncpg://{url[len('pgsql://') :]}"
return url
def _get_env_int(name: str, default: int, *, minimum: int = 0) -> int:
raw = os.getenv(name, "").strip()
if not raw:
return default
try:
return max(minimum, int(raw))
except ValueError:
return default
def _normalize_ssl_mode(dialect: str, raw_mode: str) -> str:
if not raw_mode:
raise ValueError("SSL mode cannot be empty")
mode = raw_mode.strip().lower().replace(" ", "")
if dialect == "mysql":
canonical = _MYSQL_SSL_MODE_ALIASES.get(mode)
else:
canonical = _PG_SSL_MODE_ALIASES.get(mode)
if not canonical:
raise ValueError(f"Unsupported SSL mode {raw_mode!r} for SQL dialect {dialect!r}")
return canonical
def _has_ssl_options(options: dict[str, str], keys: tuple[str, ...]) -> bool:
return any(options.get(key) for key in keys)
def _parse_ssl_bool(name: str, raw_value: str | None) -> bool | None:
if raw_value is None:
return None
value = raw_value.strip().lower()
if not value:
return None
if value in _SSL_BOOL_TRUE:
return True
if value in _SSL_BOOL_FALSE:
return False
raise ValueError(f"Unsupported boolean value {raw_value!r} for SQL SSL option {name!r}")
def _extract_sql_ssl_options(
dialect: str,
url: str,
) -> tuple[str, dict[str, str]]:
parsed = urlparse(url)
ssl_query_keys = _PG_SSL_QUERY_PARAM_KEYS if dialect == "postgresql" else _MYSQL_SSL_QUERY_PARAM_KEYS
ssl_options: dict[str, str] = {}
filtered_query_items: list[tuple[str, str]] = []
ssl_param_keys = {key.lower() for key in ssl_query_keys}
for key, value in parse_qsl(parsed.query, keep_blank_values=True):
key_lower = key.lower()
if key_lower in ssl_param_keys:
normalized_value = value.strip()
if key_lower not in ssl_options:
ssl_options[key_lower] = normalized_value
continue
filtered_query_items.append((key, value))
cleaned_url = urlunparse(
parsed._replace(query=urlencode(filtered_query_items, doseq=True))
)
return cleaned_url, ssl_options
def _resolve_ssl_mode(dialect: str, ssl_options: dict[str, str]) -> str | None:
raw_ssl_mode = next(
(ssl_options.get(key) for key in _SQL_SSL_MODE_PARAM_KEYS if ssl_options.get(key)),
None,
)
if raw_ssl_mode:
return _normalize_ssl_mode(dialect, raw_ssl_mode)
if ssl_options:
if dialect == "postgresql":
raise ValueError("PostgreSQL SSL URL parameters require sslmode to be set explicitly")
raise ValueError("MySQL SSL URL parameters require ssl-mode to be set explicitly")
return None
def _validate_pg_ssl_options(mode: str | None, ssl_options: dict[str, str]) -> None:
unsupported = [
key for key in _PG_SSL_UNSUPPORTED_PARAM_KEYS
if ssl_options.get(key)
]
if unsupported:
joined = ", ".join(sorted(unsupported))
raise ValueError(f"Unsupported PostgreSQL SSL URL parameter(s): {joined}")
if mode == "disable" and _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS):
raise ValueError("PostgreSQL SSL certificate parameters cannot be used with sslmode=disable")
if mode in {"allow", "prefer"} and _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS):
raise ValueError("PostgreSQL sslmode=allow/prefer is not supported with certificate URL parameters")
def _build_pg_ssl_context(mode: str, ssl_options: dict[str, str]) -> ssl.SSLContext:
sslrootcert = ssl_options.get("sslrootcert") or None
sslcert = ssl_options.get("sslcert") or None
sslkey = ssl_options.get("sslkey") or None
ctx = ssl.create_default_context(cafile=sslrootcert)
if mode == "require":
ctx.check_hostname = False
if not sslrootcert:
ctx.verify_mode = ssl.CERT_NONE
elif mode == "verify-ca":
ctx.check_hostname = False
else:
ctx.check_hostname = True
if sslcert:
ctx.load_cert_chain(certfile=sslcert, keyfile=sslkey or None)
elif sslkey:
raise ValueError("PostgreSQL sslkey requires sslcert")
return ctx
def _build_mysql_ssl_context(mode: str, ssl_options: dict[str, str]) -> ssl.SSLContext | None:
if mode == "disabled":
if _has_ssl_options(ssl_options, _MYSQL_SSL_CERT_PARAM_KEYS):
raise ValueError("MySQL SSL certificate parameters cannot be used with ssl-mode=disabled")
return None
if mode == "preferred":
raise ValueError("MySQL ssl-mode=allow/prefer is not supported by aiomysql")
ssl_ca = ssl_options.get("ssl-ca") or None
ssl_capath = ssl_options.get("ssl-capath") or None
ssl_cert = ssl_options.get("ssl-cert") or None
ssl_key = ssl_options.get("ssl-key") or None
ssl_check_hostname = _parse_ssl_bool("ssl-check-hostname", ssl_options.get("ssl-check-hostname"))
ssl_cipher = ssl_options.get("ssl-cipher") or None
ctx = ssl.create_default_context(cafile=ssl_ca, capath=ssl_capath)
if mode in ("preferred", "required"):
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
elif mode == "verify_ca":
ctx.check_hostname = False
else:
ctx.check_hostname = True
if ssl_check_hostname is not None:
if mode == "required" and ssl_check_hostname:
raise ValueError("MySQL ssl-check-hostname=true requires ssl-mode=verify_identity")
ctx.check_hostname = ssl_check_hostname
if ssl_cipher:
ctx.set_ciphers(ssl_cipher)
if ssl_cert:
ctx.load_cert_chain(certfile=ssl_cert, keyfile=ssl_key or None)
elif ssl_key:
raise ValueError("MySQL ssl-key requires ssl-cert")
return ctx
def _build_sql_connect_args(
dialect: str,
ssl_options: dict[str, str],
) -> dict[str, Any] | None:
mode = _resolve_ssl_mode(dialect, ssl_options)
if not mode:
return None
if dialect == "mysql":
ctx = _build_mysql_ssl_context(mode, ssl_options)
return {"ssl": ctx} if ctx is not None else None
_validate_pg_ssl_options(mode, ssl_options)
if mode == "disable":
return None
# asyncpg does not accept ssl= as a plain string (e.g. "require").
# Always build a proper ssl.SSLContext so the driver can use it directly.
return {"ssl": _build_pg_ssl_context(mode, ssl_options)}
def _prepare_sql_url_and_connect_args(
dialect: str,
url: str,
) -> tuple[str, dict[str, Any] | None]:
"""Strip SSL query params from the URL and translate them to connect_args."""
normalized_url = _normalize_sql_url(dialect, url)
if "://" not in normalized_url:
return normalized_url, None
cleaned_url, ssl_options = _extract_sql_ssl_options(dialect, normalized_url)
return cleaned_url, _build_sql_connect_args(dialect, ssl_options)
def _is_serverless() -> bool:
"""Detect common serverless environments (Vercel, AWS Lambda, etc.)."""
return bool(
os.getenv("VERCEL")
or os.getenv("AWS_LAMBDA_FUNCTION_NAME")
or os.getenv("FUNCTIONS_WORKER_RUNTIME") # Azure Functions
)
def _sql_engine_kwargs(connect_args: dict[str, Any] | None) -> dict[str, Any]:
# In serverless environments each function instance is short-lived and may
# run concurrently. Keep pools small to avoid exhausting DB connections.
serverless = _is_serverless()
kwargs: dict[str, Any] = {
"pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 1 if serverless else 5, minimum=1),
"max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 2 if serverless else 10, minimum=0),
"pool_timeout": _get_env_int("ACCOUNT_SQL_POOL_TIMEOUT", 30, minimum=1),
"pool_recycle": _get_env_int("ACCOUNT_SQL_POOL_RECYCLE", 1800, minimum=0),
"pool_pre_ping": True,
"pool_use_lifo": True,
}
if connect_args:
kwargs["connect_args"] = connect_args
return kwargs
def _get_or_create_engine(
cache_key: tuple[str, str, str],
normalized_url: str,
connect_args: dict[str, Any] | None,
) -> AsyncEngine:
with _ENGINE_CACHE_LOCK:
engine = _ENGINE_CACHE.get(cache_key)
if engine is not None:
return engine
engine = create_async_engine(normalized_url, **_sql_engine_kwargs(connect_args))
_ENGINE_CACHE[cache_key] = engine
_ENGINE_KEYS_BY_ID.setdefault(id(engine), set()).add(cache_key)
return engine
def _evict_cached_engine(engine: AsyncEngine) -> None:
with _ENGINE_CACHE_LOCK:
for key in _ENGINE_KEYS_BY_ID.pop(id(engine), set()):
if _ENGINE_CACHE.get(key) is engine:
_ENGINE_CACHE.pop(key, None)
def _row_to_record(row: Any) -> AccountRecord:
d = dict(row._mapping)
d["tags"] = json.loads(d.get("tags") or "[]")
heavy_raw = d.pop("quota_heavy", "{}") or "{}"
grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}"
heavy_dict = json.loads(heavy_raw)
grok_4_3_dict = json.loads(grok_4_3_raw)
d["quota"] = {
"auto": json.loads(d.pop("quota_auto", "{}") or "{}"),
"fast": json.loads(d.pop("quota_fast", "{}") or "{}"),
"expert": json.loads(d.pop("quota_expert", "{}") or "{}"),
**({"heavy": heavy_dict} if heavy_dict else {}),
**({"grok_4_3": grok_4_3_dict} if grok_4_3_dict else {}),
}
d["ext"] = json.loads(d.get("ext") or "{}")
return AccountRecord.model_validate(d)
class SqlAccountRepository:
"""Async SQLAlchemy-based repository for MySQL / PostgreSQL."""
def __init__(
self,
engine: AsyncEngine,
*,
dialect: str = "mysql",
dispose_engine: bool = True,
) -> None:
self._engine = engine
self._dialect = dialect # "mysql" | "postgresql"
self._session = async_sessionmaker(engine, expire_on_commit=False)
self._dispose_engine = dispose_engine
self._initialized = False
self._init_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Revision helpers (run inside a transaction)
# ------------------------------------------------------------------
async def _bump_revision(self, conn: Any) -> int:
await conn.execute(
meta_table.update()
.where(meta_table.c.key == "revision")
.values(value=sa.cast(
sa.cast(meta_table.c.value, sa.BigInteger) + 1, sa.Text
))
)
row = await conn.execute(
sa.select(meta_table.c.value).where(meta_table.c.key == "revision")
)
return int(row.scalar())
async def _get_revision(self, conn: Any) -> int:
row = await conn.execute(
sa.select(meta_table.c.value).where(meta_table.c.key == "revision")
)
v = row.scalar()
return int(v) if v else 0
# ------------------------------------------------------------------
# Upsert — dialect-specific
# ------------------------------------------------------------------
def _build_upsert(self, row: dict[str, Any]):
if self._dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert
stmt = insert(accounts_table).values(**row)
# On conflict, update all columns except token and created_at.
update_cols = {k: stmt.excluded[k] for k in row if k not in ("token", "created_at")}
return stmt.on_conflict_do_update(index_elements=["token"], set_=update_cols)
else:
# MySQL
from sqlalchemy.dialects.mysql import insert
stmt = insert(accounts_table).values(**row)
update_cols = {k: stmt.inserted[k] for k in row if k not in ("token", "created_at")}
return stmt.on_duplicate_key_update(**update_cols)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def _ensure_initialized(self) -> None:
"""Idempotent: create tables + seed revision row if not already done.
Safe to call on every request — short-circuits after first success so
repeated calls cost only an asyncio lock check. This allows the
repository to self-initialise even when the ASGI lifespan is not
executed (e.g. Vercel serverless cold-starts).
"""
if self._initialized:
return
async with self._init_lock:
if self._initialized:
return
await self._do_initialize()
self._initialized = True
async def _do_initialize(self) -> None:
async with self._engine.begin() as conn:
await conn.run_sync(metadata.create_all)
await self._ensure_columns(conn)
# Seed revision row.
if self._dialect == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
await conn.execute(
pg_insert(meta_table)
.values(key="revision", value="0")
.on_conflict_do_nothing()
)
else:
from sqlalchemy.dialects.mysql import insert as my_insert
await conn.execute(
my_insert(meta_table)
.values(key="revision", value="0")
.on_duplicate_key_update(value="0")
)
async def _ensure_columns(self, conn: Any) -> None:
"""Idempotent ALTER TABLE migrations for columns added after the initial schema."""
existing = await self._table_columns(conn, _TBL_ACCOUNTS)
if "quota_grok_4_3" not in existing:
if self._dialect == "mysql":
# MySQL forbids DEFAULT values on TEXT/BLOB columns;
# add as nullable, backfill, then promote to NOT NULL.
await conn.exec_driver_sql(
f"ALTER TABLE {_TBL_ACCOUNTS} "
f"ADD COLUMN quota_grok_4_3 TEXT"
)
await conn.exec_driver_sql(
f"UPDATE {_TBL_ACCOUNTS} "
f"SET quota_grok_4_3 = '{{}}' "
f"WHERE quota_grok_4_3 IS NULL"
)
await conn.exec_driver_sql(
f"ALTER TABLE {_TBL_ACCOUNTS} "
f"MODIFY COLUMN quota_grok_4_3 TEXT NOT NULL"
)
else:
await conn.exec_driver_sql(
f"ALTER TABLE {_TBL_ACCOUNTS} "
f"ADD COLUMN quota_grok_4_3 TEXT NOT NULL DEFAULT '{{}}'"
)
async def _table_columns(self, conn: Any, table: str) -> set[str]:
if self._dialect == "postgresql":
rows = await conn.execute(
sa.text(
"SELECT column_name FROM information_schema.columns "
"WHERE table_name = :t"
),
{"t": table},
)
else:
rows = await conn.execute(
sa.text(
"SELECT COLUMN_NAME FROM information_schema.columns "
"WHERE table_schema = DATABASE() AND table_name = :t"
),
{"t": table},
)
return {str(r[0]).lower() for r in rows.fetchall()}
async def initialize(self) -> None:
await self._ensure_initialized()
async def get_revision(self) -> int:
await self._ensure_initialized()
async with self._engine.connect() as conn:
return await self._get_revision(conn)
async def runtime_snapshot(self) -> RuntimeSnapshot:
await self._ensure_initialized()
async with self._engine.connect() as conn:
rev = await self._get_revision(conn)
rows = (await conn.execute(
sa.select(accounts_table).where(accounts_table.c.deleted_at.is_(None))
)).fetchall()
return RuntimeSnapshot(revision=rev, items=[_row_to_record(r) for r in rows])
async def scan_changes(
self,
since_revision: int,
*,
limit: int = 5000,
) -> AccountChangeSet:
await self._ensure_initialized()
async with self._engine.connect() as conn:
rev = await self._get_revision(conn)
rows = (await conn.execute(
sa.select(accounts_table)
.where(accounts_table.c.revision > since_revision)
.order_by(accounts_table.c.revision)
.limit(limit)
)).fetchall()
items: list[AccountRecord] = []
deleted: list[str] = []
for row in rows:
r = _row_to_record(row)
if r.is_deleted():
deleted.append(r.token)
else:
items.append(r)
return AccountChangeSet(
revision=rev,
items=items,
deleted_tokens=deleted,
has_more=len(rows) == limit,
)
async def upsert_accounts(
self,
items: list[AccountUpsert],
) -> AccountMutationResult:
if not items:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
count = 0
for item in items:
try:
token = AccountRecord.model_validate({"token": item.token, "pool": item.pool}).token
except Exception:
continue
pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic"
qs = default_quota_set(pool)
row = {
"token": token,
"pool": pool,
"status": "active",
"created_at": ts,
"updated_at": ts,
"deleted_at": None, # clear soft-delete on re-import
"tags": json.dumps(item.tags),
"quota_auto": json.dumps(qs.auto.to_dict()),
"quota_fast": json.dumps(qs.fast.to_dict()),
"quota_expert": json.dumps(qs.expert.to_dict()),
"quota_heavy": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}",
"quota_grok_4_3": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}",
"usage_use_count": 0,
"usage_fail_count": 0,
"usage_sync_count": 0,
"ext": json.dumps(item.ext),
"revision": rev,
}
await conn.execute(self._build_upsert(row))
count += 1
return AccountMutationResult(upserted=count, revision=rev)
async def patch_accounts(
self,
patches: list[AccountPatch],
) -> AccountMutationResult:
if not patches:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
count = 0
for patch in patches:
row = (await conn.execute(
sa.select(accounts_table).where(accounts_table.c.token == patch.token)
)).fetchone()
if row is None:
continue
record = _row_to_record(row)
updates: dict[str, Any] = {"updated_at": ts, "revision": rev}
if patch.pool is not None:
updates["pool"] = patch.pool
if patch.status is not None:
updates["status"] = patch.status.value
if patch.state_reason is not None:
updates["state_reason"] = patch.state_reason
if patch.last_use_at is not None:
updates["last_use_at"] = patch.last_use_at
if patch.last_fail_at is not None:
updates["last_fail_at"] = patch.last_fail_at
if patch.last_fail_reason is not None:
updates["last_fail_reason"] = patch.last_fail_reason
if patch.last_sync_at is not None:
updates["last_sync_at"] = patch.last_sync_at
if patch.last_clear_at is not None:
updates["last_clear_at"] = patch.last_clear_at
if patch.quota_auto is not None:
updates["quota_auto"] = json.dumps(patch.quota_auto)
if patch.quota_fast is not None:
updates["quota_fast"] = json.dumps(patch.quota_fast)
if patch.quota_expert is not None:
updates["quota_expert"] = json.dumps(patch.quota_expert)
if patch.quota_heavy is not None:
updates["quota_heavy"] = json.dumps(patch.quota_heavy)
if patch.quota_grok_4_3 is not None:
updates["quota_grok_4_3"] = json.dumps(patch.quota_grok_4_3)
if patch.usage_use_delta is not None:
updates["usage_use_count"] = max(0, record.usage_use_count + patch.usage_use_delta)
if patch.usage_fail_delta is not None:
updates["usage_fail_count"] = max(0, record.usage_fail_count + patch.usage_fail_delta)
if patch.usage_sync_delta is not None:
updates["usage_sync_count"] = max(0, record.usage_sync_count + patch.usage_sync_delta)
tags = list(record.tags)
if patch.tags is not None:
tags = patch.tags
if patch.add_tags:
for t in patch.add_tags:
if t not in tags:
tags.append(t)
if patch.remove_tags:
tags = [t for t in tags if t not in patch.remove_tags]
updates["tags"] = json.dumps(tags)
ext = dict(record.ext)
if patch.ext_merge:
ext.update(patch.ext_merge)
if patch.clear_failures:
for k in ("cooldown_until", "cooldown_reason", "disabled_at",
"disabled_reason", "expired_at", "expired_reason",
"forbidden_strikes"):
ext.pop(k, None)
updates["status"] = AccountStatus.ACTIVE.value
updates["usage_fail_count"] = 0
updates["last_fail_at"] = None
updates["last_fail_reason"] = None
updates["state_reason"] = None
updates["ext"] = json.dumps(ext)
await conn.execute(
accounts_table.update()
.where(accounts_table.c.token == patch.token)
.values(**updates)
)
count += 1
return AccountMutationResult(patched=count, revision=rev)
async def delete_accounts(
self,
tokens: list[str],
) -> AccountMutationResult:
if not tokens:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
result = await conn.execute(
accounts_table.update()
.where(
accounts_table.c.token.in_(tokens),
accounts_table.c.deleted_at.is_(None),
)
.values(deleted_at=ts, updated_at=ts, revision=rev)
)
return AccountMutationResult(deleted=result.rowcount, revision=rev)
async def get_accounts(
self,
tokens: list[str],
) -> list[AccountRecord]:
if not tokens:
return []
await self._ensure_initialized()
async with self._engine.connect() as conn:
rows = (await conn.execute(
sa.select(accounts_table).where(accounts_table.c.token.in_(tokens))
)).fetchall()
return [_row_to_record(r) for r in rows]
async def list_accounts(
self,
query: ListAccountsQuery,
) -> AccountPage:
await self._ensure_initialized()
async with self._engine.connect() as conn:
stmt = sa.select(accounts_table)
if not query.include_deleted:
stmt = stmt.where(accounts_table.c.deleted_at.is_(None))
if query.pool:
stmt = stmt.where(accounts_table.c.pool == query.pool)
if query.status:
stmt = stmt.where(accounts_table.c.status == query.status.value)
total_row = (await conn.execute(
sa.select(sa.func.count()).select_from(stmt.subquery())
)).scalar()
total = int(total_row or 0)
sort_col = getattr(accounts_table.c, query.sort_by, accounts_table.c.updated_at)
if query.sort_desc:
stmt = stmt.order_by(sort_col.desc())
else:
stmt = stmt.order_by(sort_col.asc())
offset = (query.page - 1) * query.page_size
stmt = stmt.limit(query.page_size).offset(offset)
rows = (await conn.execute(stmt)).fetchall()
rev = await self._get_revision(conn)
return AccountPage(
items=[_row_to_record(r) for r in rows],
total=total,
page=query.page,
page_size=query.page_size,
total_pages=max(1, (total + query.page_size - 1) // query.page_size),
revision=rev,
)
async def replace_pool(
self,
command: BulkReplacePoolCommand,
) -> AccountMutationResult:
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
del_result = await conn.execute(
accounts_table.update()
.where(
accounts_table.c.pool == command.pool,
accounts_table.c.deleted_at.is_(None),
)
.values(deleted_at=ts, updated_at=ts, revision=rev)
)
deleted = del_result.rowcount
upserted_result = await self.upsert_accounts(command.upserts)
return AccountMutationResult(
upserted=upserted_result.upserted,
deleted=deleted,
revision=upserted_result.revision,
)
async def close(self) -> None:
"""Dispose the SQLAlchemy connection pool."""
if self._dispose_engine:
_evict_cached_engine(self._engine)
await self._engine.dispose()
def _engine_cache_key(dialect: str, normalized_url: str, connect_args: dict[str, Any] | None) -> tuple[str, str, str]:
"""Build a stable cache key from the normalized URL and connect args."""
args_key = str(sorted(connect_args.items(), key=lambda kv: kv[0])) if connect_args else ""
return (dialect, normalized_url, args_key)
def create_mysql_engine(url: str) -> AsyncEngine:
"""Create an async SQLAlchemy engine for MySQL."""
normalized_url, connect_args = _prepare_sql_url_and_connect_args("mysql", (url or "").strip())
return _get_or_create_engine(_engine_cache_key("mysql", normalized_url, connect_args), normalized_url, connect_args)
def create_pgsql_engine(url: str) -> AsyncEngine:
"""Create an async SQLAlchemy engine for PostgreSQL."""
normalized_url, connect_args = _prepare_sql_url_and_connect_args("postgresql", (url or "").strip())
return _get_or_create_engine(_engine_cache_key("postgresql", normalized_url, connect_args), normalized_url, connect_args)
__all__ = ["SqlAccountRepository", "create_mysql_engine", "create_pgsql_engine"]