import time import functools from typing import Generator, Any from pymongo import MongoClient import hashlib import json import datetime import random from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY from .chat_messages import ChatMessages from .mongodb_cost import record_api_cost, check_cost_limits from .mongodb_init import mongo_client as client def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str: """创建缓存键""" # 将参数转换为可序列化的格式 cache_dict = { 'func_name': func_name, 'args': args, 'kwargs': kwargs } # 转换为JSON字符串并创建哈希 cache_string = json.dumps(cache_dict, sort_keys=True) return hashlib.md5(cache_string.encode()).hexdigest() def llm_api_cache(): """MongoDB缓存装饰器""" db_name=MONOGODB_DB_NAME collection_name='stream_chat' def dummy_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): # 移除 use_cache 参数,避免传递给原函数 kwargs.pop('use_cache', None) return func(*args, **kwargs) return wrapper if not ENABLE_MONOGODB: return dummy_decorator def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): check_cost_limits() use_cache = kwargs.pop('use_cache', True) # pop很重要 if not ENABLE_MONOGODB_CACHE: use_cache = False db = client[db_name] collection = db[collection_name] # 创建缓存键 cache_key = create_cache_key(func.__name__, args, kwargs) # 检查缓存 if use_cache: cached_data = list(collection.aggregate([ {'$match': {'cache_key': cache_key}}, {'$sample': {'size': 1}} ])) cached_data = cached_data[0] if cached_data else None if cached_data: # 如果有缓存,yield缓存的结果 messages = ChatMessages(cached_data['return_value']) messages.model = args[0]['model'] for item in cached_data['yields']: sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY) if sacled_delay > 0: time.sleep(sacled_delay) # 应用加速倍数 else: continue if item['index'] > 0: yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}] else: yield messages.prompt_messages messages.finished = True yield messages return messages # 如果没有缓存,执行原始函数并记录结果 yields_data = [] last_time = time.time() generator = func(*args, **kwargs) try: while True: current_time = time.time() value = next(generator) delay = current_time - last_time yields_data.append({ 'index': len(value.response), 'delay': delay }) last_time = current_time yield value except StopIteration as e: return_value = e.value # 记录API调用费用 record_api_cost(return_value) # 存储到MongoDB cache_data = { 'created_at':datetime.datetime.now(), 'return_value': return_value, 'func_name': func.__name__, 'args': args, 'kwargs': kwargs, 'yields': yields_data, 'cache_key': cache_key, } collection.insert_one(cache_data) return return_value return wrapper return decorator