|
|
"""Database storage layer""" |
|
|
import aiomysql |
|
|
import json |
|
|
from datetime import datetime |
|
|
from typing import Optional, List, Any |
|
|
from pathlib import Path |
|
|
from .models import Token, TokenStats, Task, RequestLog, AdminConfig, ProxyConfig, WatermarkFreeConfig, CacheConfig, GenerationConfig, TokenRefreshConfig |
|
|
from .config import config |
|
|
|
|
|
class MySQLCursorWrapper: |
|
|
def __init__(self, cursor): |
|
|
self.cursor = cursor |
|
|
|
|
|
def __getattr__(self, name): |
|
|
return getattr(self.cursor, name) |
|
|
|
|
|
async def fetchone(self): |
|
|
return await self.cursor.fetchone() |
|
|
|
|
|
async def fetchall(self): |
|
|
return await self.cursor.fetchall() |
|
|
|
|
|
@property |
|
|
def lastrowid(self): |
|
|
return self.cursor.lastrowid |
|
|
|
|
|
class MySQLConnectionWrapper: |
|
|
def __init__(self, pool): |
|
|
self.pool = pool |
|
|
self.conn = None |
|
|
self.cursor = None |
|
|
|
|
|
async def __aenter__(self): |
|
|
self.conn = await self.pool.acquire() |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
if self.conn: |
|
|
self.pool.release(self.conn) |
|
|
|
|
|
async def execute(self, query: str, args: tuple = None) -> MySQLCursorWrapper: |
|
|
|
|
|
query = query.replace('?', '%s') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cursor = await self.conn.cursor(aiomysql.DictCursor) |
|
|
|
|
|
try: |
|
|
await self.cursor.execute(query, args) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"SQL Error: {e}\nQuery: {query}\nArgs: {args}") |
|
|
raise e |
|
|
return MySQLCursorWrapper(self.cursor) |
|
|
|
|
|
async def commit(self): |
|
|
await self.conn.commit() |
|
|
|
|
|
@property |
|
|
def row_factory(self): |
|
|
return None |
|
|
|
|
|
@row_factory.setter |
|
|
def row_factory(self, value): |
|
|
|
|
|
pass |
|
|
|
|
|
class Database: |
|
|
"""MySQL database manager (compatible interface with previous SQLite implementation)""" |
|
|
|
|
|
def __init__(self, db_path: str = None): |
|
|
self._pool = None |
|
|
|
|
|
|
|
|
async def get_pool(self): |
|
|
if self._pool is None: |
|
|
self._pool = await aiomysql.create_pool( |
|
|
host=config.db_host, |
|
|
port=config.db_port, |
|
|
user=config.db_user, |
|
|
password=config.db_password, |
|
|
db=config.db_name, |
|
|
autocommit=False, |
|
|
cursorclass=aiomysql.DictCursor |
|
|
) |
|
|
return self._pool |
|
|
|
|
|
def connect(self): |
|
|
"""Return a context manager that mimics aiosqlite.connect""" |
|
|
|
|
|
class ConnectContext: |
|
|
def __init__(self, db_instance): |
|
|
self.db = db_instance |
|
|
|
|
|
async def __aenter__(self): |
|
|
pool = await self.db.get_pool() |
|
|
self.wrapper = MySQLConnectionWrapper(pool) |
|
|
return await self.wrapper.__aenter__() |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
await self.wrapper.__aexit__(exc_type, exc_val, exc_tb) |
|
|
|
|
|
return ConnectContext(self) |
|
|
|
|
|
async def db_exists(self) -> bool: |
|
|
"""Check if database exists (always true for MySQL if we can connect)""" |
|
|
try: |
|
|
pool = await self.get_pool() |
|
|
async with pool.acquire() as conn: |
|
|
async with conn.cursor() as cur: |
|
|
await cur.execute("SELECT 1") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Database connection failed: {e}") |
|
|
return False |
|
|
|
|
|
async def _table_exists(self, db, table_name: str) -> bool: |
|
|
"""Check if a table exists in the database""" |
|
|
|
|
|
cursor = await db.execute( |
|
|
"SELECT count(*) as count FROM information_schema.tables WHERE table_schema = %s AND table_name = %s", |
|
|
(config.db_name, table_name) |
|
|
) |
|
|
result = await cursor.fetchone() |
|
|
return result['count'] > 0 |
|
|
|
|
|
async def _column_exists(self, db, table_name: str, column_name: str) -> bool: |
|
|
"""Check if a column exists in a table""" |
|
|
try: |
|
|
cursor = await db.execute( |
|
|
"SELECT count(*) as count FROM information_schema.columns WHERE table_schema = %s AND table_name = %s AND column_name = %s", |
|
|
(config.db_name, table_name, column_name) |
|
|
) |
|
|
result = await cursor.fetchone() |
|
|
return result['count'] > 0 |
|
|
except: |
|
|
return False |
|
|
|
|
|
async def _ensure_config_rows(self, db, config_dict: dict = None): |
|
|
"""Ensure all config tables have their default rows""" |
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM admin_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
admin_username = "admin" |
|
|
admin_password = "admin" |
|
|
error_ban_threshold = 3 |
|
|
|
|
|
if config_dict: |
|
|
global_config = config_dict.get("global", {}) |
|
|
admin_username = global_config.get("admin_username", "admin") |
|
|
admin_password = global_config.get("admin_password", "admin") |
|
|
|
|
|
admin_config = config_dict.get("admin", {}) |
|
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3) |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO admin_config (id, admin_username, admin_password, error_ban_threshold) |
|
|
VALUES (1, ?, ?, ?) |
|
|
""", (admin_username, admin_password, error_ban_threshold)) |
|
|
|
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM proxy_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
proxy_enabled = False |
|
|
proxy_url = None |
|
|
|
|
|
if config_dict: |
|
|
proxy_config = config_dict.get("proxy", {}) |
|
|
proxy_enabled = proxy_config.get("proxy_enabled", False) |
|
|
proxy_url = proxy_config.get("proxy_url", "") |
|
|
|
|
|
proxy_url = proxy_url if proxy_url else None |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO proxy_config (id, proxy_enabled, proxy_url) |
|
|
VALUES (1, ?, ?) |
|
|
""", (proxy_enabled, proxy_url)) |
|
|
|
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM watermark_free_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
watermark_free_enabled = False |
|
|
parse_method = "third_party" |
|
|
custom_parse_url = None |
|
|
custom_parse_token = None |
|
|
|
|
|
if config_dict: |
|
|
watermark_config = config_dict.get("watermark_free", {}) |
|
|
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False) |
|
|
parse_method = watermark_config.get("parse_method", "third_party") |
|
|
custom_parse_url = watermark_config.get("custom_parse_url", "") |
|
|
custom_parse_token = watermark_config.get("custom_parse_token", "") |
|
|
|
|
|
|
|
|
custom_parse_url = custom_parse_url if custom_parse_url else None |
|
|
custom_parse_token = custom_parse_token if custom_parse_token else None |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token) |
|
|
VALUES (1, ?, ?, ?, ?) |
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)) |
|
|
|
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM cache_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
cache_enabled = False |
|
|
cache_timeout = 600 |
|
|
cache_base_url = None |
|
|
|
|
|
if config_dict: |
|
|
cache_config = config_dict.get("cache", {}) |
|
|
cache_enabled = cache_config.get("enabled", False) |
|
|
cache_timeout = cache_config.get("timeout", 600) |
|
|
cache_base_url = cache_config.get("base_url", "") |
|
|
|
|
|
cache_base_url = cache_base_url if cache_base_url else None |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) |
|
|
VALUES (1, ?, ?, ?) |
|
|
""", (cache_enabled, cache_timeout, cache_base_url)) |
|
|
|
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM generation_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
image_timeout = 300 |
|
|
video_timeout = 1500 |
|
|
|
|
|
if config_dict: |
|
|
generation_config = config_dict.get("generation", {}) |
|
|
image_timeout = generation_config.get("image_timeout", 300) |
|
|
video_timeout = generation_config.get("video_timeout", 1500) |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO generation_config (id, image_timeout, video_timeout) |
|
|
VALUES (1, ?, ?) |
|
|
""", (image_timeout, video_timeout)) |
|
|
|
|
|
|
|
|
cursor = await db.execute("SELECT COUNT(*) as count FROM token_refresh_config") |
|
|
count = await cursor.fetchone() |
|
|
if count['count'] == 0: |
|
|
|
|
|
at_auto_refresh_enabled = False |
|
|
|
|
|
if config_dict: |
|
|
token_refresh_config = config_dict.get("token_refresh", {}) |
|
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False) |
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO token_refresh_config (id, at_auto_refresh_enabled) |
|
|
VALUES (1, ?) |
|
|
""", (at_auto_refresh_enabled,)) |
|
|
|
|
|
async def check_and_migrate_db(self, config_dict: dict = None): |
|
|
"""Check database integrity and perform migrations if needed""" |
|
|
async with self.connect() as db: |
|
|
print("Checking database integrity and performing migrations...") |
|
|
|
|
|
|
|
|
if await self._table_exists(db, "tokens"): |
|
|
columns_to_add = [ |
|
|
("sora2_supported", "BOOLEAN"), |
|
|
("sora2_invite_code", "TEXT"), |
|
|
("sora2_redeemed_count", "INT DEFAULT 0"), |
|
|
("sora2_total_count", "INT DEFAULT 0"), |
|
|
("sora2_remaining_count", "INT DEFAULT 0"), |
|
|
("sora2_cooldown_until", "DATETIME"), |
|
|
("image_enabled", "BOOLEAN DEFAULT 1"), |
|
|
("video_enabled", "BOOLEAN DEFAULT 1"), |
|
|
] |
|
|
|
|
|
for col_name, col_type in columns_to_add: |
|
|
if not await self._column_exists(db, "tokens", col_name): |
|
|
try: |
|
|
await db.execute(f"ALTER TABLE tokens ADD COLUMN {col_name} {col_type}") |
|
|
print(f" ✓ Added column '{col_name}' to tokens table") |
|
|
except Exception as e: |
|
|
print(f" ✗ Failed to add column '{col_name}': {e}") |
|
|
|
|
|
|
|
|
if await self._table_exists(db, "admin_config"): |
|
|
columns_to_add = [ |
|
|
("admin_username", "TEXT"), |
|
|
("admin_password", "TEXT"), |
|
|
] |
|
|
|
|
|
for col_name, col_type in columns_to_add: |
|
|
if not await self._column_exists(db, "admin_config", col_name): |
|
|
try: |
|
|
await db.execute(f"ALTER TABLE admin_config ADD COLUMN {col_name} {col_type}") |
|
|
print(f" ✓ Added column '{col_name}' to admin_config table") |
|
|
except Exception as e: |
|
|
print(f" ✗ Failed to add column '{col_name}': {e}") |
|
|
|
|
|
|
|
|
if await self._table_exists(db, "watermark_free_config"): |
|
|
columns_to_add = [ |
|
|
("parse_method", "TEXT"), |
|
|
("custom_parse_url", "TEXT"), |
|
|
("custom_parse_token", "TEXT"), |
|
|
] |
|
|
|
|
|
for col_name, col_type in columns_to_add: |
|
|
if not await self._column_exists(db, "watermark_free_config", col_name): |
|
|
try: |
|
|
await db.execute(f"ALTER TABLE watermark_free_config ADD COLUMN {col_name} {col_type}") |
|
|
print(f" ✓ Added column '{col_name}' to watermark_free_config table") |
|
|
except Exception as e: |
|
|
print(f" ✗ Failed to add column '{col_name}': {e}") |
|
|
|
|
|
|
|
|
await self._ensure_config_rows(db, config_dict) |
|
|
|
|
|
await db.commit() |
|
|
print("Database migration check completed.") |
|
|
|
|
|
async def init_db(self): |
|
|
"""Initialize database tables - creates all tables and ensures data integrity""" |
|
|
async with self.connect() as db: |
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS tokens ( |
|
|
id INT PRIMARY KEY AUTO_INCREMENT, |
|
|
token VARCHAR(2000) NOT NULL, |
|
|
email VARCHAR(255) NOT NULL, |
|
|
username VARCHAR(255) NOT NULL, |
|
|
name VARCHAR(255) NOT NULL, |
|
|
st TEXT, |
|
|
rt TEXT, |
|
|
remark TEXT, |
|
|
expiry_time DATETIME, |
|
|
is_active BOOLEAN DEFAULT 1, |
|
|
cooled_until DATETIME, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
last_used_at DATETIME, |
|
|
use_count INT DEFAULT 0, |
|
|
plan_type VARCHAR(50), |
|
|
plan_title VARCHAR(100), |
|
|
subscription_end DATETIME, |
|
|
sora2_supported BOOLEAN, |
|
|
sora2_invite_code VARCHAR(255), |
|
|
sora2_redeemed_count INT DEFAULT 0, |
|
|
sora2_total_count INT DEFAULT 0, |
|
|
sora2_remaining_count INT DEFAULT 0, |
|
|
sora2_cooldown_until DATETIME, |
|
|
image_enabled BOOLEAN DEFAULT 1, |
|
|
video_enabled BOOLEAN DEFAULT 1 |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS token_stats ( |
|
|
id INT PRIMARY KEY AUTO_INCREMENT, |
|
|
token_id INT NOT NULL, |
|
|
image_count INT DEFAULT 0, |
|
|
video_count INT DEFAULT 0, |
|
|
error_count INT DEFAULT 0, |
|
|
last_error_at DATETIME, |
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id) |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS tasks ( |
|
|
id INT PRIMARY KEY AUTO_INCREMENT, |
|
|
task_id VARCHAR(255) UNIQUE NOT NULL, |
|
|
token_id INT NOT NULL, |
|
|
model VARCHAR(255) NOT NULL, |
|
|
prompt TEXT NOT NULL, |
|
|
status VARCHAR(50) NOT NULL DEFAULT 'processing', |
|
|
progress FLOAT DEFAULT 0, |
|
|
result_urls TEXT, |
|
|
error_message TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
completed_at DATETIME, |
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id) |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS request_logs ( |
|
|
id INT PRIMARY KEY AUTO_INCREMENT, |
|
|
token_id INT, |
|
|
operation VARCHAR(255) NOT NULL, |
|
|
request_body TEXT, |
|
|
response_body TEXT, |
|
|
status_code INT NOT NULL, |
|
|
duration FLOAT NOT NULL, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
FOREIGN KEY (token_id) REFERENCES tokens(id) |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS admin_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
admin_username VARCHAR(255) DEFAULT 'admin', |
|
|
admin_password VARCHAR(255) DEFAULT 'admin', |
|
|
error_ban_threshold INT DEFAULT 3, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS proxy_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
proxy_enabled BOOLEAN DEFAULT 0, |
|
|
proxy_url TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS watermark_free_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
watermark_free_enabled BOOLEAN DEFAULT 0, |
|
|
parse_method VARCHAR(50) DEFAULT 'third_party', |
|
|
custom_parse_url TEXT, |
|
|
custom_parse_token TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS cache_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
cache_enabled BOOLEAN DEFAULT 0, |
|
|
cache_timeout INT DEFAULT 600, |
|
|
cache_base_url TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS generation_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
image_timeout INT DEFAULT 300, |
|
|
video_timeout INT DEFAULT 1500, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS token_refresh_config ( |
|
|
id INT PRIMARY KEY DEFAULT 1, |
|
|
at_auto_refresh_enabled BOOLEAN DEFAULT 0, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
await db.execute("CREATE INDEX idx_task_id ON tasks(task_id)") |
|
|
except: pass |
|
|
|
|
|
try: |
|
|
await db.execute("CREATE INDEX idx_task_status ON tasks(status)") |
|
|
except: pass |
|
|
|
|
|
try: |
|
|
await db.execute("CREATE INDEX idx_token_active ON tokens(is_active)") |
|
|
except: pass |
|
|
|
|
|
await db.commit() |
|
|
|
|
|
async def init_config_from_toml(self, config_dict: dict, is_first_startup: bool = True): |
|
|
"""Initialize database configuration from setting.toml""" |
|
|
async with self.connect() as db: |
|
|
if is_first_startup: |
|
|
await self._ensure_config_rows(db, config_dict) |
|
|
|
|
|
|
|
|
admin_config = config_dict.get("admin", {}) |
|
|
error_ban_threshold = admin_config.get("error_ban_threshold", 3) |
|
|
|
|
|
global_config = config_dict.get("global", {}) |
|
|
admin_username = global_config.get("admin_username", "admin") |
|
|
admin_password = global_config.get("admin_password", "admin") |
|
|
|
|
|
if not is_first_startup: |
|
|
await db.execute(""" |
|
|
UPDATE admin_config |
|
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (admin_username, admin_password, error_ban_threshold)) |
|
|
|
|
|
|
|
|
proxy_config = config_dict.get("proxy", {}) |
|
|
proxy_enabled = proxy_config.get("proxy_enabled", False) |
|
|
proxy_url = proxy_config.get("proxy_url", "") |
|
|
proxy_url = proxy_url if proxy_url else None |
|
|
|
|
|
if is_first_startup: |
|
|
await db.execute(""" |
|
|
INSERT IGNORE INTO proxy_config (id, proxy_enabled, proxy_url) |
|
|
VALUES (1, ?, ?) |
|
|
""", (proxy_enabled, proxy_url)) |
|
|
else: |
|
|
await db.execute(""" |
|
|
UPDATE proxy_config |
|
|
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (proxy_enabled, proxy_url)) |
|
|
|
|
|
|
|
|
watermark_config = config_dict.get("watermark_free", {}) |
|
|
watermark_free_enabled = watermark_config.get("watermark_free_enabled", False) |
|
|
parse_method = watermark_config.get("parse_method", "third_party") |
|
|
custom_parse_url = watermark_config.get("custom_parse_url", "") |
|
|
custom_parse_token = watermark_config.get("custom_parse_token", "") |
|
|
|
|
|
custom_parse_url = custom_parse_url if custom_parse_url else None |
|
|
custom_parse_token = custom_parse_token if custom_parse_token else None |
|
|
|
|
|
if is_first_startup: |
|
|
await db.execute(""" |
|
|
INSERT IGNORE INTO watermark_free_config (id, watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token) |
|
|
VALUES (1, ?, ?, ?, ?) |
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)) |
|
|
else: |
|
|
await db.execute(""" |
|
|
UPDATE watermark_free_config |
|
|
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?, |
|
|
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (watermark_free_enabled, parse_method, custom_parse_url, custom_parse_token)) |
|
|
|
|
|
|
|
|
cache_config = config_dict.get("cache", {}) |
|
|
cache_enabled = cache_config.get("enabled", False) |
|
|
cache_timeout = cache_config.get("timeout", 600) |
|
|
cache_base_url = cache_config.get("base_url", "") |
|
|
cache_base_url = cache_base_url if cache_base_url else None |
|
|
|
|
|
if is_first_startup: |
|
|
await db.execute(""" |
|
|
INSERT IGNORE INTO cache_config (id, cache_enabled, cache_timeout, cache_base_url) |
|
|
VALUES (1, ?, ?, ?) |
|
|
""", (cache_enabled, cache_timeout, cache_base_url)) |
|
|
else: |
|
|
await db.execute(""" |
|
|
UPDATE cache_config |
|
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (cache_enabled, cache_timeout, cache_base_url)) |
|
|
|
|
|
|
|
|
generation_config = config_dict.get("generation", {}) |
|
|
image_timeout = generation_config.get("image_timeout", 300) |
|
|
video_timeout = generation_config.get("video_timeout", 1500) |
|
|
|
|
|
if is_first_startup: |
|
|
await db.execute(""" |
|
|
INSERT IGNORE INTO generation_config (id, image_timeout, video_timeout) |
|
|
VALUES (1, ?, ?) |
|
|
""", (image_timeout, video_timeout)) |
|
|
else: |
|
|
await db.execute(""" |
|
|
UPDATE generation_config |
|
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (image_timeout, video_timeout)) |
|
|
|
|
|
|
|
|
token_refresh_config = config_dict.get("token_refresh", {}) |
|
|
at_auto_refresh_enabled = token_refresh_config.get("at_auto_refresh_enabled", False) |
|
|
|
|
|
if is_first_startup: |
|
|
await db.execute(""" |
|
|
INSERT IGNORE INTO token_refresh_config (id, at_auto_refresh_enabled) |
|
|
VALUES (1, ?) |
|
|
""", (at_auto_refresh_enabled,)) |
|
|
else: |
|
|
await db.execute(""" |
|
|
UPDATE token_refresh_config |
|
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (at_auto_refresh_enabled,)) |
|
|
|
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def add_token(self, token: Token) -> int: |
|
|
"""Add a new token""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute(""" |
|
|
INSERT INTO tokens (token, email, username, name, st, rt, remark, expiry_time, is_active, |
|
|
plan_type, plan_title, subscription_end, sora2_supported, sora2_invite_code, |
|
|
sora2_redeemed_count, sora2_total_count, sora2_remaining_count, sora2_cooldown_until, |
|
|
image_enabled, video_enabled) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", (token.token, token.email, "", token.name, token.st, token.rt, |
|
|
token.remark, token.expiry_time, token.is_active, |
|
|
token.plan_type, token.plan_title, token.subscription_end, |
|
|
token.sora2_supported, token.sora2_invite_code, |
|
|
token.sora2_redeemed_count, token.sora2_total_count, |
|
|
token.sora2_remaining_count, token.sora2_cooldown_until, |
|
|
token.image_enabled, token.video_enabled)) |
|
|
await db.commit() |
|
|
token_id = cursor.lastrowid |
|
|
|
|
|
|
|
|
await db.execute(""" |
|
|
INSERT INTO token_stats (token_id) VALUES (?) |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
return token_id |
|
|
|
|
|
async def get_token(self, token_id: int) -> Optional[Token]: |
|
|
"""Get token by ID""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM tokens WHERE id = ?", (token_id,)) |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return Token(**dict(row)) |
|
|
return None |
|
|
|
|
|
async def get_token_by_value(self, token: str) -> Optional[Token]: |
|
|
"""Get token by value""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM tokens WHERE token = ?", (token,)) |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return Token(**dict(row)) |
|
|
return None |
|
|
|
|
|
async def get_active_tokens(self) -> List[Token]: |
|
|
"""Get all active tokens (enabled, not cooled down, not expired)""" |
|
|
async with self.connect() as db: |
|
|
|
|
|
|
|
|
cursor = await db.execute(""" |
|
|
SELECT * FROM tokens |
|
|
WHERE is_active = 1 |
|
|
AND (cooled_until IS NULL OR cooled_until < CURRENT_TIMESTAMP) |
|
|
AND expiry_time > CURRENT_TIMESTAMP |
|
|
ORDER BY last_used_at ASC |
|
|
""") |
|
|
rows = await cursor.fetchall() |
|
|
return [Token(**dict(row)) for row in rows] |
|
|
|
|
|
async def get_all_tokens(self) -> List[Token]: |
|
|
"""Get all tokens""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM tokens ORDER BY created_at DESC") |
|
|
rows = await cursor.fetchall() |
|
|
return [Token(**dict(row)) for row in rows] |
|
|
|
|
|
async def update_token_usage(self, token_id: int): |
|
|
"""Update token usage""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens |
|
|
SET last_used_at = CURRENT_TIMESTAMP, use_count = use_count + 1 |
|
|
WHERE id = ? |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token_status(self, token_id: int, is_active: bool): |
|
|
"""Update token status""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens SET is_active = ? WHERE id = ? |
|
|
""", (is_active, token_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token_sora2(self, token_id: int, supported: bool, invite_code: Optional[str] = None, |
|
|
redeemed_count: int = 0, total_count: int = 0, remaining_count: int = 0): |
|
|
"""Update token Sora2 support info""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens |
|
|
SET sora2_supported = ?, sora2_invite_code = ?, sora2_redeemed_count = ?, sora2_total_count = ?, sora2_remaining_count = ? |
|
|
WHERE id = ? |
|
|
""", (supported, invite_code, redeemed_count, total_count, remaining_count, token_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token_sora2_remaining(self, token_id: int, remaining_count: int): |
|
|
"""Update token Sora2 remaining count""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens SET sora2_remaining_count = ? WHERE id = ? |
|
|
""", (remaining_count, token_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token_sora2_cooldown(self, token_id: int, cooldown_until: Optional[datetime]): |
|
|
"""Update token Sora2 cooldown time""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens SET sora2_cooldown_until = ? WHERE id = ? |
|
|
""", (cooldown_until, token_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token_cooldown(self, token_id: int, cooled_until: datetime): |
|
|
"""Update token cooldown""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE tokens SET cooled_until = ? WHERE id = ? |
|
|
""", (cooled_until, token_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def delete_token(self, token_id: int): |
|
|
"""Delete token""" |
|
|
async with self.connect() as db: |
|
|
await db.execute("DELETE FROM token_stats WHERE token_id = ?", (token_id,)) |
|
|
await db.execute("DELETE FROM tokens WHERE id = ?", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
async def update_token(self, token_id: int, |
|
|
token: Optional[str] = None, |
|
|
st: Optional[str] = None, |
|
|
rt: Optional[str] = None, |
|
|
remark: Optional[str] = None, |
|
|
expiry_time: Optional[datetime] = None, |
|
|
plan_type: Optional[str] = None, |
|
|
plan_title: Optional[str] = None, |
|
|
subscription_end: Optional[datetime] = None, |
|
|
image_enabled: Optional[bool] = None, |
|
|
video_enabled: Optional[bool] = None): |
|
|
"""Update token (AT, ST, RT, remark, expiry_time, subscription info, image_enabled, video_enabled)""" |
|
|
async with self.connect() as db: |
|
|
|
|
|
updates = [] |
|
|
params = [] |
|
|
|
|
|
if token is not None: |
|
|
updates.append("token = ?") |
|
|
params.append(token) |
|
|
|
|
|
if st is not None: |
|
|
updates.append("st = ?") |
|
|
params.append(st) |
|
|
|
|
|
if rt is not None: |
|
|
updates.append("rt = ?") |
|
|
params.append(rt) |
|
|
|
|
|
if remark is not None: |
|
|
updates.append("remark = ?") |
|
|
params.append(remark) |
|
|
|
|
|
if expiry_time is not None: |
|
|
updates.append("expiry_time = ?") |
|
|
params.append(expiry_time) |
|
|
|
|
|
if plan_type is not None: |
|
|
updates.append("plan_type = ?") |
|
|
params.append(plan_type) |
|
|
|
|
|
if plan_title is not None: |
|
|
updates.append("plan_title = ?") |
|
|
params.append(plan_title) |
|
|
|
|
|
if subscription_end is not None: |
|
|
updates.append("subscription_end = ?") |
|
|
params.append(subscription_end) |
|
|
|
|
|
if image_enabled is not None: |
|
|
updates.append("image_enabled = ?") |
|
|
params.append(image_enabled) |
|
|
|
|
|
if video_enabled is not None: |
|
|
updates.append("video_enabled = ?") |
|
|
params.append(video_enabled) |
|
|
|
|
|
if updates: |
|
|
params.append(token_id) |
|
|
query = f"UPDATE tokens SET {', '.join(updates)} WHERE id = ?" |
|
|
await db.execute(query, params) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_token_stats(self, token_id: int) -> Optional[TokenStats]: |
|
|
"""Get token statistics""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM token_stats WHERE token_id = ?", (token_id,)) |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return TokenStats(**dict(row)) |
|
|
return None |
|
|
|
|
|
async def increment_image_count(self, token_id: int): |
|
|
"""Increment image generation count""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE token_stats SET image_count = image_count + 1 WHERE token_id = ? |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
async def increment_video_count(self, token_id: int): |
|
|
"""Increment video generation count""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE token_stats SET video_count = video_count + 1 WHERE token_id = ? |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
async def increment_error_count(self, token_id: int): |
|
|
"""Increment error count""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE token_stats |
|
|
SET error_count = error_count + 1, last_error_at = CURRENT_TIMESTAMP |
|
|
WHERE token_id = ? |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
async def reset_error_count(self, token_id: int): |
|
|
"""Reset error count""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE token_stats SET error_count = 0 WHERE token_id = ? |
|
|
""", (token_id,)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def create_task(self, task: Task) -> int: |
|
|
"""Create a new task""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute(""" |
|
|
INSERT INTO tasks (task_id, token_id, model, prompt, status, progress) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", (task.task_id, task.token_id, task.model, task.prompt, task.status, task.progress)) |
|
|
await db.commit() |
|
|
return cursor.lastrowid |
|
|
|
|
|
async def update_task(self, task_id: str, status: str, progress: float, |
|
|
result_urls: Optional[str] = None, error_message: Optional[str] = None): |
|
|
"""Update task status""" |
|
|
async with self.connect() as db: |
|
|
completed_at = datetime.now() if status in ["completed", "failed"] else None |
|
|
await db.execute(""" |
|
|
UPDATE tasks |
|
|
SET status = ?, progress = ?, result_urls = ?, error_message = ?, completed_at = ? |
|
|
WHERE task_id = ? |
|
|
""", (status, progress, result_urls, error_message, completed_at, task_id)) |
|
|
await db.commit() |
|
|
|
|
|
async def get_task(self, task_id: str) -> Optional[Task]: |
|
|
"""Get task by ID""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id,)) |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return Task(**dict(row)) |
|
|
return None |
|
|
|
|
|
|
|
|
async def log_request(self, log: RequestLog): |
|
|
"""Log a request""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
INSERT INTO request_logs (token_id, operation, request_body, response_body, status_code, duration) |
|
|
VALUES (?, ?, ?, ?, ?, ?) |
|
|
""", (log.token_id, log.operation, log.request_body, log.response_body, |
|
|
log.status_code, log.duration)) |
|
|
await db.commit() |
|
|
|
|
|
async def get_recent_logs(self, limit: int = 100) -> List[dict]: |
|
|
"""Get recent logs with token email""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute(""" |
|
|
SELECT |
|
|
rl.id, |
|
|
rl.token_id, |
|
|
rl.operation, |
|
|
rl.request_body, |
|
|
rl.response_body, |
|
|
rl.status_code, |
|
|
rl.duration, |
|
|
rl.created_at, |
|
|
t.email as token_email |
|
|
FROM request_logs rl |
|
|
LEFT JOIN tokens t ON rl.token_id = t.id |
|
|
ORDER BY rl.created_at DESC |
|
|
LIMIT ? |
|
|
""", (limit,)) |
|
|
rows = await cursor.fetchall() |
|
|
return [dict(row) for row in rows] |
|
|
|
|
|
|
|
|
async def get_admin_config(self) -> AdminConfig: |
|
|
"""Get admin configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM admin_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return AdminConfig(**dict(row)) |
|
|
return AdminConfig(admin_username="admin", admin_password="admin") |
|
|
|
|
|
async def update_admin_config(self, config: AdminConfig): |
|
|
"""Update admin configuration""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE admin_config |
|
|
SET admin_username = ?, admin_password = ?, error_ban_threshold = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (config.admin_username, config.admin_password, config.error_ban_threshold)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_proxy_config(self) -> ProxyConfig: |
|
|
"""Get proxy configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM proxy_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return ProxyConfig(**dict(row)) |
|
|
return ProxyConfig(proxy_enabled=False) |
|
|
|
|
|
async def update_proxy_config(self, enabled: bool, proxy_url: Optional[str]): |
|
|
"""Update proxy configuration""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE proxy_config |
|
|
SET proxy_enabled = ?, proxy_url = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (enabled, proxy_url)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_watermark_free_config(self) -> WatermarkFreeConfig: |
|
|
"""Get watermark-free configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM watermark_free_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return WatermarkFreeConfig(**dict(row)) |
|
|
return WatermarkFreeConfig(watermark_free_enabled=False, parse_method="third_party") |
|
|
|
|
|
async def update_watermark_free_config(self, enabled: bool, parse_method: str = None, |
|
|
custom_parse_url: str = None, custom_parse_token: str = None): |
|
|
"""Update watermark-free configuration""" |
|
|
async with self.connect() as db: |
|
|
if parse_method is None and custom_parse_url is None and custom_parse_token is None: |
|
|
|
|
|
await db.execute(""" |
|
|
UPDATE watermark_free_config |
|
|
SET watermark_free_enabled = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (enabled,)) |
|
|
else: |
|
|
|
|
|
await db.execute(""" |
|
|
UPDATE watermark_free_config |
|
|
SET watermark_free_enabled = ?, parse_method = ?, custom_parse_url = ?, |
|
|
custom_parse_token = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (enabled, parse_method or "third_party", custom_parse_url, custom_parse_token)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_cache_config(self) -> CacheConfig: |
|
|
"""Get cache configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return CacheConfig(**dict(row)) |
|
|
return CacheConfig(cache_enabled=False, cache_timeout=600) |
|
|
|
|
|
async def update_cache_config(self, enabled: bool = None, timeout: int = None, base_url: Optional[str] = None): |
|
|
"""Update cache configuration""" |
|
|
async with self.connect() as db: |
|
|
|
|
|
cursor = await db.execute("SELECT * FROM cache_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
|
|
|
if row: |
|
|
current = dict(row) |
|
|
new_enabled = enabled if enabled is not None else current.get("cache_enabled", False) |
|
|
new_timeout = timeout if timeout is not None else current.get("cache_timeout", 600) |
|
|
new_base_url = base_url if base_url is not None else current.get("cache_base_url") |
|
|
else: |
|
|
new_enabled = enabled if enabled is not None else False |
|
|
new_timeout = timeout if timeout is not None else 600 |
|
|
new_base_url = base_url |
|
|
|
|
|
|
|
|
new_base_url = new_base_url if new_base_url else None |
|
|
|
|
|
await db.execute(""" |
|
|
UPDATE cache_config |
|
|
SET cache_enabled = ?, cache_timeout = ?, cache_base_url = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (new_enabled, new_timeout, new_base_url)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_generation_config(self) -> GenerationConfig: |
|
|
"""Get generation configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return GenerationConfig(**dict(row)) |
|
|
return GenerationConfig(image_timeout=300, video_timeout=1500) |
|
|
|
|
|
async def update_generation_config(self, image_timeout: int = None, video_timeout: int = None): |
|
|
"""Update generation configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM generation_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
|
|
|
if row: |
|
|
current = dict(row) |
|
|
new_image_timeout = image_timeout if image_timeout is not None else current.get("image_timeout", 300) |
|
|
new_video_timeout = video_timeout if video_timeout is not None else current.get("video_timeout", 1500) |
|
|
else: |
|
|
new_image_timeout = image_timeout if image_timeout is not None else 300 |
|
|
new_video_timeout = video_timeout if video_timeout is not None else 1500 |
|
|
|
|
|
await db.execute(""" |
|
|
UPDATE generation_config |
|
|
SET image_timeout = ?, video_timeout = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (new_image_timeout, new_video_timeout)) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async def get_token_refresh_config(self) -> TokenRefreshConfig: |
|
|
"""Get token refresh configuration""" |
|
|
async with self.connect() as db: |
|
|
cursor = await db.execute("SELECT * FROM token_refresh_config WHERE id = 1") |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return TokenRefreshConfig(**dict(row)) |
|
|
return TokenRefreshConfig(at_auto_refresh_enabled=False) |
|
|
|
|
|
async def update_token_refresh_config(self, at_auto_refresh_enabled: bool): |
|
|
"""Update token refresh configuration""" |
|
|
async with self.connect() as db: |
|
|
await db.execute(""" |
|
|
UPDATE token_refresh_config |
|
|
SET at_auto_refresh_enabled = ?, updated_at = CURRENT_TIMESTAMP |
|
|
WHERE id = 1 |
|
|
""", (at_auto_refresh_enabled,)) |
|
|
await db.commit() |
|
|
|