File size: 7,014 Bytes
c6393bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """
database.py — SQLite database layer for rita2api
Replaces JSON file storage with SQLite.
Thread-safe with connection-per-thread pattern.
"""
import sqlite3
import threading
import time
import os
from pathlib import Path
_DEFAULT_DATA_DIR = Path(os.getenv("RITA_DATA_DIR", "/data"))
DATA_DIR = _DEFAULT_DATA_DIR if str(_DEFAULT_DATA_DIR) else (Path(__file__).parent / "data")
DATA_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DATA_DIR / "rita.db"
class DB:
"""Thread-safe SQLite wrapper using connection-per-thread."""
def __init__(self, db_path=None):
self._db_path = str(db_path or DB_PATH)
self._local = threading.local()
self._init_tables()
def _get_conn(self) -> sqlite3.Connection:
if not hasattr(self._local, 'conn') or self._local.conn is None:
self._local.conn = sqlite3.connect(self._db_path)
self._local.conn.row_factory = sqlite3.Row
self._local.conn.execute("PRAGMA journal_mode=WAL")
self._local.conn.execute("PRAGMA foreign_keys=ON")
return self._local.conn
def execute(self, sql, params=None):
conn = self._get_conn()
cur = conn.execute(sql, params or ())
conn.commit()
return cur
def executemany(self, sql, params_list):
conn = self._get_conn()
cur = conn.executemany(sql, params_list)
conn.commit()
return cur
def fetchone(self, sql, params=None):
return self._get_conn().execute(sql, params or ()).fetchone()
def fetchall(self, sql, params=None):
return self._get_conn().execute(sql, params or ()).fetchall()
def _init_tables(self):
conn = self._get_conn()
conn.executescript("""
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
name TEXT NOT NULL DEFAULT '',
token TEXT NOT NULL DEFAULT '',
visitorid TEXT NOT NULL DEFAULT '',
enabled INTEGER NOT NULL DEFAULT 1,
email TEXT NOT NULL DEFAULT '',
password TEXT NOT NULL DEFAULT '',
mail_provider TEXT NOT NULL DEFAULT '',
mail_api_key TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL DEFAULT 0,
quota_remain INTEGER NOT NULL DEFAULT 100,
total_requests INTEGER NOT NULL DEFAULT 0,
total_success INTEGER NOT NULL DEFAULT 0,
total_fail INTEGER NOT NULL DEFAULT 0,
last_used REAL NOT NULL DEFAULT 0,
last_error TEXT NOT NULL DEFAULT '',
token_valid INTEGER NOT NULL DEFAULT 1,
disabled_reason TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS usage_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp REAL NOT NULL,
account_id TEXT NOT NULL DEFAULT '',
model TEXT NOT NULL DEFAULT '',
tokens_approx INTEGER NOT NULL DEFAULT 0,
success INTEGER NOT NULL DEFAULT 1
);
""")
# Seed default config values (INSERT OR IGNORE = don't overwrite existing)
defaults = [
("RITA_UPSTREAM", "https://api_v2.rita.ai", "Rita.ai upstream API URL"),
("RITA_ORIGIN", "https://www.rita.ai", "Rita.ai origin for headers"),
("AUTH_TOKEN", "", "Admin panel auth token (empty=no auth)"),
("DISABLE_SSL_VERIFY", "1", "Disable SSL verification for upstream"),
("HEALTH_CHECK_INTERVAL", "600", "Health check interval in seconds"),
("AUTO_REGISTER_ENABLED", "0", "Enable auto-registration"),
("AUTO_REGISTER_MIN_ACTIVE", "2", "Minimum active accounts"),
("AUTO_REGISTER_BATCH", "1", "Accounts per registration batch"),
("AUTO_REGISTER_PASSWORD", "@qazwsx123456", "Default password for new accounts"),
("YESCAPTCHA_KEY", "", "YesCaptcha API key"),
("GPTMAIL_API_KEY", "", "GPTMail API key"),
("GPTMAIL_API_BASE", "https://mail.chatgpt.org.uk", "GPTMail API base URL"),
("YYDSMAIL_API_KEY", "", "YYDS Mail API key"),
("YYDSMAIL_API_BASE", "https://maliapi.215.im/v1", "YYDS Mail API base URL"),
("CAPTCHA_PROVIDER", "yescaptcha", "Captcha provider: yescaptcha or whisper"),
("AUTO_REGISTER_MIN_QUOTA", "50", "Auto-register when total quota below this"),
]
conn.executemany(
"INSERT OR IGNORE INTO config (key, value, description) VALUES (?, ?, ?)",
defaults
)
conn.commit()
print(f"[Database] Initialized: {self._db_path}")
# ===================== Config helpers =====================
def get_config(self, key, default=""):
row = self.fetchone("SELECT value FROM config WHERE key=?", (key,))
return row["value"] if row else default
def set_config(self, key, value, description=""):
self.execute(
"INSERT INTO config (key, value, description) VALUES (?, ?, ?) "
"ON CONFLICT(key) DO UPDATE SET value=excluded.value",
(key, value, description)
)
def get_all_config(self):
rows = self.fetchall("SELECT key, value, description FROM config ORDER BY key")
return [dict(r) for r in rows]
# ===================== Usage log helpers =====================
def log_usage(self, account_id, model, tokens_approx=0, success=True):
self.execute(
"INSERT INTO usage_log (timestamp, account_id, model, tokens_approx, success) VALUES (?,?,?,?,?)",
(time.time(), account_id, model, tokens_approx, 1 if success else 0)
)
def get_usage_stats(self):
today_start = time.mktime(time.strptime(time.strftime("%Y-%m-%d"), "%Y-%m-%d"))
total = self.fetchone("SELECT COUNT(*) as c, SUM(tokens_approx) as t FROM usage_log")
today = self.fetchone("SELECT COUNT(*) as c FROM usage_log WHERE timestamp >= ?", (today_start,))
by_model = self.fetchall(
"SELECT model, COUNT(*) as c FROM usage_log GROUP BY model ORDER BY c DESC LIMIT 20"
)
return {
"total_requests": total["c"] or 0,
"total_tokens_approx": total["t"] or 0,
"requests_today": today["c"] or 0,
"requests_by_model": {r["model"]: r["c"] for r in by_model},
}
# Singleton
_db = None
def get_db():
global _db
if _db is None:
_db = DB()
return _db
if __name__ == "__main__":
db = DB()
tables = db.fetchall("SELECT name FROM sqlite_master WHERE type='table'")
print(f"Database initialized: {len(tables)} tables created")
for t in tables:
print(f" - {t['name']}")
|