"""Database operations for API key management. Supports PostgreSQL (via DATABASE_URL) with SQLite fallback for local dev. """ import os import sqlite3 import logging from contextlib import contextmanager from datetime import datetime, timezone logger = logging.getLogger(__name__) DATABASE_URL = os.getenv("DATABASE_URL", "") DB_PATH = os.getenv("DB_PATH", "/tmp/data/tavily_keys.db") _use_pg = bool(DATABASE_URL) if _use_pg: import psycopg2 import psycopg2.extras def _ensure_dir(): if not _use_pg: os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) @contextmanager def get_db(): if _use_pg: conn = psycopg2.connect(DATABASE_URL) conn.autocommit = False try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() else: _ensure_dir() conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() def _execute(conn, sql, params=None): if _use_pg: cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) else: cur = conn.cursor() cur.execute(sql, params or ()) return cur def _fetchall(conn, sql, params=None): cur = _execute(conn, sql, params) rows = cur.fetchall() return [dict(r) for r in rows] def _fetchone(conn, sql, params=None): cur = _execute(conn, sql, params) row = cur.fetchone() if row is None: return None return dict(row) def _sql(template): """Convert SQL template: replace ? with %s for PostgreSQL.""" if _use_pg: return template.replace("?", "%s") return template # ── Initialization ── def init_db(): with get_db() as conn: if _use_pg: _execute(conn, """ CREATE TABLE IF NOT EXISTS api_keys ( id SERIAL PRIMARY KEY, email TEXT NOT NULL, password TEXT DEFAULT '', api_key TEXT NOT NULL UNIQUE, service TEXT DEFAULT 'tavily', status TEXT DEFAULT 'active', created_at TEXT NOT NULL, last_checked TEXT, quota_remaining INTEGER, use_count INTEGER DEFAULT 0 ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS meta ( key TEXT PRIMARY KEY, value TEXT ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS access_tokens ( id SERIAL PRIMARY KEY, token TEXT NOT NULL UNIQUE, name TEXT DEFAULT '', quota_limit INTEGER DEFAULT 1000, quota_used INTEGER DEFAULT 0, is_admin BOOLEAN DEFAULT FALSE, status TEXT DEFAULT 'active', created_at TEXT NOT NULL, last_used TEXT, expires_at TEXT ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS config ( key TEXT PRIMARY KEY, value TEXT NOT NULL ) """) else: _execute(conn, """ CREATE TABLE IF NOT EXISTS api_keys ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL, password TEXT DEFAULT '', api_key TEXT NOT NULL UNIQUE, service TEXT DEFAULT 'tavily', status TEXT DEFAULT 'active', created_at TEXT NOT NULL, last_checked TEXT, quota_remaining INTEGER, use_count INTEGER DEFAULT 0 ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS meta ( key TEXT PRIMARY KEY, value TEXT ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS access_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, token TEXT NOT NULL UNIQUE, name TEXT DEFAULT '', quota_limit INTEGER DEFAULT 1000, quota_used INTEGER DEFAULT 0, is_admin INTEGER DEFAULT 0, status TEXT DEFAULT 'active', created_at TEXT NOT NULL, last_used TEXT, expires_at TEXT ) """) _execute(conn, """ CREATE TABLE IF NOT EXISTS config ( key TEXT PRIMARY KEY, value TEXT NOT NULL ) """) # Migrations for col, coltype, default in [ ("use_count", "INTEGER", "0"), ("quota_remaining", "INTEGER", "NULL"), ("service", "TEXT", "'tavily'"), ]: try: with get_db() as conn: if _use_pg: _execute(conn, f"ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS {col} {coltype} DEFAULT {default}") else: _execute(conn, f"ALTER TABLE api_keys ADD COLUMN {col} {coltype} DEFAULT {default}") except Exception: pass for tbl, col, coltype, default in [("access_tokens", "expires_at", "TEXT", "NULL")]: try: with get_db() as conn: if _use_pg: _execute(conn, f"ALTER TABLE {tbl} ADD COLUMN IF NOT EXISTS {col} {coltype} DEFAULT {default}") else: _execute(conn, f"ALTER TABLE {tbl} ADD COLUMN {col} {coltype} DEFAULT {default}") except Exception: pass _seed_defaults() db_type = "PostgreSQL" if _use_pg else f"SQLite ({DB_PATH})" logger.info("Database initialized: %s", db_type) def _seed_defaults(): """Seed default config values and admin token if not present.""" defaults = { "admin_password": os.getenv("ADMIN_PASSWORD", ""), "admin_token": os.getenv("ADMIN_TOKEN", ""), "free_mode": os.getenv("FREE_MODE", "false"), "default_quota": os.getenv("DEFAULT_QUOTA", "1000"), } with get_db() as conn: for k, v in defaults.items(): existing = _fetchone(conn, _sql("SELECT value FROM config WHERE key = ?"), (k,)) if not existing: if _use_pg: _execute(conn, "INSERT INTO config (key, value) VALUES (%s, %s) ON CONFLICT (key) DO NOTHING", (k, v)) else: _execute(conn, "INSERT OR IGNORE INTO config (key, value) VALUES (?, ?)", (k, v)) # Ensure admin token exists admin_token = get_config("admin_token") if admin_token: with get_db() as conn: existing = _fetchone(conn, _sql("SELECT id FROM access_tokens WHERE token = ?"), (admin_token,)) if not existing: now = datetime.now(timezone.utc).isoformat() if _use_pg: _execute(conn, "INSERT INTO access_tokens (token, name, quota_limit, quota_used, is_admin, status, created_at) " "VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (token) DO NOTHING", (admin_token, "Admin", 0, 0, True, "active", now)) else: _execute(conn, "INSERT OR IGNORE INTO access_tokens (token, name, quota_limit, quota_used, is_admin, status, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?)", (admin_token, "Admin", 0, 0, 1, "active", now)) # ── API Key CRUD ── def add_key(email: str, password: str, api_key: str, created_at: str = "", service: str = "tavily") -> int: if not created_at: created_at = datetime.now(timezone.utc).isoformat() with get_db() as conn: if _use_pg: cur = _execute(conn, "INSERT INTO api_keys (email, password, api_key, service, created_at) " "VALUES (%s, %s, %s, %s, %s) RETURNING id", (email, password, api_key, service, created_at)) return cur.fetchone()["id"] else: cur = _execute(conn, "INSERT INTO api_keys (email, password, api_key, service, created_at) " "VALUES (?, ?, ?, ?, ?)", (email, password, api_key, service, created_at)) return cur.lastrowid def add_keys_batch(keys: list[dict]) -> int: added = 0 with get_db() as conn: for k in keys: try: created = k.get("created_at", datetime.now(timezone.utc).isoformat()) svc = k.get("service") or "tavily" if _use_pg: cur = _execute(conn, "INSERT INTO api_keys (email, password, api_key, service, created_at) " "VALUES (%s, %s, %s, %s, %s) ON CONFLICT (api_key) DO NOTHING", (k["email"], k.get("password", ""), k["api_key"], svc, created)) else: cur = _execute(conn, "INSERT OR IGNORE INTO api_keys (email, password, api_key, service, created_at) " "VALUES (?, ?, ?, ?, ?)", (k["email"], k.get("password", ""), k["api_key"], svc, created)) added += int(bool(cur.rowcount)) except Exception: pass return added def list_keys(status: str = "", service: str = "") -> list[dict]: with get_db() as conn: conditions, params = [], [] if status: conditions.append("status = ?") params.append(status) if service: conditions.append("service = ?") params.append(service) where = (" WHERE " + " AND ".join(conditions)) if conditions else "" return _fetchall(conn, _sql(f"SELECT * FROM api_keys{where} ORDER BY id DESC"), tuple(params)) def get_key(key_id: int) -> dict | None: with get_db() as conn: return _fetchone(conn, _sql("SELECT * FROM api_keys WHERE id = ?"), (key_id,)) def delete_key(key_id: int) -> bool: with get_db() as conn: cur = _execute(conn, _sql("DELETE FROM api_keys WHERE id = ?"), (key_id,)) return cur.rowcount > 0 def delete_keys_batch(key_ids: list[int]) -> int: if not key_ids: return 0 with get_db() as conn: if _use_pg: cur = _execute(conn, "DELETE FROM api_keys WHERE id = ANY(%s)", (key_ids,)) else: placeholders = ",".join("?" for _ in key_ids) cur = _execute(conn, f"DELETE FROM api_keys WHERE id IN ({placeholders})", tuple(key_ids)) return cur.rowcount def delete_keys_by_status(statuses: list[str]) -> int: if not statuses: return 0 with get_db() as conn: if _use_pg: placeholders = ",".join("%s" for _ in statuses) else: placeholders = ",".join("?" for _ in statuses) cur = _execute(conn, f"DELETE FROM api_keys WHERE status IN ({placeholders})", tuple(statuses)) return cur.rowcount def update_status(key_id: int, status: str, quota_remaining: int | None = None): now = datetime.now(timezone.utc).isoformat() with get_db() as conn: if quota_remaining is not None: _execute(conn, _sql( "UPDATE api_keys SET status = ?, last_checked = ?, quota_remaining = ? WHERE id = ?" ), (status, now, quota_remaining, key_id)) else: _execute(conn, _sql( "UPDATE api_keys SET status = ?, last_checked = ? WHERE id = ?" ), (status, now, key_id)) def update_status_batch(key_ids: list[int], status: str) -> int: if not key_ids: return 0 now = datetime.now(timezone.utc).isoformat() with get_db() as conn: if _use_pg: cur = _execute(conn, "UPDATE api_keys SET status = %s, last_checked = %s WHERE id = ANY(%s)", (status, now, key_ids)) else: placeholders = ",".join("?" for _ in key_ids) cur = _execute(conn, f"UPDATE api_keys SET status = ?, last_checked = ? WHERE id IN ({placeholders})", (status, now, *key_ids)) return cur.rowcount def get_next_active_key(service: str = "tavily") -> dict | None: """Get the least-recently-checked active key for a service (round-robin) and increment use_count.""" with get_db() as conn: row = _fetchone(conn, _sql( "SELECT * FROM api_keys WHERE status = 'active' AND service = ? " "ORDER BY last_checked ASC NULLS FIRST LIMIT 1"), (service,)) if row: now = datetime.now(timezone.utc).isoformat() _execute(conn, _sql( "UPDATE api_keys SET last_checked = ?, use_count = COALESCE(use_count, 0) + 1 WHERE id = ?" ), (now, row["id"])) return row return None def get_stats() -> dict: with get_db() as conn: total = _fetchone(conn, "SELECT COUNT(*) as cnt FROM api_keys")["cnt"] active = _fetchone(conn, "SELECT COUNT(*) as cnt FROM api_keys WHERE status='active'")["cnt"] inactive = _fetchone(conn, "SELECT COUNT(*) as cnt FROM api_keys WHERE status='inactive'")["cnt"] exhausted = _fetchone(conn, "SELECT COUNT(*) as cnt FROM api_keys WHERE status='exhausted'")["cnt"] total_usage = _fetchone(conn, "SELECT COALESCE(SUM(use_count), 0) as cnt FROM api_keys")["cnt"] quota_sum = _fetchone(conn, "SELECT COALESCE(SUM(quota_remaining), 0) as cnt FROM api_keys " "WHERE status='active' AND quota_remaining IS NOT NULL")["cnt"] last_reg = _fetchone(conn, "SELECT created_at FROM api_keys ORDER BY id DESC LIMIT 1") last_check = _fetchone(conn, "SELECT value FROM meta WHERE key='last_healthcheck'") by_service = {} for svc in ("tavily", "firecrawl", "exa"): s_total = _fetchone(conn, _sql("SELECT COUNT(*) as cnt FROM api_keys WHERE service = ?"), (svc,))["cnt"] if s_total == 0: continue s_active = _fetchone(conn, _sql("SELECT COUNT(*) as cnt FROM api_keys WHERE service = ? AND status='active'"), (svc,))["cnt"] s_inactive = _fetchone(conn, _sql("SELECT COUNT(*) as cnt FROM api_keys WHERE service = ? AND status='inactive'"), (svc,))["cnt"] s_exhausted = _fetchone(conn, _sql("SELECT COUNT(*) as cnt FROM api_keys WHERE service = ? AND status='exhausted'"), (svc,))["cnt"] s_usage = _fetchone(conn, _sql("SELECT COALESCE(SUM(use_count), 0) as cnt FROM api_keys WHERE service = ?"), (svc,))["cnt"] s_quota = _fetchone(conn, _sql( "SELECT COALESCE(SUM(quota_remaining), 0) as cnt FROM api_keys " "WHERE service = ? AND status='active' AND quota_remaining IS NOT NULL"), (svc,))["cnt"] by_service[svc] = { "total_keys": s_total, "active_keys": s_active, "inactive_keys": s_inactive, "exhausted_keys": s_exhausted, "total_usage": s_usage, "total_quota_remaining": s_quota if s_quota else None, } return { "total_keys": total, "active_keys": active, "inactive_keys": inactive, "exhausted_keys": exhausted, "total_usage": total_usage, "total_quota_remaining": quota_sum if quota_sum else None, "last_registration": last_reg["created_at"] if last_reg else None, "last_healthcheck": last_check["value"] if last_check else None, "by_service": by_service, } def set_meta(key: str, value: str): with get_db() as conn: if _use_pg: _execute(conn, "INSERT INTO meta (key, value) VALUES (%s, %s) " "ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", (key, value)) else: _execute(conn, "INSERT OR REPLACE INTO meta (key, value) VALUES (?, ?)", (key, value)) def export_all_keys() -> list[dict]: with get_db() as conn: return _fetchall(conn, "SELECT email, password, api_key, service, status, created_at, last_checked, " "COALESCE(use_count, 0) as use_count FROM api_keys ORDER BY id") # ── Access Token CRUD ── def add_access_token(token: str, name: str = "", quota_limit: int = 1000, is_admin: bool = False, expires_at: str | None = None) -> int: now = datetime.now(timezone.utc).isoformat() with get_db() as conn: if _use_pg: cur = _execute(conn, "INSERT INTO access_tokens (token, name, quota_limit, is_admin, status, created_at, expires_at) " "VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING id", (token, name, quota_limit, is_admin, "active", now, expires_at)) return cur.fetchone()["id"] else: cur = _execute(conn, "INSERT INTO access_tokens (token, name, quota_limit, is_admin, status, created_at, expires_at) " "VALUES (?, ?, ?, ?, ?, ?, ?)", (token, name, quota_limit, 1 if is_admin else 0, "active", now, expires_at)) return cur.lastrowid def list_access_tokens() -> list[dict]: with get_db() as conn: rows = _fetchall(conn, "SELECT * FROM access_tokens ORDER BY id") for r in rows: r["is_admin"] = bool(r.get("is_admin")) return rows def get_access_token(token: str) -> dict | None: with get_db() as conn: row = _fetchone(conn, _sql("SELECT * FROM access_tokens WHERE token = ? AND status = 'active'"), (token,)) if row: row["is_admin"] = bool(row.get("is_admin")) exp = row.get("expires_at") if exp: try: if datetime.fromisoformat(exp.replace("Z", "+00:00")) < datetime.now(timezone.utc): return None except ValueError: pass return row def delete_access_token(token_id: int) -> bool: with get_db() as conn: cur = _execute(conn, _sql("DELETE FROM access_tokens WHERE id = ?"), (token_id,)) return cur.rowcount > 0 def update_access_token(token_id: int, **kwargs) -> bool: allowed = {"name", "quota_limit", "status", "expires_at"} updates = {k: v for k, v in kwargs.items() if k in allowed} if not updates: return False with get_db() as conn: if _use_pg: set_clause = ", ".join(f"{k} = %s" for k in updates) _execute(conn, f"UPDATE access_tokens SET {set_clause} WHERE id = %s", (*updates.values(), token_id)) else: set_clause = ", ".join(f"{k} = ?" for k in updates) _execute(conn, f"UPDATE access_tokens SET {set_clause} WHERE id = ?", (*updates.values(), token_id)) return True def increment_token_usage(token: str) -> bool: now = datetime.now(timezone.utc).isoformat() with get_db() as conn: cur = _execute(conn, _sql( "UPDATE access_tokens SET quota_used = quota_used + 1, last_used = ? WHERE token = ?"), (now, token)) return cur.rowcount > 0 def reset_token_usage(token_id: int) -> bool: with get_db() as conn: cur = _execute(conn, _sql("UPDATE access_tokens SET quota_used = 0 WHERE id = ?"), (token_id,)) return cur.rowcount > 0 # ── Config ── def get_config(key: str) -> str | None: with get_db() as conn: row = _fetchone(conn, _sql("SELECT value FROM config WHERE key = ?"), (key,)) return row["value"] if row else None def set_config(key: str, value: str): with get_db() as conn: if _use_pg: _execute(conn, "INSERT INTO config (key, value) VALUES (%s, %s) " "ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value", (key, value)) else: _execute(conn, "INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", (key, value)) def get_all_config() -> dict: with get_db() as conn: rows = _fetchall(conn, "SELECT key, value FROM config ORDER BY key") return {r["key"]: r["value"] for r in rows}