long / core /frontend_setting.py
deeme's picture
Upload 111 files
217acfe verified
import gradio as gr
from enum import Enum, auto
from llm_api import ModelConfig, wenxin_model_config, doubao_model_config, gpt_model_config, zhipuai_model_config, test_stream_chat
from config import API_SETTINGS, RENDER_SETTING_API_TEST_BTN, ENABLE_SETTING_SELECT_SUB_MODEL
class Provider:
GPT = "GPT(OpenAI)"
WENXIN = "文心(百度)"
DOUBAO = "豆包(字节跳动)"
ZHIPUAI = "GLM(智谱)"
OTHERS = '其他'
def deep_update(d, u):
"""Recursively update dictionary d with values from dictionary u"""
for k, v in u.items():
if isinstance(v, dict) and k in d and isinstance(d[k], dict):
deep_update(d[k], v)
else:
d[k] = v
def new_setting():
model_config = API_SETTINGS.pop('model')
sub_model_config = API_SETTINGS.pop('sub_model')
new_setting = dict(
model=ModelConfig(**model_config),
sub_model=ModelConfig(**sub_model_config),
render_count=0,
provider_name=Provider.GPT,
wenxin={
'ak': '',
'sk': '',
'default_model': 'ERNIE-Novel-8K',
'default_sub_model': 'ERNIE-3.5-8K',
'available_models': list(wenxin_model_config.keys())
},
doubao={
'api_key': '',
'main_endpoint_id': '',
'sub_endpoint_id': '',
'default_model': 'doubao-pro-32k',
'default_sub_model': 'doubao-lite-32k',
'available_models': list(doubao_model_config.keys())
},
gpt={
'api_key': '',
'base_url': '',
'proxies': '',
'default_model': 'gpt-4o',
'default_sub_model': 'gpt-4o-mini',
'available_models': list(gpt_model_config.keys())
},
zhipuai={
'api_key': '',
'default_model': 'glm-4-plus',
'default_sub_model': 'glm-4-flashx',
'available_models': list(zhipuai_model_config.keys())
},
others={
'api_key': '',
'base_url': '',
'default_model': '',
'default_sub_model': '',
'available_models': []
}
)
deep_update(new_setting, API_SETTINGS)
return new_setting
# @gr.render(inputs=setting_state)
def render_setting(setting, setting_state):
with gr.Accordion("API 设置"):
with gr.Row():
provider_name = gr.Dropdown(
choices=[Provider.GPT, Provider.WENXIN, Provider.DOUBAO, Provider.ZHIPUAI, Provider.OTHERS],
value=setting['provider_name'],
label="模型提供商",
scale=1
)
def on_select_provider(provider_name):
setting['provider_name'] = provider_name
return setting
provider_name.select(fn=on_select_provider, inputs=provider_name, outputs=[setting_state])
match setting['provider_name']:
case Provider.WENXIN:
provider_config = setting['wenxin']
case Provider.DOUBAO:
provider_config = setting['doubao']
case Provider.GPT:
provider_config = setting['gpt']
case Provider.ZHIPUAI:
provider_config = setting['zhipuai']
case Provider.OTHERS:
provider_config = setting['others']
main_model = gr.Dropdown(
choices=provider_config['available_models'],
value=provider_config['default_model'],
label="主模型",
scale=1,
allow_custom_value=setting['provider_name'] == Provider.OTHERS
)
sub_model = gr.Dropdown(
choices=provider_config['available_models'],
value=provider_config['default_sub_model'],
label="辅助模型",
scale=1,
allow_custom_value=setting['provider_name'] == Provider.OTHERS,
interactive=ENABLE_SETTING_SELECT_SUB_MODEL
)
with gr.Row():
if setting['provider_name'] == Provider.WENXIN:
baidu_access_key = gr.Textbox(
value=provider_config['ak'],
label='Baidu Access Key',
lines=1,
placeholder='Enter your Baidu access key here',
interactive=True,
scale=10,
type='password'
)
baidu_secret_key = gr.Textbox(
value=provider_config['sk'],
label='Baidu Secret Key',
lines=1,
placeholder='Enter your Baidu secret key here',
interactive=True,
scale=10,
type='password'
)
elif setting['provider_name'] == Provider.DOUBAO:
doubao_api_key = gr.Textbox(
value=provider_config['api_key'],
label='Doubao API Key',
lines=1,
placeholder='Enter your Doubao API key here',
interactive=True,
scale=10,
type='password'
)
main_endpoint_id = gr.Textbox(
value=provider_config['main_endpoint_id'],
label='Main Endpoint ID',
lines=1,
placeholder='Enter your main endpoint ID here',
interactive=True,
scale=10,
type='password'
)
sub_endpoint_id = gr.Textbox(
value=provider_config['sub_endpoint_id'],
label='Sub Endpoint ID',
lines=1,
placeholder='Enter your sub endpoint ID here',
interactive=True,
scale=10,
type='password'
)
elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
gpt_api_key = gr.Textbox(
value=provider_config['api_key'],
label='OpenAI API Key',
lines=1,
placeholder='Enter your OpenAI API key here',
interactive=True,
scale=10,
type='password'
)
base_url = gr.Textbox(
value=provider_config['base_url'],
label='API Base URL',
lines=1,
placeholder='Enter API base URL here',
interactive=True,
scale=10,
type='password'
)
elif setting['provider_name'] == Provider.ZHIPUAI:
zhipuai_api_key = gr.Textbox(
value=provider_config['api_key'],
label='ZhipuAI API Key',
lines=1,
placeholder='Enter your ZhipuAI API key here',
interactive=True,
scale=10,
type='password'
)
with gr.Row():
if setting['provider_name'] == Provider.WENXIN:
def on_submit(main_model, sub_model, baidu_access_key, baidu_secret_key):
provider_config['ak'] = baidu_access_key
provider_config['sk'] = baidu_secret_key
setting['model'] = ModelConfig(
model=main_model,
ak=baidu_access_key,
sk=baidu_secret_key,
max_tokens=4096
)
setting['sub_model'] = ModelConfig(
model=sub_model,
ak=baidu_access_key,
sk=baidu_secret_key,
max_tokens=4096
)
submit_event = dict(
fn=on_submit,
inputs=[main_model, sub_model, baidu_access_key, baidu_secret_key],
)
on_submit(main_model.value, sub_model.value, baidu_access_key.value, baidu_secret_key.value)
main_model.change(**submit_event)
sub_model.change(**submit_event)
baidu_access_key.change(**submit_event)
baidu_secret_key.change(**submit_event)
elif setting['provider_name'] == Provider.DOUBAO:
def on_submit(main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id):
provider_config['api_key'] = doubao_api_key
provider_config['main_endpoint_id'] = main_endpoint_id
provider_config['sub_endpoint_id'] = sub_endpoint_id
setting['model'] = ModelConfig(
model=main_model,
api_key=doubao_api_key,
endpoint_id=main_endpoint_id,
max_tokens=4096
)
setting['sub_model'] = ModelConfig(
model=sub_model,
api_key=doubao_api_key,
endpoint_id=sub_endpoint_id,
max_tokens=4096
)
submit_event = dict(
fn=on_submit,
inputs=[main_model, sub_model, doubao_api_key, main_endpoint_id, sub_endpoint_id],
)
on_submit(main_model.value, sub_model.value, doubao_api_key.value, main_endpoint_id.value, sub_endpoint_id.value)
main_model.change(**submit_event)
sub_model.change(**submit_event)
doubao_api_key.change(**submit_event)
main_endpoint_id.change(**submit_event)
sub_endpoint_id.change(**submit_event)
elif setting['provider_name'] in [Provider.GPT, Provider.OTHERS]:
def on_submit(main_model, sub_model, gpt_api_key, base_url):
provider_config['api_key'] = gpt_api_key
provider_config['base_url'] = base_url.strip()
setting['model'] = ModelConfig(
model=main_model,
api_key=provider_config['api_key'],
base_url=provider_config['base_url'],
max_tokens=4096,
proxies=provider_config.get('proxies', None),
)
setting['sub_model'] = ModelConfig(
model=sub_model,
api_key=provider_config['api_key'],
base_url=provider_config['base_url'],
max_tokens=4096,
proxies=provider_config.get('proxies', None),
)
submit_event = dict(
fn=on_submit,
inputs=[main_model, sub_model, gpt_api_key, base_url],
)
on_submit(main_model.value, sub_model.value, gpt_api_key.value, base_url.value)
main_model.change(**submit_event)
sub_model.change(**submit_event)
gpt_api_key.change(**submit_event)
base_url.change(**submit_event)
elif setting['provider_name'] == Provider.ZHIPUAI:
def on_submit(main_model, sub_model, zhipuai_api_key):
provider_config['api_key'] = zhipuai_api_key
setting['model'] = ModelConfig(
model=main_model,
api_key=zhipuai_api_key,
max_tokens=4096
)
setting['sub_model'] = ModelConfig(
model=sub_model,
api_key=zhipuai_api_key,
max_tokens=4096
)
submit_event = dict(
fn=on_submit,
inputs=[main_model, sub_model, zhipuai_api_key],
)
on_submit(main_model.value, sub_model.value, zhipuai_api_key.value)
main_model.change(**submit_event)
sub_model.change(**submit_event)
zhipuai_api_key.change(**submit_event)
if RENDER_SETTING_API_TEST_BTN:
test_btn = gr.Button("测试")
test_report = gr.Textbox(show_label=False, container=False, value='', interactive=False, scale=10)
def on_test_llm_api():
if not setting['model']['model'].strip():
return gr.Info('主模型名不能为空')
if not setting['sub_model']['model'].strip():
return gr.Info('辅助模型名不能为空')
try:
response1 = yield from test_stream_chat(setting['model'])
response2 = yield from test_stream_chat(setting['sub_model'])
report_text = f"User:1+1=?\n主模型 :{response1.response}({response1.cost_info})\n辅助模型:{response2.response}({response2.cost_info})\n测试通过!"
yield report_text
except Exception as e:
yield f"测试失败:{str(e)}"
if RENDER_SETTING_API_TEST_BTN:
test_btn.click(
on_test_llm_api,
outputs=[test_report]
)