Spaces:
Running
Running
| """ | |
| 數據庫緩存層 - 減少 Firestore 調用頻率 | |
| 實現多級緩存策略,大幅降低數據庫讀取次數 | |
| """ | |
| import logging | |
| import asyncio | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Any, Optional, List, Tuple | |
| from collections import OrderedDict | |
| import hashlib | |
| import json | |
| logger = logging.getLogger("DatabaseCache") | |
| class LRUCache: | |
| """LRU (Least Recently Used) 緩存實現""" | |
| def __init__(self, max_size: int = 1000, ttl_seconds: int = 300): | |
| """ | |
| 初始化 LRU 緩存 | |
| Args: | |
| max_size: 最大緩存條目數 | |
| ttl_seconds: 緩存過期時間(秒) | |
| """ | |
| self.cache: OrderedDict = OrderedDict() | |
| self.max_size = max_size | |
| self.ttl = timedelta(seconds=ttl_seconds) | |
| self.hits = 0 | |
| self.misses = 0 | |
| self._lock = asyncio.Lock() | |
| async def get(self, key: str) -> Optional[Any]: | |
| """獲取緩存值""" | |
| async with self._lock: | |
| if key not in self.cache: | |
| self.misses += 1 | |
| return None | |
| value, expire_time = self.cache[key] | |
| # 檢查是否過期 | |
| if datetime.now() > expire_time: | |
| del self.cache[key] | |
| self.misses += 1 | |
| return None | |
| # 移到最後(表示最近使用) | |
| self.cache.move_to_end(key) | |
| self.hits += 1 | |
| return value | |
| async def set(self, key: str, value: Any): | |
| """設置緩存值""" | |
| async with self._lock: | |
| expire_time = datetime.now() + self.ttl | |
| if key in self.cache: | |
| # 更新現有值 | |
| self.cache[key] = (value, expire_time) | |
| self.cache.move_to_end(key) | |
| else: | |
| # 新增值 | |
| self.cache[key] = (value, expire_time) | |
| # 如果超過最大容量,移除最舊的 | |
| if len(self.cache) > self.max_size: | |
| oldest_key = next(iter(self.cache)) | |
| del self.cache[oldest_key] | |
| logger.debug(f"LRU 緩存已滿,移除最舊條目: {oldest_key}") | |
| async def delete(self, key: str): | |
| """刪除緩存值""" | |
| async with self._lock: | |
| if key in self.cache: | |
| del self.cache[key] | |
| async def clear(self): | |
| """清空緩存""" | |
| async with self._lock: | |
| self.cache.clear() | |
| self.hits = 0 | |
| self.misses = 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """獲取緩存統計""" | |
| total = self.hits + self.misses | |
| hit_rate = (self.hits / total * 100) if total > 0 else 0 | |
| return { | |
| "size": len(self.cache), | |
| "max_size": self.max_size, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "hit_rate": f"{hit_rate:.2f}%" | |
| } | |
| class DatabaseCache: | |
| """數據庫緩存管理器""" | |
| def __init__(self): | |
| # 不同數據類型使用不同的緩存策略 | |
| self.user_cache = LRUCache(max_size=500, ttl_seconds=600) # 用戶資料:10分鐘 | |
| self.chat_cache = LRUCache(max_size=300, ttl_seconds=300) # 對話資料:5分鐘 | |
| self.message_cache = LRUCache(max_size=1000, ttl_seconds=180) # 消息歷史:3分鐘 | |
| self.memory_cache = LRUCache(max_size=200, ttl_seconds=900) # 記憶:15分鐘 | |
| # 寫入緩衝區(批量寫入優化) | |
| self.write_buffer: Dict[str, List[Dict[str, Any]]] = { | |
| "messages": [], | |
| "memories": [], | |
| "chats": [] | |
| } | |
| self.write_buffer_lock = asyncio.Lock() | |
| self.write_buffer_max_size = 50 # 達到50條時觸發批量寫入 | |
| self.write_buffer_timeout = 10 # 10秒後強制寫入 | |
| # 請求合併(同一查詢只執行一次) | |
| self.pending_requests: Dict[str, asyncio.Future] = {} | |
| self.pending_lock = asyncio.Lock() | |
| # 其他快取:環境、反地理、路徑 | |
| self.env_ctx_cache = LRUCache(max_size=1000, ttl_seconds=600) # 使用者環境快取:10 分鐘 | |
| self.geo_cache = LRUCache(max_size=5000, ttl_seconds=604800) # 反地理快取:7 天 | |
| self.route_cache = LRUCache(max_size=5000, ttl_seconds=86400) # 路線快取:1 天 | |
| logger.info("數據庫緩存管理器初始化完成") | |
| def _generate_cache_key(self, operation: str, **kwargs) -> str: | |
| """生成緩存鍵""" | |
| # 將參數排序後生成哈希,確保相同參數生成相同鍵 | |
| params_str = json.dumps(kwargs, sort_keys=True, default=str) | |
| key_hash = hashlib.md5(f"{operation}:{params_str}".encode()).hexdigest() | |
| return f"{operation}:{key_hash}" | |
| async def get_user_cached(self, user_id: str) -> Optional[Dict[str, Any]]: | |
| """獲取緩存的用戶資料""" | |
| cache_key = self._generate_cache_key("user", user_id=user_id) | |
| return await self.user_cache.get(cache_key) | |
| async def set_user_cache(self, user_id: str, user_data: Dict[str, Any]): | |
| """設置用戶緩存""" | |
| cache_key = self._generate_cache_key("user", user_id=user_id) | |
| await self.user_cache.set(cache_key, user_data) | |
| async def get_chat_cached(self, chat_id: str) -> Optional[Dict[str, Any]]: | |
| """獲取緩存的對話資料""" | |
| cache_key = self._generate_cache_key("chat", chat_id=chat_id) | |
| return await self.chat_cache.get(cache_key) | |
| async def set_chat_cache(self, chat_id: str, chat_data: Dict[str, Any]): | |
| """設置對話緩存""" | |
| cache_key = self._generate_cache_key("chat", chat_id=chat_id) | |
| await self.chat_cache.set(cache_key, chat_data) | |
| async def invalidate_chat_cache(self, chat_id: str): | |
| """使對話緩存失效(當對話更新時調用)""" | |
| cache_key = self._generate_cache_key("chat", chat_id=chat_id) | |
| await self.chat_cache.delete(cache_key) | |
| async def get_user_chats_cached(self, user_id: str) -> Optional[List[Dict[str, Any]]]: | |
| """獲取緩存的用戶對話列表""" | |
| cache_key = self._generate_cache_key("user_chats", user_id=user_id) | |
| return await self.chat_cache.get(cache_key) | |
| async def set_user_chats_cache(self, user_id: str, chats: List[Dict[str, Any]]): | |
| """設置用戶對話列表緩存""" | |
| cache_key = self._generate_cache_key("user_chats", user_id=user_id) | |
| await self.chat_cache.set(cache_key, chats) | |
| async def invalidate_user_chats_cache(self, user_id: str): | |
| """使用戶對話列表緩存失效""" | |
| cache_key = self._generate_cache_key("user_chats", user_id=user_id) | |
| await self.chat_cache.delete(cache_key) | |
| async def get_memories_cached(self, user_id: str, memory_type: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: | |
| """獲取緩存的記憶""" | |
| cache_key = self._generate_cache_key("memories", user_id=user_id, memory_type=memory_type) | |
| return await self.memory_cache.get(cache_key) | |
| async def set_memories_cache(self, user_id: str, memories: List[Dict[str, Any]], memory_type: Optional[str] = None): | |
| """設置記憶緩存""" | |
| cache_key = self._generate_cache_key("memories", user_id=user_id, memory_type=memory_type) | |
| await self.memory_cache.set(cache_key, memories) | |
| async def coalesce_request(self, cache_key: str, fetch_func): | |
| """ | |
| 請求合併:如果同一個請求正在執行,等待其結果而不是重複執行 | |
| Args: | |
| cache_key: 請求的唯一標識 | |
| fetch_func: 實際的數據獲取函數(async) | |
| Returns: | |
| 查詢結果 | |
| """ | |
| async with self.pending_lock: | |
| # 檢查是否有相同請求正在執行 | |
| if cache_key in self.pending_requests: | |
| logger.debug(f"請求合併:等待現有請求 {cache_key}") | |
| # 等待現有請求完成 | |
| future = self.pending_requests[cache_key] | |
| else: | |
| # 創建新的請求 | |
| future = asyncio.create_task(fetch_func()) | |
| self.pending_requests[cache_key] = future | |
| try: | |
| # 等待結果 | |
| result = await future | |
| return result | |
| finally: | |
| # 清理已完成的請求 | |
| async with self.pending_lock: | |
| if cache_key in self.pending_requests: | |
| del self.pending_requests[cache_key] | |
| async def buffer_write(self, collection: str, data: Dict[str, Any]) -> bool: | |
| """ | |
| 緩衝寫入:先存入緩衝區,達到一定數量或時間後批量寫入 | |
| Args: | |
| collection: 集合名稱 (messages/memories/chats) | |
| data: 要寫入的數據 | |
| Returns: | |
| True 如果已觸發批量寫入 | |
| """ | |
| async with self.write_buffer_lock: | |
| if collection not in self.write_buffer: | |
| self.write_buffer[collection] = [] | |
| # 添加時間戳 | |
| data["_buffered_at"] = datetime.now() | |
| self.write_buffer[collection].append(data) | |
| # 檢查是否需要立即寫入 | |
| if len(self.write_buffer[collection]) >= self.write_buffer_max_size: | |
| logger.info(f"寫入緩衝區已滿 ({collection}),觸發批量寫入") | |
| return True | |
| return False | |
| async def flush_write_buffer(self, collection: Optional[str] = None) -> Dict[str, int]: | |
| """ | |
| 清空寫入緩衝區,執行批量寫入 | |
| Args: | |
| collection: 指定集合名稱,None 表示清空所有 | |
| Returns: | |
| 每個集合的寫入數量 | |
| """ | |
| async with self.write_buffer_lock: | |
| collections_to_flush = [collection] if collection else list(self.write_buffer.keys()) | |
| result = {} | |
| for coll in collections_to_flush: | |
| if coll not in self.write_buffer or not self.write_buffer[coll]: | |
| result[coll] = 0 | |
| continue | |
| items = self.write_buffer[coll] | |
| self.write_buffer[coll] = [] | |
| result[coll] = len(items) | |
| logger.info(f"批量寫入 {coll}: {len(items)} 條記錄") | |
| # 這裡需要實際的批量寫入邏輯 | |
| # 將在 database.py 中實現 | |
| return result | |
| async def get_buffer_size(self) -> Dict[str, int]: | |
| """獲取各緩衝區的大小""" | |
| async with self.write_buffer_lock: | |
| return {k: len(v) for k, v in self.write_buffer.items()} | |
| def get_all_stats(self) -> Dict[str, Any]: | |
| """獲取所有緩存統計""" | |
| return { | |
| "user_cache": self.user_cache.get_stats(), | |
| "chat_cache": self.chat_cache.get_stats(), | |
| "message_cache": self.message_cache.get_stats(), | |
| "memory_cache": self.memory_cache.get_stats(), | |
| "env_ctx_cache": self.env_ctx_cache.get_stats(), | |
| "geo_cache": self.geo_cache.get_stats(), | |
| "route_cache": self.route_cache.get_stats(), | |
| "write_buffer": {k: len(v) for k, v in self.write_buffer.items()} | |
| } | |
| async def clear_all(self): | |
| """清空所有緩存""" | |
| await self.user_cache.clear() | |
| await self.chat_cache.clear() | |
| await self.message_cache.clear() | |
| await self.memory_cache.clear() | |
| logger.info("所有緩存已清空") | |
| # ===== 環境/地理/路線 快取 API ===== | |
| async def get_env_ctx_cached(self, user_id: str) -> Optional[Dict[str, Any]]: | |
| key = self._generate_cache_key("env_ctx", user_id=user_id) | |
| return await self.env_ctx_cache.get(key) | |
| async def set_env_ctx_cache(self, user_id: str, ctx: Dict[str, Any]): | |
| key = self._generate_cache_key("env_ctx", user_id=user_id) | |
| await self.env_ctx_cache.set(key, ctx) | |
| async def get_geo_cached(self, geohash7: str) -> Optional[Dict[str, Any]]: | |
| key = self._generate_cache_key("geo", geohash=geohash7) | |
| return await self.geo_cache.get(key) | |
| async def set_geo_cache(self, geohash7: str, payload: Dict[str, Any]): | |
| key = self._generate_cache_key("geo", geohash=geohash7) | |
| await self.geo_cache.set(key, payload) | |
| async def get_route_cached(self, cache_key: str) -> Optional[Dict[str, Any]]: | |
| key = self._generate_cache_key("route", key=cache_key) | |
| return await self.route_cache.get(key) | |
| async def set_route_cache(self, cache_key: str, payload: Dict[str, Any]): | |
| key = self._generate_cache_key("route", key=cache_key) | |
| await self.route_cache.set(key, payload) | |
| async def get_tdx_cached(self, cache_key: str) -> Optional[Any]: | |
| """獲取 TDX API 快取資料""" | |
| return await self.route_cache.get(cache_key) | |
| async def set_tdx_cache(self, cache_key: str, data: Any, ttl: int = 60): | |
| """設置 TDX API 快取資料(使用 route_cache,因為 TDX 也是路線相關)""" | |
| await self.route_cache.set(cache_key, data) | |
| # 全局緩存實例 | |
| db_cache = DatabaseCache() | |
| async def periodic_cache_maintenance(): | |
| """定期緩存維護任務""" | |
| while True: | |
| try: | |
| await asyncio.sleep(300) # 每5分鐘執行一次 | |
| # 輸出緩存統計 | |
| stats = db_cache.get_all_stats() | |
| logger.info(f"緩存統計: {stats}") | |
| # 檢查寫入緩衝區 | |
| buffer_size = await db_cache.get_buffer_size() | |
| if any(size > 0 for size in buffer_size.values()): | |
| logger.info(f"清空寫入緩衝區: {buffer_size}") | |
| await db_cache.flush_write_buffer() | |
| except Exception as e: | |
| logger.error(f"緩存維護任務出錯: {e}") | |