| |
| """ |
| 缓存管理模块。 |
| 负责处理与 Gemini API 原生缓存相关的操作,包括: |
| - 计算内容的哈希值。 |
| - 将字典格式的内容转换为 Gemini API SDK 的 Content 对象列表。 |
| - 调用 Gemini API 创建缓存。 |
| - 在本地数据库中存储和管理缓存元数据 (CachedContent 模型)。 |
| - 根据内容哈希或用户 ID 和消息查找有效缓存。 |
| - 删除缓存(包括数据库记录和 Gemini API 端的缓存)。 |
| - 清理过期和无效的缓存条目。 |
| """ |
| import hashlib |
| import json |
| import logging |
| from datetime import datetime, timedelta, timezone |
| from typing import Dict, Any, Optional, List |
|
|
| import google.generativeai as genai |
| from google.generativeai import types |
| from google.api_core import exceptions as google_exceptions |
|
|
| |
| |
| |
| from sqlalchemy.orm import Session |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from aiosqlite import Connection |
| from sqlalchemy import select, update, delete |
|
|
| from app.core.database.models import CachedContent |
| |
|
|
| |
| logger = logging.getLogger('my_logger') |
|
|
| class CacheManager: |
| """ |
| 缓存管理器类。 |
| 封装了与 Gemini API 缓存和本地数据库缓存记录交互的所有逻辑。 |
| """ |
|
|
| def _calculate_hash(self, content: dict) -> str: |
| """ |
| (内部辅助方法) 计算给定内容字典的 SHA-256 哈希值。 |
| 为了确保哈希的一致性,字典在序列化为 JSON 字符串之前会按键排序。 |
| |
| Args: |
| content (dict): 需要计算哈希的内容字典。 |
| |
| Returns: |
| str: 计算得到的十六进制哈希字符串。 |
| |
| Raises: |
| TypeError: 如果输入的内容不是字典类型。 |
| """ |
| |
| if not isinstance(content, dict): |
| logger.error(f"计算哈希时内容不是字典: {type(content)}") |
| raise TypeError("内容必须是字典类型") |
| try: |
| |
| |
| |
| content_str = json.dumps(content, sort_keys=True).encode('utf-8') |
| |
| return hashlib.sha256(content_str).hexdigest() |
| except Exception as e: |
| logger.error(f"计算内容哈希时发生错误: {e}", exc_info=True) |
| raise |
|
|
| def _convert_dict_to_gemini_content(self, content_dict: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| (内部辅助方法) 将包含 parts 的字典格式内容转换为 Gemini SDK 0.8.5 版本期望的字典列表格式。 |
| 主要用于将要缓存的内容转换为 Gemini API `create_cached_content` 方法接受的格式。 |
| 支持 text 和 inline_data (假设为 base64 编码) 类型的 part。 |
| |
| Args: |
| content_dict (Dict[str, Any]): 包含 "parts" 键的字典,其值为 part 字典列表。 |
| 或者包含 "messages" 键,其值为 OpenAI 格式的消息列表。 |
| 例如: {"messages": [{"role": "user", "content": "你好"}]} |
| 或: {"parts": [{"text": "你好"}, {"inline_data": {"mime_type": "image/png", "data": "base64..."}}]} |
| |
| Returns: |
| List[Dict[str, Any]]: 转换后的 Gemini 内容字典列表。如果转换失败或无有效 parts/messages,则返回空列表。 |
| |
| Raises: |
| TypeError: 如果输入的 content_dict 不是字典类型。 |
| """ |
| |
| if not isinstance(content_dict, dict): |
| logger.error(f"转换内容时输入不是字典: {type(content_dict)}") |
| raise TypeError("内容必须是字典类型") |
|
|
| processed_gemini_contents = [] |
|
|
| |
| if "messages" in content_dict and isinstance(content_dict["messages"], list): |
| from app.core.context.converter import convert_openai_to_gemini_contents |
| try: |
| |
| gemini_dicts = convert_openai_to_gemini_contents(content_dict["messages"]) |
| |
| for gemini_dict in gemini_dicts: |
| if "role" in gemini_dict and "parts" in gemini_dict: |
| |
| for part in gemini_dict.get("parts", []): |
| if "inline_data" in part and isinstance(part["inline_data"], dict): |
| if "data" in part["inline_data"] and isinstance(part["inline_data"]["data"], bytes): |
| logger.warning("inline_data 的 data 字段是 bytes,期望是 base64 字符串。将尝试编码。") |
| import base64 |
| try: |
| part["inline_data"]["data"] = base64.b64encode(part["inline_data"]["data"]).decode('utf-8') |
| except Exception as enc_err: |
| logger.error(f"Base64 编码 inline_data 时出错: {enc_err}") |
| |
| processed_gemini_contents.append(gemini_dict) |
| else: |
| logger.warning(f"从 messages 转换的字典缺少 role 或 parts: {gemini_dict}") |
| if processed_gemini_contents: |
| return processed_gemini_contents |
| except Exception as e: |
| logger.error(f"从 messages 转换 Gemini 内容字典时出错: {e}", exc_info=True) |
| return [] |
|
|
| |
| |
| |
| |
| raw_parts_data = content_dict.get("parts") |
| current_role = content_dict.get("role", "model") |
|
|
| if raw_parts_data and isinstance(raw_parts_data, list): |
| processed_parts = [] |
| for part_data in raw_parts_data: |
| if not isinstance(part_data, dict): |
| logger.warning(f"Part 数据不是字典格式,已跳过: {part_data}") |
| continue |
|
|
| if "text" in part_data and part_data["text"] is not None: |
| processed_parts.append({"text": part_data["text"]}) |
| elif "inline_data" in part_data and isinstance(part_data["inline_data"], dict): |
| inline_data_dict = part_data["inline_data"] |
| if "mime_type" in inline_data_dict and "data" in inline_data_dict: |
| |
| |
| data_value = inline_data_dict["data"] |
| if isinstance(data_value, bytes): |
| logger.warning("inline_data 的 data 字段是 bytes,期望是 base64 字符串。将尝试编码。") |
| import base64 |
| try: |
| data_value = base64.b64encode(data_value).decode('utf-8') |
| except Exception as enc_err: |
| logger.error(f"Base64 编码 inline_data 时出错: {enc_err}") |
| continue |
| |
| processed_parts.append({ |
| "inline_data": { |
| "mime_type": inline_data_dict["mime_type"], |
| "data": data_value |
| } |
| }) |
| else: |
| logger.warning(f"inline_data 缺少 mime_type 或 data: {part_data}") |
| |
| |
| if processed_parts: |
| processed_gemini_contents.append({"role": current_role, "parts": processed_parts}) |
| return processed_gemini_contents |
|
|
| logger.warning(f"无法从字典内容转换出有效的 Gemini 内容字典列表: {content_dict}") |
| return [] |
|
|
| |
| |
| |
| |
|
|
| async def create_cache(self, db: AsyncSession, user_id: str, api_key_id: int, content: dict, ttl: int) -> Optional[int]: |
| """ |
| 异步创建缓存条目。 |
| 首先计算内容哈希,检查数据库中是否已存在有效缓存。 |
| 如果不存在,则调用 Gemini API 创建缓存,并将返回的信息存入数据库。 |
| 注意:此方法期望接收 AsyncSession。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| user_id (str): 与缓存关联的用户 ID。 |
| api_key_id (int): 创建缓存时使用的 API Key 的数据库 ID。 |
| content (dict): 需要缓存的原始内容字典 (通常包含 "messages" 和 "model")。 |
| ttl (int): 缓存的生存时间 (秒)。 |
| |
| Returns: |
| Optional[int]: 如果成功创建或找到现有有效缓存,返回数据库中 CachedContent 条目的 ID; |
| 否则返回 None。 |
| """ |
| try: |
| |
| content_hash = self._calculate_hash(content) |
| logger.info(f"尝试为内容哈希 {content_hash[:8]}... 创建缓存 (用户: {user_id}, Key ID: {api_key_id})") |
| except Exception as hash_err: |
| logger.error(f"创建缓存时计算哈希失败: {hash_err}", exc_info=True) |
| return None |
|
|
| try: |
| |
| now_utc = datetime.utcnow().replace(tzinfo=timezone.utc) |
| stmt_check = select(CachedContent).where( |
| CachedContent.content_hash == content_hash, |
| CachedContent.expires_at > now_utc |
| ).limit(1) |
| result_check = await db.execute(stmt_check) |
| existing_cache = result_check.scalar_one_or_none() |
|
|
| if existing_cache: |
| logger.info(f"数据库中已存在有效缓存 (ID: {existing_cache.id}),跳过 Gemini API 创建。") |
| |
| update_stmt = ( |
| update(CachedContent) |
| .where(CachedContent.id == existing_cache.id) |
| .values( |
| last_used_at=now_utc, |
| usage_count=existing_cache.usage_count + 1 |
| ) |
| .execution_options(synchronize_session=False) |
| ) |
| await db.execute(update_stmt) |
| await db.commit() |
| return existing_cache.id |
|
|
| except Exception as db_check_err: |
| logger.error(f"检查数据库现有缓存时出错: {db_check_err}", exc_info=True) |
| |
| return None |
|
|
| |
| gemini_cached_content = None |
| try: |
| |
| gemini_content_list = self._convert_dict_to_gemini_content(content) |
| if not gemini_content_list: |
| logger.warning(f"转换内容为 Gemini Content 失败,无法创建 Gemini API 缓存。内容: {content}") |
| return None |
|
|
| |
| logger.debug(f"调用 Gemini API 创建缓存,内容: {gemini_content_list}, TTL: {ttl}") |
| gemini_cached_content = await genai.create_cached_content( |
| contents=gemini_content_list, |
| ttl=timedelta(seconds=ttl) |
| ) |
| |
| logger.info(f"成功创建 Gemini API 缓存: {gemini_cached_content.name} (过期时间: {gemini_cached_content.expire_time})") |
|
|
| |
| try: |
| |
| expire_time_dt = gemini_cached_content.expire_time |
| if not isinstance(expire_time_dt, datetime): |
| |
| |
| try: |
| expire_time_dt = expire_time_dt.replace(tzinfo=timezone.utc) |
| except AttributeError: |
| logger.error(f"无法处理 Gemini API 返回的过期时间类型: {type(expire_time_dt)}") |
| |
| expire_time_dt = datetime.utcnow().replace(tzinfo=timezone.utc) + timedelta(seconds=ttl) |
|
|
| |
| cached_content_db = CachedContent( |
| gemini_cache_id=gemini_cached_content.name, |
| content_hash=content_hash, |
| user_id=user_id, |
| api_key_id=api_key_id, |
| expires_at=expire_time_dt, |
| |
| content=json.dumps(content), |
| |
| last_used_at=datetime.utcnow().replace(tzinfo=timezone.utc), |
| usage_count=1 |
| |
| ) |
| db.add(cached_content_db) |
| await db.commit() |
| await db.refresh(cached_content_db) |
|
|
| logger.info(f"成功创建数据库缓存条目 (ID: {cached_content_db.id})") |
| return cached_content_db.id |
| except Exception as db_save_err: |
| logger.error(f"将 Gemini 缓存信息存入数据库时出错: {db_save_err}", exc_info=True) |
| await db.rollback() |
| |
| try: |
| logger.warning(f"因数据库保存失败,尝试删除 Gemini API 缓存: {gemini_cached_content.name}") |
| await genai.delete_cached_content(name=gemini_cached_content.name) |
| logger.info(f"已删除因数据库保存失败而创建的 Gemini API 缓存: {gemini_cached_content.name}") |
| except Exception as delete_err: |
| logger.error(f"尝试删除 Gemini API 缓存 {gemini_cached_content.name} 失败: {delete_err}") |
| return None |
|
|
| except google_exceptions.AlreadyExists as e: |
| logger.warning(f"尝试创建 Gemini API 缓存时发现已存在 (哈希: {content_hash[:8]}...): {e}") |
| |
| try: |
| stmt_find = select(CachedContent).where(CachedContent.content_hash == content_hash).limit(1) |
| result_find = await db.execute(stmt_find) |
| existing_db_cache = result_find.scalar_one_or_none() |
| if existing_db_cache: |
| logger.info(f"从数据库中找到了与已存在 Gemini 缓存对应的记录 (ID: {existing_db_cache.id})") |
| return existing_db_cache.id |
| else: |
| logger.error(f"Gemini API 报告缓存已存在,但在数据库中未找到对应记录 (哈希: {content_hash[:8]}...)。") |
| return None |
| except Exception as db_find_err: |
| logger.error(f"尝试查找已存在的 Gemini 缓存对应数据库记录时出错: {db_find_err}", exc_info=True) |
| return None |
| except google_exceptions.GoogleAPIError as e: |
| logger.error(f"调用 Gemini API 创建缓存失败: {e}", exc_info=True) |
| return None |
| except Exception as e: |
| logger.error(f"创建缓存过程中发生意外错误: {e}", exc_info=True) |
| return None |
|
|
| async def get_cache(self, db: AsyncSession, content_hash: str) -> Optional[Dict[str, Any]]: |
| """ |
| (异步方法) 根据内容哈希值从数据库获取缓存信息。 |
| """ |
| logger.info(f"尝试获取内容哈希 {content_hash[:8]}... 的缓存 (异步)") |
| try: |
| |
| stmt = select(CachedContent).where(CachedContent.content_hash == content_hash).limit(1) |
| result = await db.execute(stmt) |
| cached_content = result.scalar_one_or_none() |
|
|
| if cached_content: |
| now_utc = datetime.utcnow().replace(tzinfo=timezone.utc) |
| |
| expires_at_aware = cached_content.expires_at |
| if not expires_at_aware.tzinfo: |
| expires_at_aware = expires_at_aware.replace(tzinfo=timezone.utc) |
|
|
| if now_utc < expires_at_aware: |
| logger.info(f"找到有效缓存 (ID: {cached_content.id}, Gemini ID: {cached_content.gemini_cache_id[:8]}...) (异步)") |
| |
| |
| cached_content.last_used_at = now_utc |
| cached_content.usage_count += 1 |
| |
| await db.commit() |
|
|
| try: |
| original_content = json.loads(cached_content.content) |
| except json.JSONDecodeError: |
| logger.error(f"无法解析数据库中缓存 ID {cached_content.id} 的 content 字段。") |
| original_content = None |
| |
| return { |
| "gemini_cache_id": cached_content.gemini_cache_id, |
| "content": original_content |
| } |
| else: |
| logger.info(f"找到过期缓存 (ID: {cached_content.id}),视为未命中 (异步)。") |
| |
| |
| |
| |
| return None |
| else: |
| logger.info(f"未找到内容哈希 {content_hash[:8]}... 的缓存 (异步)。") |
| return None |
| except Exception as e: |
| logger.error(f"获取缓存 (哈希: {content_hash[:8]}...) 时出错 (异步): {e}", exc_info=True) |
| await db.rollback() |
| return None |
|
|
| async def find_cache(self, db: AsyncSession, user_id: str, messages: List[Dict[str, Any]]) -> Optional[str]: |
| """ |
| (异步方法) 根据用户 ID 和消息内容异步查找有效的缓存。 |
| 注意:此方法使用了异步 SQLAlchemy Session。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| user_id (str): 要查找缓存的用户 ID。 |
| messages (List[Dict[str, Any]]): OpenAI 格式的消息列表,用于计算哈希。 |
| |
| Returns: |
| Optional[str]: 如果找到有效缓存,返回其 Gemini 缓存 ID (gemini_cache_id);否则返回 None。 |
| """ |
| |
| |
| content_to_hash = {"messages": messages} |
| try: |
| content_hash = self._calculate_hash(content_to_hash) |
| except TypeError as e: |
| logger.error(f"查找缓存时计算哈希失败: {e}") |
| return None |
|
|
| logger.info(f"尝试为用户 {user_id} 查找内容哈希 {content_hash[:8]}... 的有效缓存") |
|
|
| try: |
| |
| now_utc = datetime.utcnow().replace(tzinfo=timezone.utc) |
| stmt = select(CachedContent).where( |
| CachedContent.user_id == user_id, |
| CachedContent.content_hash == content_hash, |
| CachedContent.expires_at > now_utc |
| ).limit(1) |
| |
| result = await db.execute(stmt) |
| cached_content = result.scalar_one_or_none() |
|
|
| if cached_content: |
| logger.info(f"为用户 {user_id} 找到有效缓存 (ID: {cached_content.id}, Gemini ID: {cached_content.gemini_cache_id[:8]}...)") |
| |
| update_stmt = ( |
| update(CachedContent) |
| .where(CachedContent.id == cached_content.id) |
| .values( |
| last_used_at=now_utc, |
| usage_count=cached_content.usage_count + 1 |
| ) |
| .execution_options(synchronize_session=False) |
| ) |
| await db.execute(update_stmt) |
| await db.commit() |
| |
| return cached_content.gemini_cache_id |
| else: |
| logger.info(f"未找到用户 {user_id} 内容哈希 {content_hash[:8]}... 的有效缓存。") |
| return None |
| except Exception as e: |
| logger.error(f"查找缓存 (用户: {user_id}, 哈希: {content_hash[:8]}...) 时出错: {e}", exc_info=True) |
| await db.rollback() |
| return None |
|
|
|
|
| async def delete_cache(self, db: AsyncSession, cache_id: int) -> bool: |
| """ |
| (异步方法) 删除指定 ID 的缓存条目(包括数据库记录和 Gemini API 端的缓存)。 |
| 注意:此方法使用了异步 SQLAlchemy Session。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| cache_id (int): 要删除的数据库缓存条目的 ID。 |
| |
| Returns: |
| bool: 如果成功删除数据库条目(无论 Gemini API 删除是否成功或缓存是否存在),返回 True; |
| 如果数据库条目未找到或删除过程中发生数据库错误,返回 False。 |
| """ |
| logger.info(f"尝试删除数据库缓存条目 (ID: {cache_id})") |
| try: |
| |
| stmt_select = select(CachedContent).where(CachedContent.id == cache_id) |
| result_select = await db.execute(stmt_select) |
| cached_content = result_select.scalar_one_or_none() |
|
|
| if cached_content: |
| gemini_cache_id = cached_content.gemini_cache_id |
| logger.info(f"找到数据库缓存条目 (ID: {cache_id}, Gemini ID: {gemini_cache_id[:8]}...),准备删除。") |
|
|
| |
| try: |
| |
| logger.debug(f"尝试删除 Gemini API 缓存: {gemini_cache_id}") |
| await genai.delete_cached_content(name=gemini_cache_id) |
| logger.info(f"成功删除 Gemini API 缓存: {gemini_cache_id}") |
| except google_exceptions.NotFound: |
| logger.warning(f"尝试删除 Gemini API 缓存 {gemini_cache_id} 时发现不存在。") |
| except google_exceptions.GoogleAPIError as e: |
| logger.error(f"调用 Gemini API 删除缓存 {gemini_cache_id} 失败: {e}", exc_info=True) |
| |
| except Exception as e: |
| logger.error(f"删除 Gemini API 缓存 {gemini_cache_id} 过程中发生意外错误: {e}", exc_info=True) |
| |
|
|
| |
| stmt_delete = delete(CachedContent).where(CachedContent.id == cache_id) |
| await db.execute(stmt_delete) |
| await db.commit() |
| logger.info(f"成功删除数据库缓存条目 (ID: {cache_id})") |
| return True |
| else: |
| logger.warning(f"未找到数据库缓存条目 (ID: {cache_id}),无需删除。") |
| return False |
| except Exception as e: |
| logger.error(f"删除缓存 (ID: {cache_id}) 时发生数据库错误: {e}", exc_info=True) |
| await db.rollback() |
| return False |
|
|
| async def cleanup_expired_caches(self, db: AsyncSession): |
| """ |
| (异步方法) 清理数据库中所有已过期的缓存条目。 |
| 注意:此方法目前仅删除数据库记录,未主动删除对应的 Gemini API 缓存。 |
| Gemini API 的缓存有自己的 TTL,会自动过期。如果需要强制删除,应调用 delete_cache。 |
| 此方法使用了异步 SQLAlchemy Session。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| """ |
| logger.info("开始清理数据库中过期的缓存条目...") |
| cleaned_count = 0 |
| try: |
| |
| now_utc = datetime.utcnow().replace(tzinfo=timezone.utc) |
| |
| stmt_select = select(CachedContent.id, CachedContent.gemini_cache_id).where(CachedContent.expires_at <= now_utc) |
| result_select = await db.execute(stmt_select) |
| expired_caches = result_select.all() |
|
|
| if expired_caches: |
| expired_ids = [cache.id for cache in expired_caches] |
| logger.info(f"发现 {len(expired_ids)} 个过期的数据库缓存条目,准备删除...") |
| |
| stmt_delete = delete(CachedContent).where(CachedContent.id.in_(expired_ids)) |
| |
| result_delete = await db.execute(stmt_delete) |
| await db.commit() |
| cleaned_count = result_delete.rowcount |
| logger.info(f"成功清理了 {cleaned_count} 个过期的数据库缓存条目。") |
| |
| |
| |
| else: |
| logger.info("未发现需要清理的过期数据库缓存条目。") |
|
|
| except Exception as e: |
| logger.error(f"清理过期缓存时出错: {e}", exc_info=True) |
| await db.rollback() |
|
|
| async def cleanup_invalid_caches(self, db: AsyncSession): |
| """ |
| (异步方法) 清理数据库中无效的缓存条目(即在 Gemini API 端已不存在的缓存)。 |
| 遍历数据库中的所有缓存条目,尝试调用 Gemini API 获取对应的缓存对象。 |
| 如果 Gemini API 返回 NotFound 错误,则从数据库中删除该条目。 |
| 注意:此方法使用了异步 SQLAlchemy Session。 |
| |
| Args: |
| db (AsyncSession): SQLAlchemy 异步数据库会话。 |
| """ |
| logger.info("开始清理无效的数据库缓存条目 (与 Gemini API 同步)...") |
| cleaned_count = 0 |
| invalid_ids_to_delete = [] |
|
|
| try: |
| |
| stmt_select = select(CachedContent.id, CachedContent.gemini_cache_id) |
| result_select = await db.execute(stmt_select) |
| all_db_caches = result_select.all() |
| logger.debug(f"从数据库获取了 {len(all_db_caches)} 条缓存记录进行检查。") |
|
|
| |
| for db_cache in all_db_caches: |
| db_id = db_cache.id |
| gemini_cache_id = db_cache.gemini_cache_id |
| if not gemini_cache_id: |
| logger.debug(f"数据库缓存条目 (ID: {db_id}) 没有 Gemini Cache ID,跳过检查。") |
| continue |
| logger.debug(f"检查数据库缓存条目 (ID: {db_id}, Gemini ID: {gemini_cache_id[:8]}...)") |
| try: |
| |
| await genai.get_cached_content(name=gemini_cache_id) |
| |
| logger.debug(f"Gemini API 缓存 {gemini_cache_id[:8]}... 存在。") |
| except google_exceptions.NotFound: |
| |
| logger.warning(f"Gemini API 缓存 {gemini_cache_id[:8]}... 不存在,标记数据库条目 (ID: {db_id}) 为待删除。") |
| invalid_ids_to_delete.append(db_id) |
| except google_exceptions.GoogleAPIError as e: |
| |
| logger.error(f"检查 Gemini API 缓存 {gemini_cache_id[:8]}... 时发生 Google API 错误: {e}", exc_info=True) |
| except Exception as e: |
| |
| logger.error(f"检查 Gemini API 缓存 {gemini_cache_id[:8]}... 时发生意外错误: {e}", exc_info=True) |
|
|
| |
| if invalid_ids_to_delete: |
| logger.info(f"准备从数据库删除 {len(invalid_ids_to_delete)} 个无效缓存条目...") |
| stmt_delete = delete(CachedContent).where(CachedContent.id.in_(invalid_ids_to_delete)) |
| result_delete = await db.execute(stmt_delete) |
| await db.commit() |
| cleaned_count = result_delete.rowcount |
| logger.info(f"成功清理了 {cleaned_count} 个无效的数据库缓存条目。") |
| else: |
| logger.info("未发现需要清理的无效数据库缓存条目。") |
|
|
| except Exception as e: |
| logger.error(f"清理无效缓存时出错: {e}", exc_info=True) |
| await db.rollback() |
|
|