|
|
from typing import Dict, Any, Optional, Generator
|
|
|
|
|
|
from .mongodb_cache import llm_api_cache
|
|
|
from .baidu_api import stream_chat_with_wenxin, wenxin_model_config
|
|
|
from .doubao_api import stream_chat_with_doubao, doubao_model_config
|
|
|
from .chat_messages import ChatMessages
|
|
|
from .openai_api import stream_chat_with_gpt, gpt_model_config
|
|
|
from .zhipuai_api import stream_chat_with_zhipuai, zhipuai_model_config
|
|
|
|
|
|
class ModelConfig(dict):
|
|
|
def __init__(self, model: str, **options):
|
|
|
super().__init__(**options)
|
|
|
self['model'] = model
|
|
|
self.validate()
|
|
|
|
|
|
def validate(self):
|
|
|
def check_key(provider, keys):
|
|
|
for key in keys:
|
|
|
if key not in self:
|
|
|
raise ValueError(f"{provider}的API设置中未传入: {key}")
|
|
|
elif not self[key].strip():
|
|
|
raise ValueError(f"{provider}的API设置中未配置: {key}")
|
|
|
|
|
|
if self['model'] in wenxin_model_config:
|
|
|
check_key('文心一言', ['ak', 'sk'])
|
|
|
elif self['model'] in doubao_model_config:
|
|
|
check_key('豆包', ['api_key', 'endpoint_id'])
|
|
|
elif self['model'] in zhipuai_model_config:
|
|
|
check_key('智谱AI', ['api_key'])
|
|
|
elif self['model'] in gpt_model_config or True:
|
|
|
|
|
|
check_key('OpenAI', ['api_key'])
|
|
|
|
|
|
if 'max_tokens' not in self:
|
|
|
raise ValueError('ModelConfig未传入key: max_tokens')
|
|
|
else:
|
|
|
assert self['max_tokens'] <= 4_096, 'max_tokens最大为4096!'
|
|
|
|
|
|
|
|
|
def get_api_keys(self) -> Dict[str, str]:
|
|
|
return {k: v for k, v in self.items() if k not in ['model']}
|
|
|
|
|
|
@llm_api_cache()
|
|
|
def stream_chat(model_config: ModelConfig, messages: list, response_json=False) -> Generator:
|
|
|
if isinstance(model_config, dict):
|
|
|
model_config = ModelConfig(**model_config)
|
|
|
|
|
|
model_config.validate()
|
|
|
|
|
|
messages = ChatMessages(messages, model=model_config['model'])
|
|
|
|
|
|
assert model_config['max_tokens'] <= 4096, 'max_tokens最大为4096!'
|
|
|
|
|
|
if messages.count_message_tokens() > model_config['max_tokens']:
|
|
|
raise Exception(f'请求的文本过长,超过最大tokens:{model_config["max_tokens"]}。')
|
|
|
|
|
|
yield messages
|
|
|
|
|
|
if model_config['model'] in wenxin_model_config:
|
|
|
result = yield from stream_chat_with_wenxin(
|
|
|
messages,
|
|
|
model=model_config['model'],
|
|
|
ak=model_config['ak'],
|
|
|
sk=model_config['sk'],
|
|
|
max_tokens=model_config['max_tokens'],
|
|
|
response_json=response_json
|
|
|
)
|
|
|
elif model_config['model'] in doubao_model_config:
|
|
|
result = yield from stream_chat_with_doubao(
|
|
|
messages,
|
|
|
model=model_config['model'],
|
|
|
endpoint_id=model_config['endpoint_id'],
|
|
|
api_key=model_config['api_key'],
|
|
|
max_tokens=model_config['max_tokens'],
|
|
|
response_json=response_json
|
|
|
)
|
|
|
elif model_config['model'] in zhipuai_model_config:
|
|
|
result = yield from stream_chat_with_zhipuai(
|
|
|
messages,
|
|
|
model=model_config['model'],
|
|
|
api_key=model_config['api_key'],
|
|
|
max_tokens=model_config['max_tokens'],
|
|
|
response_json=response_json
|
|
|
)
|
|
|
elif model_config['model'] in gpt_model_config or True:
|
|
|
result = yield from stream_chat_with_gpt(
|
|
|
messages,
|
|
|
model=model_config['model'],
|
|
|
api_key=model_config['api_key'],
|
|
|
base_url=model_config.get('base_url'),
|
|
|
proxies=model_config.get('proxies'),
|
|
|
max_tokens=model_config['max_tokens'],
|
|
|
response_json=response_json
|
|
|
)
|
|
|
|
|
|
result.finished = True
|
|
|
yield result
|
|
|
|
|
|
return result
|
|
|
|
|
|
def test_stream_chat(model_config: ModelConfig):
|
|
|
messages = [{"role": "user", "content": "1+1=?直接输出答案即可:"}]
|
|
|
for response in stream_chat(model_config, messages, use_cache=False):
|
|
|
yield response.response
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
__all__ = ['ChatMessages', 'stream_chat', 'wenxin_model_config', 'doubao_model_config', 'gpt_model_config', 'zhipuai_model_config', 'ModelConfig']
|
|
|
|