File size: 5,547 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Literal, Optional

import gradio as gr

from swift.utils import get_file_mm_type
from ..utils import History
from .locale import locale_mapping


def clear_session():
    return '', [], []


def modify_system_session(system: str):
    system = system or ''
    return system, '', [], []


def _history_to_messages(history: History, system: Optional[str]):
    messages = []
    if system is not None:
        messages.append({'role': 'system', 'content': system})
    content = []
    for h in history:
        assert isinstance(h, (list, tuple))
        if isinstance(h[0], tuple):
            assert h[1] is None
            file_path = h[0][0]
            try:
                mm_type = get_file_mm_type(file_path)
                content.append({'type': mm_type, mm_type: file_path})
            except ValueError:
                with open(file_path, 'r', encoding='utf-8') as f:
                    content.append({'type': 'text', 'text': f.read()})
        else:
            content.append({'type': 'text', 'text': h[0]})
            messages.append({'role': 'user', 'content': content})
            if h[1] is not None:
                messages.append({'role': 'assistant', 'content': h[1]})
            content = []
    return messages


def _parse_text(text: str) -> str:
    mapping = {'<': '&lt;', '>': '&gt;', '*': '&ast;'}
    for k, v in mapping.items():
        text = text.replace(k, v)
    return text


async def model_chat(history: History, real_history: History, system: Optional[str], *, client, model: str,
                     request_config: Optional['RequestConfig']):
    if history:
        from swift.llm import InferRequest

        messages = _history_to_messages(real_history, system)
        resp_or_gen = await client.infer_async(
            InferRequest(messages=messages), request_config=request_config, model=model)
        if request_config and request_config.stream:
            response = ''
            async for resp in resp_or_gen:
                if resp is None:
                    continue
                response += resp.choices[0].delta.content
                history[-1][1] = _parse_text(response)
                real_history[-1][-1] = response
                yield history, real_history

        else:
            response = resp_or_gen.choices[0].message.content
            history[-1][1] = _parse_text(response)
            real_history[-1][-1] = response
            yield history, real_history

    else:
        yield [], []


def add_text(history: History, real_history: History, query: str):
    history = history or []
    real_history = real_history or []
    history.append([_parse_text(query), None])
    real_history.append([query, None])
    return history, real_history, ''


def add_file(history: History, real_history: History, file):
    history = history or []
    real_history = real_history or []
    history.append([(file.name, ), None])
    real_history.append([(file.name, ), None])
    return history, real_history


def build_ui(base_url: str,
             model: Optional[str] = None,
             *,
             request_config: Optional['RequestConfig'] = None,
             is_multimodal: bool = True,
             studio_title: Optional[str] = None,
             lang: Literal['en', 'zh'] = 'en',
             default_system: Optional[str] = None):
    from swift.llm import InferClient
    client = InferClient(base_url=base_url)
    model = model or client.models[0]
    studio_title = studio_title or model
    with gr.Blocks() as demo:
        gr.Markdown(f'<center><font size=8>{studio_title}</center>')
        with gr.Row():
            with gr.Column(scale=3):
                system_input = gr.Textbox(value=default_system, lines=1, label='System')
            with gr.Column(scale=1):
                modify_system = gr.Button(locale_mapping['modify_system'][lang], scale=2)
        chatbot = gr.Chatbot(label='Chatbot')
        textbox = gr.Textbox(lines=1, label='Input')

        with gr.Row():
            upload = gr.UploadButton(locale_mapping['upload'][lang], visible=is_multimodal)
            submit = gr.Button(locale_mapping['submit'][lang])
            regenerate = gr.Button(locale_mapping['regenerate'][lang])
            clear_history = gr.Button(locale_mapping['clear_history'][lang])

        system_state = gr.State(value=default_system)
        history_state = gr.State(value=[])
        model_chat_ = partial(model_chat, client=client, model=model, request_config=request_config)

        upload.upload(add_file, [chatbot, history_state, upload], [chatbot, history_state])
        textbox.submit(add_text, [chatbot, history_state, textbox],
                       [chatbot, history_state, textbox]).then(model_chat_, [chatbot, history_state, system_state],
                                                               [chatbot, history_state])
        submit.click(add_text, [chatbot, history_state, textbox],
                     [chatbot, history_state, textbox]).then(model_chat_, [chatbot, history_state, system_state],
                                                             [chatbot, history_state])
        regenerate.click(model_chat_, [chatbot, history_state, system_state], [chatbot, history_state])
        clear_history.click(clear_session, [], [textbox, chatbot, history_state])
        modify_system.click(modify_system_session, [system_input], [system_state, textbox, chatbot, history_state])
    return demo