File size: 4,575 Bytes
217acfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import time
import functools
from typing import Generator, Any
from pymongo import MongoClient
import hashlib
import json
import datetime
import random

from config import ENABLE_MONOGODB, MONOGODB_DB_NAME, ENABLE_MONOGODB_CACHE, CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY

from .chat_messages import ChatMessages
from .mongodb_cost import record_api_cost, check_cost_limits
from .mongodb_init import mongo_client as client

def create_cache_key(func_name: str, args: tuple, kwargs: dict) -> str:
    """创建缓存键"""
    # 将参数转换为可序列化的格式
    cache_dict = {
        'func_name': func_name,
        'args': args,
        'kwargs': kwargs
    }
    # 转换为JSON字符串并创建哈希
    cache_string = json.dumps(cache_dict, sort_keys=True)
    return hashlib.md5(cache_string.encode()).hexdigest()



def llm_api_cache():
    """MongoDB缓存装饰器"""
    db_name=MONOGODB_DB_NAME
    collection_name='stream_chat'
    
    def dummy_decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 移除 use_cache 参数,避免传递给原函数
            kwargs.pop('use_cache', None)
            return func(*args, **kwargs)
        return wrapper
    

    if not ENABLE_MONOGODB:
        return dummy_decorator
    
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            check_cost_limits()

            use_cache = kwargs.pop('use_cache', True)   # pop很重要
            
            if not ENABLE_MONOGODB_CACHE:
                use_cache = False

            db = client[db_name]
            collection = db[collection_name]
            
            # 创建缓存键
            cache_key = create_cache_key(func.__name__, args, kwargs)
            
            # 检查缓存
            if use_cache:
                cached_data = list(collection.aggregate([
                    {'$match': {'cache_key': cache_key}},
                    {'$sample': {'size': 1}}
                ]))
                cached_data = cached_data[0] if cached_data else None
                if cached_data:
                    # 如果有缓存,yield缓存的结果
                    messages = ChatMessages(cached_data['return_value'])
                    messages.model = args[0]['model']
                    for item in cached_data['yields']:
                        sacled_delay = min(item['delay'] / CACHE_REPLAY_SPEED, CACHE_REPLAY_MAX_DELAY)
                        if sacled_delay > 0: time.sleep(sacled_delay)  # 应用加速倍数
                        else: continue
                        if item['index'] > 0:
                            yield messages.prompt_messages + [{'role': 'assistant', 'content': messages.response[:item['index']]}]
                        else:
                            yield messages.prompt_messages
                    messages.finished = True
                    yield messages
                    return messages
            
            # 如果没有缓存,执行原始函数并记录结果
            yields_data = []
            last_time = time.time()
            
            generator = func(*args, **kwargs)
            
            try:
                while True:
                    current_time = time.time()
                    value = next(generator)
                    delay = current_time - last_time
                    
                    yields_data.append({
                        'index': len(value.response),
                        'delay': delay
                    })
                    
                    last_time = current_time
                    yield value
                    
            except StopIteration as e:
                return_value = e.value
                
                # 记录API调用费用
                record_api_cost(return_value)
                
                # 存储到MongoDB
                cache_data = {
                    'created_at':datetime.datetime.now(),
                    'return_value': return_value,
                    'func_name': func.__name__,
                    'args': args,
                    'kwargs': kwargs,
                    'yields': yields_data,
                    'cache_key': cache_key,
                }
                collection.insert_one(cache_data)
                
                return return_value
            
        return wrapper
    return decorator