Spaces:
Running
Running
File size: 10,759 Bytes
3f0377e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
"""
優化後的數據庫調用層
整合緩存、批量操作、請求合併等優化策略
"""
import logging
from datetime import datetime
from typing import Dict, Any, Optional, List
import asyncio
from .base import (
firestore_db, users_collection, chats_collection,
messages_collection,
get_user_by_id as _get_user_by_id_original,
get_chat as _get_chat_original,
get_user_chats as _get_user_chats_original,
save_message as _save_message_original,
save_chat_message as _save_chat_message_original,
get_user_memories as _get_user_memories_original,
)
from .cache import db_cache
logger = logging.getLogger("DatabaseOptimized")
# ==================== 用戶操作優化 ====================
async def get_user_by_id(user_id: str) -> Optional[Dict[str, Any]]:
"""
優化版:獲取用戶資料(帶緩存)
優化策略:
1. 先查緩存
2. 緩存未命中時,使用請求合併避免重複查詢
3. 查詢結果寫入緩存
"""
# 1. 檢查緩存
cached_user = await db_cache.get_user_cached(user_id)
if cached_user is not None:
logger.debug(f"用戶緩存命中: {user_id}")
return cached_user
# 2. 使用請求合併
cache_key = f"fetch_user:{user_id}"
async def fetch_user():
logger.debug(f"從數據庫獲取用戶: {user_id}")
user = await _get_user_by_id_original(user_id)
if user:
# 寫入緩存
await db_cache.set_user_cache(user_id, user)
return user
return await db_cache.coalesce_request(cache_key, fetch_user)
# ==================== 對話操作優化 ====================
async def get_chat(chat_id: str) -> Dict[str, Any]:
"""
優化版:獲取對話(帶緩存)
優化策略:
1. 先查緩存
2. 緩存未命中時查詢數據庫
3. 結果寫入緩存
"""
# 1. 檢查緩存
cached_chat = await db_cache.get_chat_cached(chat_id)
if cached_chat is not None:
logger.debug(f"對話緩存命中: {chat_id}")
return {"success": True, "chat": cached_chat}
# 2. 使用請求合併
cache_key = f"fetch_chat:{chat_id}"
async def fetch_chat():
logger.debug(f"從數據庫獲取對話: {chat_id}")
result = await _get_chat_original(chat_id)
if result.get("success") and result.get("chat"):
# 寫入緩存
await db_cache.set_chat_cache(chat_id, result["chat"])
return result
return await db_cache.coalesce_request(cache_key, fetch_chat)
async def get_user_chats(user_id: str) -> Dict[str, Any]:
"""
優化版:獲取用戶對話列表(帶緩存)
優化策略:
1. 先查緩存
2. 緩存未命中時查詢數據庫
3. 結果寫入緩存(短時間)
"""
# 1. 檢查緩存
cached_chats = await db_cache.get_user_chats_cached(user_id)
if cached_chats is not None:
logger.debug(f"用戶對話列表緩存命中: {user_id}")
return {"success": True, "chats": cached_chats}
# 2. 使用請求合併
cache_key = f"fetch_user_chats:{user_id}"
async def fetch_chats():
logger.debug(f"從數據庫獲取用戶對話列表: {user_id}")
result = await _get_user_chats_original(user_id)
if result.get("success") and result.get("chats"):
# 寫入緩存
await db_cache.set_user_chats_cache(user_id, result["chats"])
return result
return await db_cache.coalesce_request(cache_key, fetch_chats)
async def save_chat_message(chat_id: str, sender: str, content: str, buffered: bool = False) -> Dict[str, Any]:
"""
優化版:保存對話消息(支援批量緩衝)
優化策略:
1. buffered=True: 先放入緩衝區,10秒或50條後批量寫入(適用於高頻場景)
2. buffered=False: 直接寫入(確保數據一致性,適用於關鍵操作)
3. 使對話緩存失效
"""
try:
if buffered:
# 使用緩衝寫入(非阻塞)
should_flush = await db_cache.buffer_write("messages", {
"chat_id": chat_id,
"sender": sender,
"content": content,
"timestamp": datetime.now()
})
# 如果緩衝區已滿,觸發批量寫入(後台執行)
if should_flush:
import asyncio
asyncio.create_task(_flush_message_buffer())
# 使對話緩存失效
await db_cache.invalidate_chat_cache(chat_id)
return {"success": True, "buffered": True}
else:
# 直接寫入(阻塞)
result = await _save_chat_message_original(chat_id, sender, content)
if result.get("success"):
# 使對話緩存失效
await db_cache.invalidate_chat_cache(chat_id)
logger.debug(f"消息已保存並使緩存失效: {chat_id}")
return result
except Exception as e:
logger.error(f"保存消息失敗: {e}")
return {"success": False, "error": str(e)}
async def _flush_message_buffer():
"""清空消息緩衝區並批量寫入"""
try:
result = await db_cache.flush_write_buffer("messages")
if result.get("messages", 0) > 0:
logger.info(f"✅ 批量寫入 {result['messages']} 條消息")
except Exception as e:
logger.error(f"❌ 批量寫入失敗: {e}")
# ==================== 記憶操作優化 ====================
async def get_user_memories(
user_id: str,
memory_type: Optional[str] = None,
limit: int = 10,
min_importance: float = 0.0
) -> Dict[str, Any]:
"""
優化版:獲取用戶記憶(帶緩存)
優化策略:
1. 先查緩存(只緩存常見查詢)
2. 緩存未命中時查詢數據庫
3. 結果寫入緩存
"""
# 只對標準查詢使用緩存(避免緩存碎片化)
use_cache = (limit == 10 and min_importance == 0.0)
if use_cache:
# 1. 檢查緩存
cached_memories = await db_cache.get_memories_cached(user_id, memory_type)
if cached_memories is not None:
logger.debug(f"記憶緩存命中: {user_id}, type={memory_type}")
return {"success": True, "memories": cached_memories}
# 2. 查詢數據庫
cache_key = f"fetch_memories:{user_id}:{memory_type}:{limit}:{min_importance}"
async def fetch_memories():
logger.debug(f"從數據庫獲取記憶: {user_id}")
result = await _get_user_memories_original(user_id, memory_type, limit, min_importance)
if use_cache and result.get("success") and result.get("memories"):
# 寫入緩存
await db_cache.set_memories_cache(user_id, result["memories"], memory_type)
return result
return await db_cache.coalesce_request(cache_key, fetch_memories)
# ==================== 批量操作優化 ====================
class BatchWriter:
"""批量寫入管理器"""
def __init__(self):
self.batch_queue: Dict[str, List[Dict[str, Any]]] = {
"messages": [],
"memories": [],
}
self.batch_size = 50
self.lock = asyncio.Lock()
async def add_message(self, user_id: str, content: str, is_bot: bool = False):
"""添加消息到批量寫入隊列"""
async with self.lock:
self.batch_queue["messages"].append({
"user_id": user_id,
"content": content,
"is_bot": is_bot,
"timestamp": datetime.now(),
})
# 達到批量大小時觸發寫入
if len(self.batch_queue["messages"]) >= self.batch_size:
await self._flush_messages()
async def _flush_messages(self):
"""批量寫入消息"""
if not self.batch_queue["messages"]:
return
messages = self.batch_queue["messages"]
self.batch_queue["messages"] = []
logger.info(f"批量寫入 {len(messages)} 條消息到 Firestore")
try:
# Firestore 批量寫入
batch = firestore_db.batch()
for msg in messages:
doc_ref = messages_collection.document()
batch.set(doc_ref, msg)
# 執行批量寫入
await asyncio.to_thread(batch.commit)
logger.info(f"✅ 批量寫入成功: {len(messages)} 條消息")
except Exception as e:
logger.error(f"批量寫入失敗: {e}")
# 失敗時重新加入隊列
self.batch_queue["messages"].extend(messages)
async def flush_all(self):
"""清空所有隊列"""
async with self.lock:
await self._flush_messages()
# 全局批量寫入器
batch_writer = BatchWriter()
# ==================== 查詢優化工具 ====================
class QueryOptimizer:
"""查詢優化工具"""
def __init__(self):
# 記錄查詢統計
self.query_stats: Dict[str, Dict[str, int]] = {}
self.lock = asyncio.Lock()
async def record_query(self, query_type: str, from_cache: bool):
"""記錄查詢統計"""
async with self.lock:
if query_type not in self.query_stats:
self.query_stats[query_type] = {"total": 0, "cached": 0, "db": 0}
self.query_stats[query_type]["total"] += 1
if from_cache:
self.query_stats[query_type]["cached"] += 1
else:
self.query_stats[query_type]["db"] += 1
def get_stats(self) -> Dict[str, Any]:
"""獲取查詢統計"""
stats = {}
for query_type, counts in self.query_stats.items():
total = counts["total"]
cached = counts["cached"]
cache_hit_rate = (cached / total * 100) if total > 0 else 0
stats[query_type] = {
"total": total,
"from_cache": cached,
"from_db": counts["db"],
"cache_hit_rate": f"{cache_hit_rate:.2f}%"
}
return stats
async def reset_stats(self):
"""重置統計"""
async with self.lock:
self.query_stats.clear()
# 全局查詢優化器
query_optimizer = QueryOptimizer()
# ==================== 導出接口 ====================
__all__ = [
"get_user_by_id",
"get_chat",
"get_user_chats",
"save_chat_message",
"get_user_memories",
"batch_writer",
"query_optimizer",
]
|