File size: 4,575 Bytes
217acfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import time
import functools
from typing import Generator, Any
from pymongo import MongoClient
import hashlib
import json
import datetime
import random
from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY
from .chat_messages import ChatMessages
from .mongodb_cost import record_api_cost, check_cost_limits
from .mongodb_init import mongo_client as client
def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
"""创建缓存键"""
# 将参数转换为可序列化的格式
cache_dict = {
'func_name': func_name,
'args': args,
'kwargs': kwargs
}
# 转换为JSON字符串并创建哈希
cache_string = json.dumps(cache_dict, sort_keys=True)
return hashlib.md5(cache_string.encode()).hexdigest()
def llm_api_cache():
"""MongoDB缓存装饰器"""
db_name=MONOGODB_DB_NAME
collection_name='stream_chat'
def dummy_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 移除 use_cache 参数,避免传递给原函数
kwargs.pop('use_cache', None)
return func(*args, **kwargs)
return wrapper
if not ENABLE_MONOGODB:
return dummy_decorator
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
check_cost_limits()
use_cache = kwargs.pop('use_cache', True) # pop很重要
if not ENABLE_MONOGODB_CACHE:
use_cache = False
db = client[db_name]
collection = db[collection_name]
# 创建缓存键
cache_key = create_cache_key(func.__name__, args, kwargs)
# 检查缓存
if use_cache:
cached_data = list(collection.aggregate([
{'$match': {'cache_key': cache_key}},
{'$sample': {'size': 1}}
]))
cached_data = cached_data[0] if cached_data else None
if cached_data:
# 如果有缓存,yield缓存的结果
messages = ChatMessages(cached_data['return_value'])
messages.model = args[0]['model']
for item in cached_data['yields']:
sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY)
if sacled_delay > 0: time.sleep(sacled_delay) # 应用加速倍数
else: continue
if item['index'] > 0:
yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}]
else:
yield messages.prompt_messages
messages.finished = True
yield messages
return messages
# 如果没有缓存,执行原始函数并记录结果
yields_data = []
last_time = time.time()
generator = func(*args, **kwargs)
try:
while True:
current_time = time.time()
value = next(generator)
delay = current_time - last_time
yields_data.append({
'index': len(value.response),
'delay': delay
})
last_time = current_time
yield value
except StopIteration as e:
return_value = e.value
# 记录API调用费用
record_api_cost(return_value)
# 存储到MongoDB
cache_data = {
'created_at':datetime.datetime.now(),
'return_value': return_value,
'func_name': func.__name__,
'args': args,
'kwargs': kwargs,
'yields': yields_data,
'cache_key': cache_key,
}
collection.insert_one(cache_data)
return return_value
return wrapper
return decorator
|