| |
| """ |
| 数据库工具函数。 |
| 包含数据库初始化、连接管理、以及针对数据库模型的 CRUD (创建、读取、更新、删除) 操作。 |
| """ |
| import logging |
| import aiosqlite |
| from typing import Dict, Any, List, Tuple, Optional, AsyncGenerator |
| from datetime import datetime, timezone |
| from sqlalchemy.orm import sessionmaker, Session |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
| from app.core.database.models import Base, UserKeyAssociation, KeyScore, Setting, ApiKey |
| from sqlalchemy import select, update, delete |
| from sqlalchemy.dialects import sqlite |
| from sqlalchemy.sql import text |
| from app.core.tracking import key_scores_cache, cache_lock |
| from contextlib import asynccontextmanager |
| import sqlalchemy |
| import sqlalchemy.exc |
|
|
| |
| logger = logging.getLogger("my_logger") |
|
|
| |
| |
| import os |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| from app import config as app_config |
|
|
| |
| logger.debug(f"环境变量 HF_SPACE_ID: {os.getenv('HF_SPACE_ID')}") |
| logger.debug(f"环境变量 APP_DB_MODE: {os.getenv('APP_DB_MODE')}") |
| logger.debug(f"环境变量 KEY_STORAGE_MODE: {os.getenv('KEY_STORAGE_MODE')}") |
| logger.debug(f"从 app.config 导入的 KEY_STORAGE_MODE: {app_config.KEY_STORAGE_MODE}") |
| |
|
|
| IS_HF_ENV = bool(os.getenv("HF_SPACE_ID")) |
| use_memory_db_reason = [] |
|
|
| if IS_HF_ENV: |
| use_memory_db_reason.append("Hugging Face Environment (HF_SPACE_ID set)") |
| if os.getenv("APP_DB_MODE") == "memory": |
| use_memory_db_reason.append("APP_DB_MODE=memory") |
| |
| if app_config.KEY_STORAGE_MODE == "memory": |
| use_memory_db_reason.append(f"config.KEY_STORAGE_MODE is '{app_config.KEY_STORAGE_MODE}' (可能是默认值或环境变量设置)") |
|
|
|
|
| if use_memory_db_reason: |
| DATABASE_PATH = ":memory:" |
| logger.info(f"最终决定使用内存数据库。原因: {'; '.join(use_memory_db_reason)}") |
| else: |
| |
| logger.info(f"尝试使用文件数据库。内存条件评估: HF Env: {IS_HF_ENV}, APP_DB_MODE env: '{os.getenv('APP_DB_MODE')}', config.KEY_STORAGE_MODE: '{app_config.KEY_STORAGE_MODE}'") |
| |
| custom_db_path_from_env = app_config.CONTEXT_DB_PATH |
| |
| if custom_db_path_from_env: |
| logger.info(f"检测到环境变量 CONTEXT_DB_PATH 设置为: '{custom_db_path_from_env}'") |
| |
| |
| db_dir = os.path.dirname(custom_db_path_from_env) |
| if db_dir: |
| try: |
| os.makedirs(db_dir, exist_ok=True) |
| DATABASE_PATH = custom_db_path_from_env |
| logger.info(f"使用 CONTEXT_DB_PATH 指定的数据库路径: {DATABASE_PATH}") |
| except PermissionError as e: |
| logger.error(f"根据 CONTEXT_DB_PATH 创建目录 '{db_dir}' 时发生权限错误: {e}。将回退到默认路径或内存数据库。") |
| DATABASE_PATH = None |
| except Exception as e: |
| logger.error(f"根据 CONTEXT_DB_PATH 创建目录 '{db_dir}' 时发生未知错误: {e}。将回退到默认路径或内存数据库。") |
| DATABASE_PATH = None |
| else: |
| DATABASE_PATH = custom_db_path_from_env |
| logger.info(f"使用 CONTEXT_DB_PATH 指定的数据库文件名 (将在当前目录创建): {DATABASE_PATH}") |
| |
| |
|
|
| else: |
| DATABASE_PATH = None |
|
|
| if DATABASE_PATH is None: |
| logger.info("CONTEXT_DB_PATH 未设置或使用失败,尝试默认应用数据目录。") |
| _home_dir = os.path.expanduser("~") |
| _app_data_dir = os.path.join(_home_dir, '.gemini_api_proxy', 'data') |
| try: |
| os.makedirs(_app_data_dir, exist_ok=True) |
| DATABASE_PATH = os.path.join(_app_data_dir, 'context_store.db') |
| logger.info(f"默认数据库路径设置为: {DATABASE_PATH}") |
| except PermissionError as e: |
| logger.error(f"创建默认应用数据目录 '{_app_data_dir}' 时发生权限错误: {e}。将回退到内存数据库。") |
| DATABASE_PATH = ":memory:" |
| except Exception as e: |
| logger.error(f"创建默认应用数据目录时发生未知错误: {e}。将回退到内存数据库。") |
| DATABASE_PATH = ":memory:" |
|
|
| if DATABASE_PATH == ":memory:": |
| DATABASE_URL = "sqlite+aiosqlite:///:memory:" |
| else: |
| DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_PATH}" |
|
|
| logger.info(f"最终数据库路径: {DATABASE_PATH}") |
| logger.info(f"最终数据库 URL: {DATABASE_URL}") |
| DEFAULT_CONTEXT_TTL_DAYS = 30 |
| IS_MEMORY_DB = DATABASE_PATH == ':memory:' |
|
|
| |
| @asynccontextmanager |
| async def get_db_connection() -> AsyncGenerator[aiosqlite.Connection, None]: |
| """ |
| (可能已废弃/特定用途) 获取一个原生的 aiosqlite 数据库连接。 |
| 使用 asynccontextmanager 确保连接在使用后能被正确关闭。 |
| 注意:项目主要应使用 SQLAlchemy 的 AsyncSession。 |
| """ |
| conn = None |
| try: |
| |
| conn = await aiosqlite.connect(DATABASE_PATH) |
| logger.info("数据库连接已获取 (类型: aiosqlite.Connection)。") |
| yield conn |
| except Exception as e: |
| |
| logger.error(f"获取 aiosqlite 数据库连接失败: {e}", exc_info=True) |
| raise |
| finally: |
| |
| if conn is not None: |
| await conn.close() |
| logger.debug("aiosqlite 数据库连接已关闭。") |
|
|
| |
| async def initialize_db_tables() -> None: |
| """ |
| 初始化数据库表。如果表已存在,则不会重复创建。 |
| 此函数应在应用启动时调用。 |
| """ |
| try: |
| |
| await _initialize_db_tables() |
| except Exception as e: |
| |
| logger.error(f"初始化数据库表失败: {e}", exc_info=True) |
| raise |
|
|
| async def _initialize_db_tables() -> None: |
| """ |
| 内部函数:使用 SQLAlchemy 的异步引擎和 Base.metadata 来创建所有定义的表。 |
| """ |
| |
| |
| engine = create_async_engine(DATABASE_URL, echo=False) |
|
|
| try: |
| |
| async with engine.begin() as conn: |
| |
| |
| await conn.run_sync(Base.metadata.create_all) |
| |
| logger.info("所有通过 SQLAlchemy Base 定义的数据库表已成功初始化/验证。") |
| except Exception as e: |
| |
| logger.error(f"使用 SQLAlchemy 初始化数据库表失败: {e}", exc_info=True) |
| raise |
| finally: |
| |
| await engine.dispose() |
| logger.debug("SQLAlchemy 异步引擎已关闭。") |
|
|
|
|
| |
|
|
| async def add_api_key( |
| db: AsyncSession, |
| key_string: str, |
| description: Optional[str] = None, |
| expires_at: Optional[datetime] = None, |
| is_active: bool = True, |
| enable_context_completion: bool = True, |
| user_id: Optional[str] = None |
| ) -> Optional[ApiKey]: |
| """ |
| 向数据库异步添加一个新的 API Key 记录。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key_string (str): 要添加的 API Key 字符串。 |
| description (Optional[str]): Key 的描述信息。 |
| expires_at (Optional[datetime]): Key 的过期时间 (UTC)。如果提供,请确保是 aware datetime。 |
| is_active (bool): Key 是否激活。 |
| enable_context_completion (bool): 此 Key 是否启用上下文补全。 |
| user_id (Optional[str]): 与此 Key 关联的用户 ID。 |
| |
| Returns: |
| Optional[ApiKey]: 如果成功创建,返回包含数据库生成信息的 ApiKey 对象; |
| 如果 Key 已存在或发生其他错误,则返回 None。 |
| """ |
| try: |
| |
| stmt_check = select(ApiKey).where(ApiKey.key_string == key_string) |
| result_check = await db.execute(stmt_check) |
| existing_key = result_check.scalar_one_or_none() |
| if existing_key: |
| logger.warning(f"尝试添加已存在的 API Key: {key_string[:8]}...") |
| return None |
|
|
| |
| new_api_key = ApiKey( |
| key_string=key_string, |
| description=description, |
| |
| expires_at=expires_at.replace(tzinfo=timezone.utc) if expires_at else None, |
| is_active=is_active, |
| enable_context_completion=enable_context_completion, |
| user_id=user_id |
| |
| ) |
| |
| db.add(new_api_key) |
| |
| await db.commit() |
| |
| await db.refresh(new_api_key) |
| |
| logger.info(f"成功添加 API Key: {new_api_key.key_string[:8]}... (ID: {new_api_key.id})") |
| return new_api_key |
| except sqlalchemy.exc.IntegrityError as e: |
| await db.rollback() |
| logger.error(f"添加 API Key 时发生唯一约束冲突 (Key 可能已存在): {key_string[:8]}... - {e}", exc_info=False) |
| return None |
| except Exception as e: |
| await db.rollback() |
| logger.error(f"添加 API Key {key_string[:8]}... 失败: {e}", exc_info=True) |
| return None |
|
|
| async def get_all_api_keys_from_db(db: AsyncSession) -> List[ApiKey]: |
| """ |
| 从数据库异步获取所有 API 密钥对象。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| |
| Returns: |
| List[ApiKey]: 包含所有 ApiKey 对象的列表,按创建时间排序。如果出错则返回空列表。 |
| """ |
| try: |
| |
| stmt = select(ApiKey).order_by(ApiKey.created_at.desc()) |
| |
| result = await db.execute(stmt) |
| |
| api_keys = result.scalars().all() |
| |
| logger.info(f"成功从数据库获取 {len(api_keys)} 个 API Key。") |
| return list(api_keys) |
| except Exception as e: |
| |
| logger.error(f"从数据库获取所有 API Key 失败: {e}", exc_info=True) |
| return [] |
|
|
| async def get_api_key_by_string(db: AsyncSession, key_string: str) -> Optional[ApiKey]: |
| """ |
| 根据 Key 字符串从数据库异步获取单个 API Key 对象。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key_string (str): 要查询的 API Key 字符串。 |
| |
| Returns: |
| Optional[ApiKey]: 如果找到匹配的 Key,则返回 ApiKey 对象;否则返回 None。 |
| """ |
| if not key_string: return None |
| try: |
| |
| stmt = select(ApiKey).where(ApiKey.key_string == key_string) |
| |
| result = await db.execute(stmt) |
| |
| api_key = result.scalar_one_or_none() |
| if api_key: |
| logger.debug(f"成功获取 API Key: {key_string[:8]}...") |
| else: |
| logger.debug(f"未找到 API Key: {key_string[:8]}...") |
| return api_key |
| except Exception as e: |
| |
| logger.error(f"获取 API Key {key_string[:8]}... 失败: {e}", exc_info=True) |
| return None |
|
|
| async def update_api_key(db: AsyncSession, key_string: str, updates: Dict[str, Any]) -> Optional[ApiKey]: |
| """ |
| 异步更新数据库中指定 API Key 的信息。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key_string (str): 要更新的 API Key 字符串。 |
| updates (Dict[str, Any]): 一个字典,包含要更新的字段名和对应的新值。 |
| 例如: {'description': '新描述', 'is_active': False}。 |
| |
| Returns: |
| Optional[ApiKey]: 如果成功更新,返回更新后的 ApiKey 对象; |
| 如果 Key 不存在或发生错误,则返回 None。 |
| """ |
| if not key_string or not updates: return None |
| try: |
| |
| stmt_select = select(ApiKey).where(ApiKey.key_string == key_string) |
| result_select = await db.execute(stmt_select) |
| api_key_to_update = result_select.scalar_one_or_none() |
|
|
| if not api_key_to_update: |
| logger.warning(f"尝试更新不存在的 API Key: {key_string[:8]}...") |
| return None |
|
|
| |
| |
| allowed_updates = { |
| k: v for k, v in updates.items() |
| if hasattr(ApiKey, k) and k not in ['id', 'key_string', 'created_at'] |
| } |
|
|
| |
| if 'expires_at' in allowed_updates: |
| expiry_value = allowed_updates['expires_at'] |
| if isinstance(expiry_value, datetime): |
| |
| if expiry_value.tzinfo is None: |
| logger.warning(f"更新 Key {key_string[:8]}... 的 expires_at 时提供了 naive datetime,假设为本地时间并转换为 UTC。建议提供 aware datetime。") |
| |
| |
| allowed_updates['expires_at'] = expiry_value.astimezone(timezone.utc) |
| else: |
| allowed_updates['expires_at'] = expiry_value.astimezone(timezone.utc) |
| elif expiry_value is None: |
| pass |
| else: |
| logger.warning(f"更新 Key {key_string[:8]}... 时提供的 expires_at ('{expiry_value}') 不是有效的 datetime 对象或 None,已忽略此字段。") |
| del allowed_updates['expires_at'] |
|
|
| |
| if not allowed_updates: |
| logger.warning(f"没有有效的字段需要为 Key {key_string[:8]}... 更新。") |
| return api_key_to_update |
|
|
| |
| stmt_update = ( |
| update(ApiKey) |
| .where(ApiKey.key_string == key_string) |
| .values(**allowed_updates) |
| .execution_options(synchronize_session="fetch") |
| ) |
| await db.execute(stmt_update) |
| await db.commit() |
| await db.refresh(api_key_to_update) |
| logger.info(f"成功更新 API Key: {key_string[:8]}... 更新内容: {allowed_updates}") |
| return api_key_to_update |
| except Exception as e: |
| await db.rollback() |
| logger.error(f"更新 API Key {key_string[:8]}... 失败: {e}", exc_info=True) |
| return None |
|
|
| async def delete_api_key(db: AsyncSession, key_string: str) -> bool: |
| """ |
| 从数据库异步删除指定的 API Key。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key_string (str): 要删除的 API Key 字符串。 |
| |
| Returns: |
| bool: 如果成功删除 Key (或 Key 原本就不存在),返回 True;如果发生错误,返回 False。 |
| 注意:即使 Key 不存在,rowcount 也可能为 0,但操作本身没有错误。 |
| 可以根据需求调整返回值逻辑,例如严格要求 rowcount > 0 才返回 True。 |
| """ |
| if not key_string: return False |
| try: |
| |
| stmt = delete(ApiKey).where(ApiKey.key_string == key_string) |
| |
| result = await db.execute(stmt) |
| |
| await db.commit() |
| |
| if result.rowcount > 0: |
| logger.info(f"成功删除 API Key: {key_string[:8]}...") |
| return True |
| else: |
| logger.warning(f"尝试删除 API Key: {key_string[:8]}... 时未找到匹配项。") |
| return True |
| except Exception as e: |
| await db.rollback() |
| logger.error(f"删除 API Key {key_string[:8]}... 失败: {e}", exc_info=True) |
| return False |
|
|
| |
|
|
| async def is_valid_proxy_key(db: AsyncSession, key: str) -> bool: |
| """ |
| 异步检查提供的代理 Key (字符串) 是否在数据库中存在且处于活动状态。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key (str): 要检查的代理 Key 字符串。 |
| |
| Returns: |
| bool: 如果 Key 有效且激活,返回 True;否则返回 False。 |
| """ |
| |
| api_key_obj = await get_api_key_by_string(db, key) |
| |
| is_valid = bool(api_key_obj and api_key_obj.is_active) |
| |
| logger.debug(f"检查 Key '{key[:8]}...' 有效性: {is_valid}") |
| return is_valid |
|
|
|
|
| async def get_key_id_by_cached_content_id(db: AsyncSession, cached_content_id: str) -> Optional[int]: |
| """ |
| (需要实现) 根据缓存内容 ID 获取关联的 Key ID。 |
| 目前是模拟实现。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| cached_content_id (str): 缓存内容的唯一 ID。 |
| |
| Returns: |
| Optional[int]: 关联的 Key ID,如果找不到或出错则返回 None。 |
| """ |
| |
| from app.core.database.models import CachedContent |
| try: |
| |
| |
| if isinstance(cached_content_id, int): |
| stmt = select(CachedContent.key_id).where(CachedContent.id == cached_content_id) |
| else: |
| stmt = select(CachedContent.key_id).where(CachedContent.content_id == cached_content_id) |
| |
| result = await db.execute(stmt) |
| key_id = result.scalar_one_or_none() |
| if key_id is not None: |
| logger.info(f"成功从缓存标识符 '{cached_content_id}' 获取到 key_id: {key_id}") |
| else: |
| logger.info(f"未从缓存标识符 '{cached_content_id}' 找到关联的 key_id。") |
| return key_id |
| except Exception as e: |
| logger.error(f"根据缓存内容标识符 '{cached_content_id}' 获取 Key ID 失败: {e}", exc_info=True) |
| return None |
|
|
| async def get_key_string_by_id(db: AsyncSession, key_id: int) -> Optional[str]: |
| """ |
| 根据 Key ID 获取 Key 字符串。 |
| """ |
| |
| from app.core.database.models import ApiKey |
| try: |
| stmt = select(ApiKey.key_string).where(ApiKey.id == key_id) |
| result = await db.execute(stmt) |
| key_string = result.scalar_one_or_none() |
| if key_string: |
| logger.info(f"成功为 Key ID {key_id} 获取到 Key 字符串: {key_string[:8]}...") |
| else: |
| logger.info(f"未为 Key ID {key_id} 找到对应的 Key 字符串。") |
| return key_string |
| except Exception as e: |
| logger.error(f"根据 Key ID {key_id} 获取 Key 字符串失败: {e}", exc_info=True) |
| return None |
|
|
| async def get_user_last_used_key_id(db: AsyncSession, user_id: str) -> Optional[int]: |
| """ |
| 获取指定用户上次成功使用的 Key ID (用于粘性会话)。 |
| """ |
| |
| from app.core.database.models import UserKeyAssociation |
| try: |
| stmt = select(UserKeyAssociation.key_id)\ |
| .where(UserKeyAssociation.user_id == user_id)\ |
| .order_by(UserKeyAssociation.last_used_timestamp.desc())\ |
| .limit(1) |
| result = await db.execute(stmt) |
| key_id = result.scalar_one_or_none() |
| if key_id is not None: |
| logger.info(f"用户 '{user_id}' 上次使用的 Key ID: {key_id}") |
| else: |
| logger.info(f"未找到用户 '{user_id}' 上次使用的 Key ID。") |
| return key_id |
| except Exception as e: |
| logger.error(f"获取用户 '{user_id}' 上次成功使用的 Key ID 失败: {e}", exc_info=True) |
| return None |
|
|
| async def get_key_scores(model_name: str) -> Dict[str, float]: |
| """ |
| (需要实现/可能已废弃) 获取指定模型的 Key 分数。 |
| 目前从内存缓存 key_scores_cache 获取 (如果存在)。 |
| |
| Args: |
| model_name (str): 模型名称。 |
| |
| Returns: |
| Dict[str, float]: 包含 Key 字符串和对应分数的字典。 |
| """ |
| logger.debug("get_key_scores 函数可能依赖于内存缓存或未完全实现的数据库逻辑。") |
| try: |
| |
| |
| with cache_lock: |
| scores = key_scores_cache.get(model_name, {}) |
| return scores.copy() |
| except Exception as e: |
| logger.error(f"获取模型 '{model_name}' 的 Key 分数失败: {e}", exc_info=True) |
| return {} |
|
|
| async def update_setting(db: AsyncSession, key: str, value: str): |
| """ |
| (重复/可能已废弃) 更新或插入设置项。 |
| 此函数与 set_setting 功能重复,且使用了原生 SQL。建议统一使用 set_setting。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| key (str): 设置项的键。 |
| value (str): 设置项的值。 |
| """ |
| logger.warning("调用了可能重复或已废弃的 update_setting 函数,建议使用 set_setting。") |
| try: |
| |
| |
| |
| |
| stmt = text("UPDATE settings SET value = :value WHERE key = :key") |
| parameters = {"key": key, "value": value} |
| result = await db.execute(stmt, parameters) |
| if result.rowcount == 0: |
| |
| stmt_insert = text("INSERT INTO settings (key, value) VALUES (:key, :value)") |
| await db.execute(stmt_insert, parameters) |
| logger.info(f"设置 '{key}' 不存在,已插入新值 '{value}'") |
| else: |
| logger.info(f"设置 '{key}' 已更新为 '{value}' (通过 update_setting)") |
| await db.commit() |
| except Exception as e: |
| logger.error(f"使用 update_setting 更新设置 '{key}' 时发生错误: {e}", exc_info=True) |
| await db.rollback() |
|
|