File size: 4,443 Bytes
217acfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import Dict, Any, Optional, Generator

from .mongodb_cache import llm_api_cache
from .baidu_api import stream_chat_with_wenxin, wenxin_model_config
from .doubao_api import stream_chat_with_doubao, doubao_model_config
from .chat_messages import ChatMessages
from .openai_api import stream_chat_with_gpt, gpt_model_config
from .zhipuai_api import stream_chat_with_zhipuai, zhipuai_model_config

class ModelConfig(dict):
    def __init__(self, model: str, **options):
        super().__init__(**options)
        self['model'] = model
        self.validate()

    def validate(self):
        def check_key(provider, keys):
            for key in keys:    
                if key not in self:
                    raise ValueError(f"{provider}的API设置中未传入: {key}")
                elif not self[key].strip():
                    raise ValueError(f"{provider}的API设置中未配置: {key}")

        if self['model'] in wenxin_model_config:
            check_key('文心一言', ['ak', 'sk'])
        elif self['model'] in doubao_model_config:
            check_key('豆包', ['api_key', 'endpoint_id'])
        elif self['model'] in zhipuai_model_config:
            check_key('智谱AI', ['api_key'])
        elif self['model'] in gpt_model_config or True:
            # 其他模型名默认采用openai接口调用
            check_key('OpenAI', ['api_key'])
        
        if 'max_tokens' not in self:
            raise ValueError('ModelConfig未传入key: max_tokens')
        else:
            assert self['max_tokens'] <= 4_096, 'max_tokens最大为4096!'


    def get_api_keys(self) -> Dict[str, str]:
        return {k: v for k, v in self.items() if k not in ['model']}

@llm_api_cache()
def stream_chat(model_config: ModelConfig, messages: list, response_json=False) -> Generator:
    if isinstance(model_config, dict):
        model_config = ModelConfig(**model_config)
    
    model_config.validate()

    messages = ChatMessages(messages, model=model_config['model'])

    assert model_config['max_tokens'] <= 4096, 'max_tokens最大为4096!'

    if messages.count_message_tokens() > model_config['max_tokens']:
        raise Exception(f'请求的文本过长,超过最大tokens:{model_config["max_tokens"]}。')
    
    yield messages
    
    if model_config['model'] in wenxin_model_config:
        result = yield from stream_chat_with_wenxin(
            messages,
            model=model_config['model'],
            ak=model_config['ak'],
            sk=model_config['sk'],
            max_tokens=model_config['max_tokens'],
            response_json=response_json
        )
    elif model_config['model'] in doubao_model_config:  # doubao models
        result = yield from stream_chat_with_doubao(
            messages,
            model=model_config['model'],
            endpoint_id=model_config['endpoint_id'],
            api_key=model_config['api_key'],
            max_tokens=model_config['max_tokens'],
            response_json=response_json
        )
    elif model_config['model'] in zhipuai_model_config:  # zhipuai models
        result = yield from stream_chat_with_zhipuai(
            messages,
            model=model_config['model'],
            api_key=model_config['api_key'],
            max_tokens=model_config['max_tokens'],
            response_json=response_json
        )
    elif model_config['model'] in gpt_model_config or True:  # openai models或其他兼容openai接口的模型
        result = yield from stream_chat_with_gpt(
            messages,
            model=model_config['model'],
            api_key=model_config['api_key'],
            base_url=model_config.get('base_url'),
            proxies=model_config.get('proxies'),
            max_tokens=model_config['max_tokens'],
            response_json=response_json
        )
    
    result.finished = True
    yield result

    return result

def test_stream_chat(model_config: ModelConfig):
    messages = [{"role": "user", "content": "1+1=?直接输出答案即可:"}]
    for response in stream_chat(model_config, messages, use_cache=False):
        yield response.response
    
    return response

# 导出必要的函数和配置
__all__ = ['ChatMessages', 'stream_chat', 'wenxin_model_config', 'doubao_model_config', 'gpt_model_config', 'zhipuai_model_config', 'ModelConfig']