File size: 4,174 Bytes
217acfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import hashlib
import re
import json
import os

def count_characters(text):
    chinese_pattern = re.compile(r'[\u4e00-\u9fff]+')
    english_pattern = re.compile(r'[a-zA-Z]+')
    other_pattern = re.compile(r'[^\u4e00-\u9fffa-zA-Z]+')

    chinese_characters = chinese_pattern.findall(text)
    english_characters = english_pattern.findall(text)
    other_characters = other_pattern.findall(text)

    chinese_count = sum(len(char) for char in chinese_characters)
    english_count = sum(len(char) for char in english_characters)
    other_count = sum(len(char) for char in other_characters)

    return chinese_count, english_count, other_count


model_config = {}


model_prices = {}
try:
    model_prices_path = os.path.join(os.path.dirname(__file__), 'model_prices.json')
    with open(model_prices_path, 'r') as f:
        model_prices = json.load(f)
except Exception as e:
    print(f"Warning: Failed to load model_prices.json: {e}")

class ChatMessages(list):
    def __init__(self, *args, **kwargs):
        super().__init__(*args)
        self.model = kwargs['model'] if 'model' in kwargs else None
        self.finished = False
        
        assert 'currency_symbol' not in kwargs

        if not model_config:
            from .baidu_api import wenxin_model_config
            from .doubao_api import doubao_model_config
            from .openai_api import gpt_model_config
            from .zhipuai_api import zhipuai_model_config
            model_config.update({**wenxin_model_config, **doubao_model_config, **gpt_model_config, **zhipuai_model_config})
    
    def __getitem__(self, index):
        result = super().__getitem__(index)
        if isinstance(index, slice):
            return ChatMessages(result, model=self.model)
        return result
    
    def __add__(self, other):
        if isinstance(other, list):
            return ChatMessages(super().__add__(other), model=self.model)
        return NotImplemented 

    def count_message_tokens(self):
        return self.get_estimated_tokens()
    
    def copy(self):
        return ChatMessages(self, model=self.model)
    
    def get_estimated_tokens(self):
        num_tokens = 0
        for message in self:
            for key, value in message.items():
                chinese_count, english_count, other_count = count_characters(value)
                num_tokens += chinese_count // 2 + english_count // 5 + other_count // 2
        return num_tokens
    
    def get_prompt_messages_hash(self):
        # 转换为JSON字符串并创建哈希
        cache_string = json.dumps(self.prompt_messages, sort_keys=True)
        return hashlib.md5(cache_string.encode()).hexdigest()
    
    @property
    def cost(self):
        if len(self) == 0:
            return 0
        
        if self.model in model_config:
            return model_config[self.model]["Pricing"][0] * self[:-1].count_message_tokens() / 1_000 + model_config[self.model]["Pricing"][1] * self[-1:].count_message_tokens() / 1_000
        elif self.model in model_prices:
            return (
                model_prices[self.model]["input_cost_per_token"] * self[:-1].count_message_tokens() +
                model_prices[self.model]["output_cost_per_token"] * self[-1:].count_message_tokens()
            )
        return 0
    
    @property
    def response(self):
        return self[-1]['content'] if self[-1]['role'] == 'assistant' else ''
    
    @property
    def prompt_messages(self):
        return self[:-1] if self.response else self
    
    @property
    def currency_symbol(self):
        if self.model in model_config:
            return model_config[self.model]["currency_symbol"]
        else:
            return '$'
    
    @property
    def cost_info(self):
        formatted_cost = f"{self.cost:.7f}".rstrip('0').rstrip('.')
        return f"{self.model}: {formatted_cost}{self.currency_symbol}"
    
    def print(self):
        for message in self:
            print(f"{message['role']}".center(100, '-') + '\n')
            print(message['content'])
            print()