File size: 5,547 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Literal, Optional
import gradio as gr
from swift.utils import get_file_mm_type
from ..utils import History
from .locale import locale_mapping
def clear_session():
return '', [], []
def modify_system_session(system: str):
system = system or ''
return system, '', [], []
def _history_to_messages(history: History, system: Optional[str]):
messages = []
if system is not None:
messages.append({'role': 'system', 'content': system})
content = []
for h in history:
assert isinstance(h, (list, tuple))
if isinstance(h[0], tuple):
assert h[1] is None
file_path = h[0][0]
try:
mm_type = get_file_mm_type(file_path)
content.append({'type': mm_type, mm_type: file_path})
except ValueError:
with open(file_path, 'r', encoding='utf-8') as f:
content.append({'type': 'text', 'text': f.read()})
else:
content.append({'type': 'text', 'text': h[0]})
messages.append({'role': 'user', 'content': content})
if h[1] is not None:
messages.append({'role': 'assistant', 'content': h[1]})
content = []
return messages
def _parse_text(text: str) -> str:
mapping = {'<': '<', '>': '>', '*': '*'}
for k, v in mapping.items():
text = text.replace(k, v)
return text
async def model_chat(history: History, real_history: History, system: Optional[str], *, client, model: str,
request_config: Optional['RequestConfig']):
if history:
from swift.llm import InferRequest
messages = _history_to_messages(real_history, system)
resp_or_gen = await client.infer_async(
InferRequest(messages=messages), request_config=request_config, model=model)
if request_config and request_config.stream:
response = ''
async for resp in resp_or_gen:
if resp is None:
continue
response += resp.choices[0].delta.content
history[-1][1] = _parse_text(response)
real_history[-1][-1] = response
yield history, real_history
else:
response = resp_or_gen.choices[0].message.content
history[-1][1] = _parse_text(response)
real_history[-1][-1] = response
yield history, real_history
else:
yield [], []
def add_text(history: History, real_history: History, query: str):
history = history or []
real_history = real_history or []
history.append([_parse_text(query), None])
real_history.append([query, None])
return history, real_history, ''
def add_file(history: History, real_history: History, file):
history = history or []
real_history = real_history or []
history.append([(file.name, ), None])
real_history.append([(file.name, ), None])
return history, real_history
def build_ui(base_url: str,
model: Optional[str] = None,
*,
request_config: Optional['RequestConfig'] = None,
is_multimodal: bool = True,
studio_title: Optional[str] = None,
lang: Literal['en', 'zh'] = 'en',
default_system: Optional[str] = None):
from swift.llm import InferClient
client = InferClient(base_url=base_url)
model = model or client.models[0]
studio_title = studio_title or model
with gr.Blocks() as demo:
gr.Markdown(f'<center><font size=8>{studio_title}</center>')
with gr.Row():
with gr.Column(scale=3):
system_input = gr.Textbox(value=default_system, lines=1, label='System')
with gr.Column(scale=1):
modify_system = gr.Button(locale_mapping['modify_system'][lang], scale=2)
chatbot = gr.Chatbot(label='Chatbot')
textbox = gr.Textbox(lines=1, label='Input')
with gr.Row():
upload = gr.UploadButton(locale_mapping['upload'][lang], visible=is_multimodal)
submit = gr.Button(locale_mapping['submit'][lang])
regenerate = gr.Button(locale_mapping['regenerate'][lang])
clear_history = gr.Button(locale_mapping['clear_history'][lang])
system_state = gr.State(value=default_system)
history_state = gr.State(value=[])
model_chat_ = partial(model_chat, client=client, model=model, request_config=request_config)
upload.upload(add_file, [chatbot, history_state, upload], [chatbot, history_state])
textbox.submit(add_text, [chatbot, history_state, textbox],
[chatbot, history_state, textbox]).then(model_chat_, [chatbot, history_state, system_state],
[chatbot, history_state])
submit.click(add_text, [chatbot, history_state, textbox],
[chatbot, history_state, textbox]).then(model_chat_, [chatbot, history_state, system_state],
[chatbot, history_state])
regenerate.click(model_chat_, [chatbot, history_state, system_state], [chatbot, history_state])
clear_history.click(clear_session, [], [textbox, chatbot, history_state])
modify_system.click(modify_system_session, [system_input], [system_state, textbox, chatbot, history_state])
return demo
|