| """Shared SQLAlchemy-based backend for MySQL and PostgreSQL. |
| |
| Both dialects share the same table schema and query logic; |
| only the DDL fragments and upsert syntax differ. |
| """ |
|
|
| import json |
| import os |
| import ssl |
| from threading import Lock |
| from typing import Any |
| from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse |
|
|
| import asyncio |
|
|
| import sqlalchemy as sa |
| from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine |
|
|
| from app.platform.runtime.clock import now_ms |
| from ..commands import AccountPatch, AccountUpsert, BulkReplacePoolCommand, ListAccountsQuery |
| from ..enums import AccountStatus |
| from ..models import ( |
| AccountChangeSet, |
| AccountMutationResult, |
| AccountPage, |
| AccountRecord, |
| RuntimeSnapshot, |
| ) |
| from ..quota_defaults import default_quota_set |
|
|
| _TBL_ACCOUNTS = "accounts" |
| _TBL_META = "account_meta" |
|
|
| metadata = sa.MetaData() |
|
|
| accounts_table = sa.Table( |
| _TBL_ACCOUNTS, |
| metadata, |
| sa.Column("token", sa.String(512), primary_key=True), |
| sa.Column("pool", sa.Text, nullable=False, default="basic"), |
| sa.Column("status", sa.Text, nullable=False, default="active"), |
| sa.Column("created_at", sa.BigInteger, nullable=False), |
| sa.Column("updated_at", sa.BigInteger, nullable=False), |
| sa.Column("tags", sa.Text, nullable=False, default="[]"), |
| sa.Column("quota_auto", sa.Text, nullable=False, default="{}"), |
| sa.Column("quota_fast", sa.Text, nullable=False, default="{}"), |
| sa.Column("quota_expert", sa.Text, nullable=False, default="{}"), |
| sa.Column("quota_heavy", sa.Text, nullable=False, default="{}"), |
| sa.Column("quota_grok_4_3", sa.Text, nullable=False, default="{}"), |
| sa.Column("usage_use_count", sa.Integer, nullable=False, default=0), |
| sa.Column("usage_fail_count", sa.Integer, nullable=False, default=0), |
| sa.Column("usage_sync_count", sa.Integer, nullable=False, default=0), |
| sa.Column("last_use_at", sa.BigInteger), |
| sa.Column("last_fail_at", sa.BigInteger), |
| sa.Column("last_fail_reason", sa.Text), |
| sa.Column("last_sync_at", sa.BigInteger), |
| sa.Column("last_clear_at", sa.BigInteger), |
| sa.Column("state_reason", sa.Text), |
| sa.Column("deleted_at", sa.BigInteger), |
| sa.Column("ext", sa.Text, nullable=False, default="{}"), |
| sa.Column("revision", sa.BigInteger, nullable=False, default=0), |
| ) |
|
|
| meta_table = sa.Table( |
| _TBL_META, |
| metadata, |
| sa.Column("key", sa.String(128), primary_key=True), |
| sa.Column("value", sa.Text, nullable=False), |
| ) |
|
|
| _SQL_SSL_MODE_PARAM_KEYS = ("sslmode", "ssl-mode", "ssl") |
| _PG_SSL_CERT_PARAM_KEYS = ("sslrootcert", "sslcert", "sslkey") |
| _PG_SSL_UNSUPPORTED_PARAM_KEYS = ( |
| "sslcrl", |
| "sslpassword", |
| "sslnegotiation", |
| "ssl_min_protocol_version", |
| "ssl_max_protocol_version", |
| ) |
| _PG_SSL_QUERY_PARAM_KEYS = ( |
| *_SQL_SSL_MODE_PARAM_KEYS, |
| *_PG_SSL_CERT_PARAM_KEYS, |
| *_PG_SSL_UNSUPPORTED_PARAM_KEYS, |
| ) |
| _MYSQL_SSL_CERT_PARAM_KEYS = ("ssl-ca", "ssl-capath", "ssl-cert", "ssl-key") |
| _MYSQL_SSL_OPTIONAL_PARAM_KEYS = ("ssl-check-hostname", "ssl-cipher") |
| _MYSQL_SSL_QUERY_PARAM_KEYS = ( |
| *_SQL_SSL_MODE_PARAM_KEYS, |
| *_MYSQL_SSL_CERT_PARAM_KEYS, |
| *_MYSQL_SSL_OPTIONAL_PARAM_KEYS, |
| ) |
| _SSL_BOOL_TRUE = {"1", "true", "yes", "on"} |
| _SSL_BOOL_FALSE = {"0", "false", "no", "off"} |
|
|
| _PG_SSL_MODE_ALIASES: dict[str, str] = { |
| "disable": "disable", |
| "disabled": "disable", |
| "false": "disable", |
| "0": "disable", |
| "no": "disable", |
| "off": "disable", |
| "prefer": "prefer", |
| "preferred": "prefer", |
| "allow": "allow", |
| "require": "require", |
| "required": "require", |
| "true": "require", |
| "1": "require", |
| "yes": "require", |
| "on": "require", |
| "verify-ca": "verify-ca", |
| "verify_ca": "verify-ca", |
| "verify-full": "verify-full", |
| "verify_full": "verify-full", |
| "verify-identity": "verify-full", |
| "verify_identity": "verify-full", |
| } |
|
|
| _MYSQL_SSL_MODE_ALIASES: dict[str, str] = { |
| "disable": "disabled", |
| "disabled": "disabled", |
| "false": "disabled", |
| "0": "disabled", |
| "no": "disabled", |
| "off": "disabled", |
| "prefer": "preferred", |
| "preferred": "preferred", |
| "allow": "preferred", |
| "require": "required", |
| "required": "required", |
| "true": "required", |
| "1": "required", |
| "yes": "required", |
| "on": "required", |
| "verify-ca": "verify_ca", |
| "verify_ca": "verify_ca", |
| "verify-full": "verify_identity", |
| "verify_full": "verify_identity", |
| "verify-identity": "verify_identity", |
| "verify_identity": "verify_identity", |
| } |
|
|
| _ENGINE_CACHE_LOCK = Lock() |
| _ENGINE_CACHE: dict[tuple[str, str, str], AsyncEngine] = {} |
| _ENGINE_KEYS_BY_ID: dict[int, set[tuple[str, str, str]]] = {} |
|
|
|
|
| def _normalize_sql_url(dialect: str, url: str) -> str: |
| """Rewrite SQL URLs to the async SQLAlchemy dialect form.""" |
| if not url or "://" not in url: |
| return url |
| if dialect == "mysql": |
| if url.startswith("mysql://"): |
| return f"mysql+aiomysql://{url[len('mysql://') :]}" |
| if url.startswith("mariadb://"): |
| return f"mysql+aiomysql://{url[len('mariadb://') :]}" |
| if url.startswith("mariadb+aiomysql://"): |
| return f"mysql+aiomysql://{url[len('mariadb+aiomysql://') :]}" |
| return url |
| if url.startswith("postgres://"): |
| return f"postgresql+asyncpg://{url[len('postgres://') :]}" |
| if url.startswith("postgresql://"): |
| return f"postgresql+asyncpg://{url[len('postgresql://') :]}" |
| if url.startswith("pgsql://"): |
| return f"postgresql+asyncpg://{url[len('pgsql://') :]}" |
| return url |
|
|
|
|
| def _get_env_int(name: str, default: int, *, minimum: int = 0) -> int: |
| raw = os.getenv(name, "").strip() |
| if not raw: |
| return default |
| try: |
| return max(minimum, int(raw)) |
| except ValueError: |
| return default |
|
|
|
|
| def _normalize_ssl_mode(dialect: str, raw_mode: str) -> str: |
| if not raw_mode: |
| raise ValueError("SSL mode cannot be empty") |
|
|
| mode = raw_mode.strip().lower().replace(" ", "") |
| if dialect == "mysql": |
| canonical = _MYSQL_SSL_MODE_ALIASES.get(mode) |
| else: |
| canonical = _PG_SSL_MODE_ALIASES.get(mode) |
| if not canonical: |
| raise ValueError(f"Unsupported SSL mode {raw_mode!r} for SQL dialect {dialect!r}") |
| return canonical |
|
|
|
|
| def _has_ssl_options(options: dict[str, str], keys: tuple[str, ...]) -> bool: |
| return any(options.get(key) for key in keys) |
|
|
|
|
| def _parse_ssl_bool(name: str, raw_value: str | None) -> bool | None: |
| if raw_value is None: |
| return None |
|
|
| value = raw_value.strip().lower() |
| if not value: |
| return None |
| if value in _SSL_BOOL_TRUE: |
| return True |
| if value in _SSL_BOOL_FALSE: |
| return False |
| raise ValueError(f"Unsupported boolean value {raw_value!r} for SQL SSL option {name!r}") |
|
|
|
|
| def _extract_sql_ssl_options( |
| dialect: str, |
| url: str, |
| ) -> tuple[str, dict[str, str]]: |
| parsed = urlparse(url) |
| ssl_query_keys = _PG_SSL_QUERY_PARAM_KEYS if dialect == "postgresql" else _MYSQL_SSL_QUERY_PARAM_KEYS |
| ssl_options: dict[str, str] = {} |
| filtered_query_items: list[tuple[str, str]] = [] |
| ssl_param_keys = {key.lower() for key in ssl_query_keys} |
| for key, value in parse_qsl(parsed.query, keep_blank_values=True): |
| key_lower = key.lower() |
| if key_lower in ssl_param_keys: |
| normalized_value = value.strip() |
| if key_lower not in ssl_options: |
| ssl_options[key_lower] = normalized_value |
| continue |
| filtered_query_items.append((key, value)) |
|
|
| cleaned_url = urlunparse( |
| parsed._replace(query=urlencode(filtered_query_items, doseq=True)) |
| ) |
| return cleaned_url, ssl_options |
|
|
|
|
| def _resolve_ssl_mode(dialect: str, ssl_options: dict[str, str]) -> str | None: |
| raw_ssl_mode = next( |
| (ssl_options.get(key) for key in _SQL_SSL_MODE_PARAM_KEYS if ssl_options.get(key)), |
| None, |
| ) |
| if raw_ssl_mode: |
| return _normalize_ssl_mode(dialect, raw_ssl_mode) |
|
|
| if ssl_options: |
| if dialect == "postgresql": |
| raise ValueError("PostgreSQL SSL URL parameters require sslmode to be set explicitly") |
| raise ValueError("MySQL SSL URL parameters require ssl-mode to be set explicitly") |
| return None |
|
|
|
|
| def _validate_pg_ssl_options(mode: str | None, ssl_options: dict[str, str]) -> None: |
| unsupported = [ |
| key for key in _PG_SSL_UNSUPPORTED_PARAM_KEYS |
| if ssl_options.get(key) |
| ] |
| if unsupported: |
| joined = ", ".join(sorted(unsupported)) |
| raise ValueError(f"Unsupported PostgreSQL SSL URL parameter(s): {joined}") |
| if mode == "disable" and _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS): |
| raise ValueError("PostgreSQL SSL certificate parameters cannot be used with sslmode=disable") |
| if mode in {"allow", "prefer"} and _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS): |
| raise ValueError("PostgreSQL sslmode=allow/prefer is not supported with certificate URL parameters") |
|
|
|
|
| def _build_pg_ssl_context(mode: str, ssl_options: dict[str, str]) -> ssl.SSLContext: |
| sslrootcert = ssl_options.get("sslrootcert") or None |
| sslcert = ssl_options.get("sslcert") or None |
| sslkey = ssl_options.get("sslkey") or None |
|
|
| ctx = ssl.create_default_context(cafile=sslrootcert) |
| if mode == "require": |
| ctx.check_hostname = False |
| if not sslrootcert: |
| ctx.verify_mode = ssl.CERT_NONE |
| elif mode == "verify-ca": |
| ctx.check_hostname = False |
| else: |
| ctx.check_hostname = True |
|
|
| if sslcert: |
| ctx.load_cert_chain(certfile=sslcert, keyfile=sslkey or None) |
| elif sslkey: |
| raise ValueError("PostgreSQL sslkey requires sslcert") |
| return ctx |
|
|
|
|
| def _build_mysql_ssl_context(mode: str, ssl_options: dict[str, str]) -> ssl.SSLContext | None: |
| if mode == "disabled": |
| if _has_ssl_options(ssl_options, _MYSQL_SSL_CERT_PARAM_KEYS): |
| raise ValueError("MySQL SSL certificate parameters cannot be used with ssl-mode=disabled") |
| return None |
| if mode == "preferred": |
| raise ValueError("MySQL ssl-mode=allow/prefer is not supported by aiomysql") |
|
|
| ssl_ca = ssl_options.get("ssl-ca") or None |
| ssl_capath = ssl_options.get("ssl-capath") or None |
| ssl_cert = ssl_options.get("ssl-cert") or None |
| ssl_key = ssl_options.get("ssl-key") or None |
| ssl_check_hostname = _parse_ssl_bool("ssl-check-hostname", ssl_options.get("ssl-check-hostname")) |
| ssl_cipher = ssl_options.get("ssl-cipher") or None |
|
|
| ctx = ssl.create_default_context(cafile=ssl_ca, capath=ssl_capath) |
| if mode in ("preferred", "required"): |
| ctx.check_hostname = False |
| ctx.verify_mode = ssl.CERT_NONE |
| elif mode == "verify_ca": |
| ctx.check_hostname = False |
| else: |
| ctx.check_hostname = True |
|
|
| if ssl_check_hostname is not None: |
| if mode == "required" and ssl_check_hostname: |
| raise ValueError("MySQL ssl-check-hostname=true requires ssl-mode=verify_identity") |
| ctx.check_hostname = ssl_check_hostname |
| if ssl_cipher: |
| ctx.set_ciphers(ssl_cipher) |
| if ssl_cert: |
| ctx.load_cert_chain(certfile=ssl_cert, keyfile=ssl_key or None) |
| elif ssl_key: |
| raise ValueError("MySQL ssl-key requires ssl-cert") |
| return ctx |
|
|
|
|
| def _build_sql_connect_args( |
| dialect: str, |
| ssl_options: dict[str, str], |
| ) -> dict[str, Any] | None: |
| mode = _resolve_ssl_mode(dialect, ssl_options) |
| if not mode: |
| return None |
|
|
| if dialect == "mysql": |
| ctx = _build_mysql_ssl_context(mode, ssl_options) |
| return {"ssl": ctx} if ctx is not None else None |
|
|
| _validate_pg_ssl_options(mode, ssl_options) |
| if mode == "disable": |
| return None |
| |
| |
| return {"ssl": _build_pg_ssl_context(mode, ssl_options)} |
|
|
|
|
| def _prepare_sql_url_and_connect_args( |
| dialect: str, |
| url: str, |
| ) -> tuple[str, dict[str, Any] | None]: |
| """Strip SSL query params from the URL and translate them to connect_args.""" |
| normalized_url = _normalize_sql_url(dialect, url) |
| if "://" not in normalized_url: |
| return normalized_url, None |
|
|
| cleaned_url, ssl_options = _extract_sql_ssl_options(dialect, normalized_url) |
| return cleaned_url, _build_sql_connect_args(dialect, ssl_options) |
|
|
|
|
| def _is_serverless() -> bool: |
| """Detect common serverless environments (Vercel, AWS Lambda, etc.).""" |
| return bool( |
| os.getenv("VERCEL") |
| or os.getenv("AWS_LAMBDA_FUNCTION_NAME") |
| or os.getenv("FUNCTIONS_WORKER_RUNTIME") |
| ) |
|
|
|
|
| def _sql_engine_kwargs(connect_args: dict[str, Any] | None) -> dict[str, Any]: |
| |
| |
| serverless = _is_serverless() |
| kwargs: dict[str, Any] = { |
| "pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 1 if serverless else 5, minimum=1), |
| "max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 2 if serverless else 10, minimum=0), |
| "pool_timeout": _get_env_int("ACCOUNT_SQL_POOL_TIMEOUT", 30, minimum=1), |
| "pool_recycle": _get_env_int("ACCOUNT_SQL_POOL_RECYCLE", 1800, minimum=0), |
| "pool_pre_ping": True, |
| "pool_use_lifo": True, |
| } |
| if connect_args: |
| kwargs["connect_args"] = connect_args |
| return kwargs |
|
|
|
|
| def _get_or_create_engine( |
| cache_key: tuple[str, str, str], |
| normalized_url: str, |
| connect_args: dict[str, Any] | None, |
| ) -> AsyncEngine: |
| with _ENGINE_CACHE_LOCK: |
| engine = _ENGINE_CACHE.get(cache_key) |
| if engine is not None: |
| return engine |
|
|
| engine = create_async_engine(normalized_url, **_sql_engine_kwargs(connect_args)) |
| _ENGINE_CACHE[cache_key] = engine |
| _ENGINE_KEYS_BY_ID.setdefault(id(engine), set()).add(cache_key) |
| return engine |
|
|
|
|
| def _evict_cached_engine(engine: AsyncEngine) -> None: |
| with _ENGINE_CACHE_LOCK: |
| for key in _ENGINE_KEYS_BY_ID.pop(id(engine), set()): |
| if _ENGINE_CACHE.get(key) is engine: |
| _ENGINE_CACHE.pop(key, None) |
|
|
|
|
| def _row_to_record(row: Any) -> AccountRecord: |
| d = dict(row._mapping) |
| d["tags"] = json.loads(d.get("tags") or "[]") |
| heavy_raw = d.pop("quota_heavy", "{}") or "{}" |
| grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}" |
| heavy_dict = json.loads(heavy_raw) |
| grok_4_3_dict = json.loads(grok_4_3_raw) |
| d["quota"] = { |
| "auto": json.loads(d.pop("quota_auto", "{}") or "{}"), |
| "fast": json.loads(d.pop("quota_fast", "{}") or "{}"), |
| "expert": json.loads(d.pop("quota_expert", "{}") or "{}"), |
| **({"heavy": heavy_dict} if heavy_dict else {}), |
| **({"grok_4_3": grok_4_3_dict} if grok_4_3_dict else {}), |
| } |
| d["ext"] = json.loads(d.get("ext") or "{}") |
| return AccountRecord.model_validate(d) |
|
|
|
|
| class SqlAccountRepository: |
| """Async SQLAlchemy-based repository for MySQL / PostgreSQL.""" |
|
|
| def __init__( |
| self, |
| engine: AsyncEngine, |
| *, |
| dialect: str = "mysql", |
| dispose_engine: bool = True, |
| ) -> None: |
| self._engine = engine |
| self._dialect = dialect |
| self._session = async_sessionmaker(engine, expire_on_commit=False) |
| self._dispose_engine = dispose_engine |
| self._initialized = False |
| self._init_lock = asyncio.Lock() |
|
|
| |
| |
| |
|
|
| async def _bump_revision(self, conn: Any) -> int: |
| await conn.execute( |
| meta_table.update() |
| .where(meta_table.c.key == "revision") |
| .values(value=sa.cast( |
| sa.cast(meta_table.c.value, sa.BigInteger) + 1, sa.Text |
| )) |
| ) |
| row = await conn.execute( |
| sa.select(meta_table.c.value).where(meta_table.c.key == "revision") |
| ) |
| return int(row.scalar()) |
|
|
| async def _get_revision(self, conn: Any) -> int: |
| row = await conn.execute( |
| sa.select(meta_table.c.value).where(meta_table.c.key == "revision") |
| ) |
| v = row.scalar() |
| return int(v) if v else 0 |
|
|
| |
| |
| |
|
|
| def _build_upsert(self, row: dict[str, Any]): |
| if self._dialect == "postgresql": |
| from sqlalchemy.dialects.postgresql import insert |
| stmt = insert(accounts_table).values(**row) |
| |
| update_cols = {k: stmt.excluded[k] for k in row if k not in ("token", "created_at")} |
| return stmt.on_conflict_do_update(index_elements=["token"], set_=update_cols) |
| else: |
| |
| from sqlalchemy.dialects.mysql import insert |
| stmt = insert(accounts_table).values(**row) |
| update_cols = {k: stmt.inserted[k] for k in row if k not in ("token", "created_at")} |
| return stmt.on_duplicate_key_update(**update_cols) |
|
|
| |
| |
| |
|
|
| async def _ensure_initialized(self) -> None: |
| """Idempotent: create tables + seed revision row if not already done. |
| |
| Safe to call on every request — short-circuits after first success so |
| repeated calls cost only an asyncio lock check. This allows the |
| repository to self-initialise even when the ASGI lifespan is not |
| executed (e.g. Vercel serverless cold-starts). |
| """ |
| if self._initialized: |
| return |
| async with self._init_lock: |
| if self._initialized: |
| return |
| await self._do_initialize() |
| self._initialized = True |
|
|
| async def _do_initialize(self) -> None: |
| async with self._engine.begin() as conn: |
| await conn.run_sync(metadata.create_all) |
| await self._ensure_columns(conn) |
| |
| if self._dialect == "postgresql": |
| from sqlalchemy.dialects.postgresql import insert as pg_insert |
| await conn.execute( |
| pg_insert(meta_table) |
| .values(key="revision", value="0") |
| .on_conflict_do_nothing() |
| ) |
| else: |
| from sqlalchemy.dialects.mysql import insert as my_insert |
| await conn.execute( |
| my_insert(meta_table) |
| .values(key="revision", value="0") |
| .on_duplicate_key_update(value="0") |
| ) |
|
|
| async def _ensure_columns(self, conn: Any) -> None: |
| """Idempotent ALTER TABLE migrations for columns added after the initial schema.""" |
| existing = await self._table_columns(conn, _TBL_ACCOUNTS) |
| if "quota_grok_4_3" not in existing: |
| if self._dialect == "mysql": |
| |
| |
| await conn.exec_driver_sql( |
| f"ALTER TABLE {_TBL_ACCOUNTS} " |
| f"ADD COLUMN quota_grok_4_3 TEXT" |
| ) |
| await conn.exec_driver_sql( |
| f"UPDATE {_TBL_ACCOUNTS} " |
| f"SET quota_grok_4_3 = '{{}}' " |
| f"WHERE quota_grok_4_3 IS NULL" |
| ) |
| await conn.exec_driver_sql( |
| f"ALTER TABLE {_TBL_ACCOUNTS} " |
| f"MODIFY COLUMN quota_grok_4_3 TEXT NOT NULL" |
| ) |
| else: |
| await conn.exec_driver_sql( |
| f"ALTER TABLE {_TBL_ACCOUNTS} " |
| f"ADD COLUMN quota_grok_4_3 TEXT NOT NULL DEFAULT '{{}}'" |
| ) |
|
|
| async def _table_columns(self, conn: Any, table: str) -> set[str]: |
| if self._dialect == "postgresql": |
| rows = await conn.execute( |
| sa.text( |
| "SELECT column_name FROM information_schema.columns " |
| "WHERE table_name = :t" |
| ), |
| {"t": table}, |
| ) |
| else: |
| rows = await conn.execute( |
| sa.text( |
| "SELECT COLUMN_NAME FROM information_schema.columns " |
| "WHERE table_schema = DATABASE() AND table_name = :t" |
| ), |
| {"t": table}, |
| ) |
| return {str(r[0]).lower() for r in rows.fetchall()} |
|
|
| async def initialize(self) -> None: |
| await self._ensure_initialized() |
|
|
| async def get_revision(self) -> int: |
| await self._ensure_initialized() |
| async with self._engine.connect() as conn: |
| return await self._get_revision(conn) |
|
|
| async def runtime_snapshot(self) -> RuntimeSnapshot: |
| await self._ensure_initialized() |
| async with self._engine.connect() as conn: |
| rev = await self._get_revision(conn) |
| rows = (await conn.execute( |
| sa.select(accounts_table).where(accounts_table.c.deleted_at.is_(None)) |
| )).fetchall() |
| return RuntimeSnapshot(revision=rev, items=[_row_to_record(r) for r in rows]) |
|
|
| async def scan_changes( |
| self, |
| since_revision: int, |
| *, |
| limit: int = 5000, |
| ) -> AccountChangeSet: |
| await self._ensure_initialized() |
| async with self._engine.connect() as conn: |
| rev = await self._get_revision(conn) |
| rows = (await conn.execute( |
| sa.select(accounts_table) |
| .where(accounts_table.c.revision > since_revision) |
| .order_by(accounts_table.c.revision) |
| .limit(limit) |
| )).fetchall() |
| items: list[AccountRecord] = [] |
| deleted: list[str] = [] |
| for row in rows: |
| r = _row_to_record(row) |
| if r.is_deleted(): |
| deleted.append(r.token) |
| else: |
| items.append(r) |
| return AccountChangeSet( |
| revision=rev, |
| items=items, |
| deleted_tokens=deleted, |
| has_more=len(rows) == limit, |
| ) |
|
|
| async def upsert_accounts( |
| self, |
| items: list[AccountUpsert], |
| ) -> AccountMutationResult: |
| if not items: |
| return AccountMutationResult() |
| await self._ensure_initialized() |
| async with self._engine.begin() as conn: |
| rev = await self._bump_revision(conn) |
| ts = now_ms() |
| count = 0 |
| for item in items: |
| try: |
| token = AccountRecord.model_validate({"token": item.token, "pool": item.pool}).token |
| except Exception: |
| continue |
| pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" |
| qs = default_quota_set(pool) |
| row = { |
| "token": token, |
| "pool": pool, |
| "status": "active", |
| "created_at": ts, |
| "updated_at": ts, |
| "deleted_at": None, |
| "tags": json.dumps(item.tags), |
| "quota_auto": json.dumps(qs.auto.to_dict()), |
| "quota_fast": json.dumps(qs.fast.to_dict()), |
| "quota_expert": json.dumps(qs.expert.to_dict()), |
| "quota_heavy": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", |
| "quota_grok_4_3": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", |
| "usage_use_count": 0, |
| "usage_fail_count": 0, |
| "usage_sync_count": 0, |
| "ext": json.dumps(item.ext), |
| "revision": rev, |
| } |
| await conn.execute(self._build_upsert(row)) |
| count += 1 |
| return AccountMutationResult(upserted=count, revision=rev) |
|
|
| async def patch_accounts( |
| self, |
| patches: list[AccountPatch], |
| ) -> AccountMutationResult: |
| if not patches: |
| return AccountMutationResult() |
| await self._ensure_initialized() |
| async with self._engine.begin() as conn: |
| rev = await self._bump_revision(conn) |
| ts = now_ms() |
| count = 0 |
| for patch in patches: |
| row = (await conn.execute( |
| sa.select(accounts_table).where(accounts_table.c.token == patch.token) |
| )).fetchone() |
| if row is None: |
| continue |
| record = _row_to_record(row) |
|
|
| updates: dict[str, Any] = {"updated_at": ts, "revision": rev} |
| if patch.pool is not None: |
| updates["pool"] = patch.pool |
| if patch.status is not None: |
| updates["status"] = patch.status.value |
| if patch.state_reason is not None: |
| updates["state_reason"] = patch.state_reason |
| if patch.last_use_at is not None: |
| updates["last_use_at"] = patch.last_use_at |
| if patch.last_fail_at is not None: |
| updates["last_fail_at"] = patch.last_fail_at |
| if patch.last_fail_reason is not None: |
| updates["last_fail_reason"] = patch.last_fail_reason |
| if patch.last_sync_at is not None: |
| updates["last_sync_at"] = patch.last_sync_at |
| if patch.last_clear_at is not None: |
| updates["last_clear_at"] = patch.last_clear_at |
| if patch.quota_auto is not None: |
| updates["quota_auto"] = json.dumps(patch.quota_auto) |
| if patch.quota_fast is not None: |
| updates["quota_fast"] = json.dumps(patch.quota_fast) |
| if patch.quota_expert is not None: |
| updates["quota_expert"] = json.dumps(patch.quota_expert) |
| if patch.quota_heavy is not None: |
| updates["quota_heavy"] = json.dumps(patch.quota_heavy) |
| if patch.quota_grok_4_3 is not None: |
| updates["quota_grok_4_3"] = json.dumps(patch.quota_grok_4_3) |
| if patch.usage_use_delta is not None: |
| updates["usage_use_count"] = max(0, record.usage_use_count + patch.usage_use_delta) |
| if patch.usage_fail_delta is not None: |
| updates["usage_fail_count"] = max(0, record.usage_fail_count + patch.usage_fail_delta) |
| if patch.usage_sync_delta is not None: |
| updates["usage_sync_count"] = max(0, record.usage_sync_count + patch.usage_sync_delta) |
|
|
| tags = list(record.tags) |
| if patch.tags is not None: |
| tags = patch.tags |
| if patch.add_tags: |
| for t in patch.add_tags: |
| if t not in tags: |
| tags.append(t) |
| if patch.remove_tags: |
| tags = [t for t in tags if t not in patch.remove_tags] |
| updates["tags"] = json.dumps(tags) |
|
|
| ext = dict(record.ext) |
| if patch.ext_merge: |
| ext.update(patch.ext_merge) |
| if patch.clear_failures: |
| for k in ("cooldown_until", "cooldown_reason", "disabled_at", |
| "disabled_reason", "expired_at", "expired_reason", |
| "forbidden_strikes"): |
| ext.pop(k, None) |
| updates["status"] = AccountStatus.ACTIVE.value |
| updates["usage_fail_count"] = 0 |
| updates["last_fail_at"] = None |
| updates["last_fail_reason"] = None |
| updates["state_reason"] = None |
| updates["ext"] = json.dumps(ext) |
|
|
| await conn.execute( |
| accounts_table.update() |
| .where(accounts_table.c.token == patch.token) |
| .values(**updates) |
| ) |
| count += 1 |
| return AccountMutationResult(patched=count, revision=rev) |
|
|
| async def delete_accounts( |
| self, |
| tokens: list[str], |
| ) -> AccountMutationResult: |
| if not tokens: |
| return AccountMutationResult() |
| await self._ensure_initialized() |
| async with self._engine.begin() as conn: |
| rev = await self._bump_revision(conn) |
| ts = now_ms() |
| result = await conn.execute( |
| accounts_table.update() |
| .where( |
| accounts_table.c.token.in_(tokens), |
| accounts_table.c.deleted_at.is_(None), |
| ) |
| .values(deleted_at=ts, updated_at=ts, revision=rev) |
| ) |
| return AccountMutationResult(deleted=result.rowcount, revision=rev) |
|
|
| async def get_accounts( |
| self, |
| tokens: list[str], |
| ) -> list[AccountRecord]: |
| if not tokens: |
| return [] |
| await self._ensure_initialized() |
| async with self._engine.connect() as conn: |
| rows = (await conn.execute( |
| sa.select(accounts_table).where(accounts_table.c.token.in_(tokens)) |
| )).fetchall() |
| return [_row_to_record(r) for r in rows] |
|
|
| async def list_accounts( |
| self, |
| query: ListAccountsQuery, |
| ) -> AccountPage: |
| await self._ensure_initialized() |
| async with self._engine.connect() as conn: |
| stmt = sa.select(accounts_table) |
| if not query.include_deleted: |
| stmt = stmt.where(accounts_table.c.deleted_at.is_(None)) |
| if query.pool: |
| stmt = stmt.where(accounts_table.c.pool == query.pool) |
| if query.status: |
| stmt = stmt.where(accounts_table.c.status == query.status.value) |
|
|
| total_row = (await conn.execute( |
| sa.select(sa.func.count()).select_from(stmt.subquery()) |
| )).scalar() |
| total = int(total_row or 0) |
|
|
| sort_col = getattr(accounts_table.c, query.sort_by, accounts_table.c.updated_at) |
| if query.sort_desc: |
| stmt = stmt.order_by(sort_col.desc()) |
| else: |
| stmt = stmt.order_by(sort_col.asc()) |
| offset = (query.page - 1) * query.page_size |
| stmt = stmt.limit(query.page_size).offset(offset) |
|
|
| rows = (await conn.execute(stmt)).fetchall() |
| rev = await self._get_revision(conn) |
| return AccountPage( |
| items=[_row_to_record(r) for r in rows], |
| total=total, |
| page=query.page, |
| page_size=query.page_size, |
| total_pages=max(1, (total + query.page_size - 1) // query.page_size), |
| revision=rev, |
| ) |
|
|
| async def replace_pool( |
| self, |
| command: BulkReplacePoolCommand, |
| ) -> AccountMutationResult: |
| await self._ensure_initialized() |
| async with self._engine.begin() as conn: |
| rev = await self._bump_revision(conn) |
| ts = now_ms() |
| del_result = await conn.execute( |
| accounts_table.update() |
| .where( |
| accounts_table.c.pool == command.pool, |
| accounts_table.c.deleted_at.is_(None), |
| ) |
| .values(deleted_at=ts, updated_at=ts, revision=rev) |
| ) |
| deleted = del_result.rowcount |
|
|
| upserted_result = await self.upsert_accounts(command.upserts) |
| return AccountMutationResult( |
| upserted=upserted_result.upserted, |
| deleted=deleted, |
| revision=upserted_result.revision, |
| ) |
|
|
| async def close(self) -> None: |
| """Dispose the SQLAlchemy connection pool.""" |
| if self._dispose_engine: |
| _evict_cached_engine(self._engine) |
| await self._engine.dispose() |
|
|
|
|
| def _engine_cache_key(dialect: str, normalized_url: str, connect_args: dict[str, Any] | None) -> tuple[str, str, str]: |
| """Build a stable cache key from the normalized URL and connect args.""" |
| args_key = str(sorted(connect_args.items(), key=lambda kv: kv[0])) if connect_args else "" |
| return (dialect, normalized_url, args_key) |
|
|
|
|
| def create_mysql_engine(url: str) -> AsyncEngine: |
| """Create an async SQLAlchemy engine for MySQL.""" |
| normalized_url, connect_args = _prepare_sql_url_and_connect_args("mysql", (url or "").strip()) |
| return _get_or_create_engine(_engine_cache_key("mysql", normalized_url, connect_args), normalized_url, connect_args) |
|
|
|
|
| def create_pgsql_engine(url: str) -> AsyncEngine: |
| """Create an async SQLAlchemy engine for PostgreSQL.""" |
| normalized_url, connect_args = _prepare_sql_url_and_connect_args("postgresql", (url or "").strip()) |
| return _get_or_create_engine(_engine_cache_key("postgresql", normalized_url, connect_args), normalized_url, connect_args) |
|
|
|
|
| __all__ = ["SqlAccountRepository", "create_mysql_engine", "create_pgsql_engine"] |
|
|