|
|
import time
|
|
|
import functools
|
|
|
from typing import Generator, Any
|
|
|
from pymongo import MongoClient
|
|
|
import hashlib
|
|
|
import json
|
|
|
import datetime
|
|
|
import random
|
|
|
|
|
|
from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY
|
|
|
|
|
|
from .chat_messages import ChatMessages
|
|
|
from .mongodb_cost import record_api_cost, check_cost_limits
|
|
|
from .mongodb_init import mongo_client as client
|
|
|
|
|
|
def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
|
|
|
"""创建缓存键"""
|
|
|
|
|
|
cache_dict = {
|
|
|
'func_name': func_name,
|
|
|
'args': args,
|
|
|
'kwargs': kwargs
|
|
|
}
|
|
|
|
|
|
cache_string = json.dumps(cache_dict, sort_keys=True)
|
|
|
return hashlib.md5(cache_string.encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
def llm_api_cache():
|
|
|
"""MongoDB缓存装饰器"""
|
|
|
db_name=MONOGODB_DB_NAME
|
|
|
collection_name='stream_chat'
|
|
|
|
|
|
def dummy_decorator(func):
|
|
|
@functools.wraps(func)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
|
kwargs.pop('use_cache', None)
|
|
|
return func(*args, **kwargs)
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
if not ENABLE_MONOGODB:
|
|
|
return dummy_decorator
|
|
|
|
|
|
def decorator(func):
|
|
|
@functools.wraps(func)
|
|
|
def wrapper(*args, **kwargs):
|
|
|
check_cost_limits()
|
|
|
|
|
|
use_cache = kwargs.pop('use_cache', True)
|
|
|
|
|
|
if not ENABLE_MONOGODB_CACHE:
|
|
|
use_cache = False
|
|
|
|
|
|
db = client[db_name]
|
|
|
collection = db[collection_name]
|
|
|
|
|
|
|
|
|
cache_key = create_cache_key(func.__name__, args, kwargs)
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
cached_data = list(collection.aggregate([
|
|
|
{'$match': {'cache_key': cache_key}},
|
|
|
{'$sample': {'size': 1}}
|
|
|
]))
|
|
|
cached_data = cached_data[0] if cached_data else None
|
|
|
if cached_data:
|
|
|
|
|
|
messages = ChatMessages(cached_data['return_value'])
|
|
|
messages.model = args[0]['model']
|
|
|
for item in cached_data['yields']:
|
|
|
sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY)
|
|
|
if sacled_delay > 0: time.sleep(sacled_delay)
|
|
|
else: continue
|
|
|
if item['index'] > 0:
|
|
|
yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}]
|
|
|
else:
|
|
|
yield messages.prompt_messages
|
|
|
messages.finished = True
|
|
|
yield messages
|
|
|
return messages
|
|
|
|
|
|
|
|
|
yields_data = []
|
|
|
last_time = time.time()
|
|
|
|
|
|
generator = func(*args, **kwargs)
|
|
|
|
|
|
try:
|
|
|
while True:
|
|
|
current_time = time.time()
|
|
|
value = next(generator)
|
|
|
delay = current_time - last_time
|
|
|
|
|
|
yields_data.append({
|
|
|
'index': len(value.response),
|
|
|
'delay': delay
|
|
|
})
|
|
|
|
|
|
last_time = current_time
|
|
|
yield value
|
|
|
|
|
|
except StopIteration as e:
|
|
|
return_value = e.value
|
|
|
|
|
|
|
|
|
record_api_cost(return_value)
|
|
|
|
|
|
|
|
|
cache_data = {
|
|
|
'created_at':datetime.datetime.now(),
|
|
|
'return_value': return_value,
|
|
|
'func_name': func.__name__,
|
|
|
'args': args,
|
|
|
'kwargs': kwargs,
|
|
|
'yields': yields_data,
|
|
|
'cache_key': cache_key,
|
|
|
}
|
|
|
collection.insert_one(cache_data)
|
|
|
|
|
|
return return_value
|
|
|
|
|
|
return wrapper
|
|
|
return decorator
|
|
|
|