File size: 4,179 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Type

import gradio as gr

from swift.llm import TEMPLATE_MAPPING, ModelType, RLHFArguments
from swift.llm.model.register import get_all_models
from swift.ui.base import BaseUI


class Model(BaseUI):
    group = 'llm_train'

    locale_dict = {
        'model_type': {
            'label': {
                'zh': '模型类型',
                'en': 'Select Model Type'
            },
            'info': {
                'zh': 'SWIFT已支持的模型类型',
                'en': 'Base model type supported by SWIFT'
            }
        },
        'model': {
            'label': {
                'zh': '模型id或路径',
                'en': 'Model id or path'
            },
            'info': {
                'zh': '实际的模型id',
                'en': 'The actual model id or model path'
            }
        },
        'template': {
            'label': {
                'zh': '模型Prompt模板类型',
                'en': 'Prompt template type'
            },
            'info': {
                'zh': '选择匹配模型的Prompt模板',
                'en': 'Choose the template type of the model'
            }
        },
        'system': {
            'label': {
                'zh': 'system字段',
                'en': 'system'
            },
            'info': {
                'zh': '选择system字段的内容',
                'en': 'Choose the content of the system field'
            }
        },
        'reset': {
            'value': {
                'zh': '恢复模型初始值',
                'en': 'Reset model default'
            },
        },
        'train_record': {
            'label': {
                'zh': '训练记录',
                'en': 'Train record'
            },
            'info': {
                'zh': '展示使用web-ui的历史训练及参数',
                'en': 'Show the training history and parameters'
            }
        },
        'clear_cache': {
            'value': {
                'zh': '删除训练记录',
                'en': 'Delete train records'
            },
        },
        'model_param': {
            'label': {
                'zh': '模型设置',
                'en': 'Model settings'
            },
        },
        'checkpoint': {
            'value': {
                'zh': '训练后的模型',
                'en': 'Trained model'
            }
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='model_param', open=True):
            with gr.Row():
                model = gr.Dropdown(
                    elem_id='model',
                    scale=20,
                    choices=get_all_models(),
                    value='Qwen/Qwen2.5-7B-Instruct',
                    allow_custom_value=True)
                gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
                gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
                train_record = gr.Dropdown(elem_id='train_record', choices=[], scale=20)
                clear_cache = gr.Button(elem_id='clear_cache', scale=2)
            with gr.Row():
                gr.Textbox(elem_id='system', lines=1, scale=20)

        def clear_record(model):
            if model:
                cls.clear_cache(model)
                return gr.update(choices=[])
            return gr.update()

        clear_cache.click(clear_record, inputs=[model], outputs=[train_record])

    @classmethod
    def after_build_ui(cls, base_tab: Type['BaseUI']):
        cls.element('model').change(
            partial(base_tab.update_input_model, arg_cls=RLHFArguments),
            inputs=[cls.element('model')],
            outputs=[cls.element('train_record')] + list(base_tab.valid_elements().values()))

        cls.element('train_record').change(
            partial(base_tab.update_all_settings, base_tab=base_tab),
            inputs=[cls.element('model'), cls.element('train_record')],
            outputs=list(base_tab.valid_elements().values()))