File size: 5,545 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Advanced(BaseUI):

    group = 'llm_train'

    locale_dict = {
        'advanced_param': {
            'label': {
                'zh': '高级参数设置',
                'en': 'Advanced settings'
            },
        },
        'optim': {
            'label': {
                'zh': 'Optimizer类型',
                'en': 'The Optimizer type'
            },
            'info': {
                'zh': '设置Optimizer类型',
                'en': 'Set the Optimizer type'
            }
        },
        'weight_decay': {
            'label': {
                'zh': '权重衰减',
                'en': 'Weight decay'
            },
            'info': {
                'zh': '设置weight decay',
                'en': 'Set the weight decay'
            }
        },
        'logging_steps': {
            'label': {
                'zh': '日志打印步数',
                'en': 'Logging steps'
            },
            'info': {
                'zh': '设置日志打印的步数间隔',
                'en': 'Set the logging interval'
            }
        },
        'lr_scheduler_type': {
            'label': {
                'zh': 'LrScheduler类型',
                'en': 'The LrScheduler type'
            },
            'info': {
                'zh': '设置LrScheduler类型',
                'en': 'Set the LrScheduler type'
            }
        },
        'warmup_ratio': {
            'label': {
                'zh': '学习率warmup比例',
                'en': 'Lr warmup ratio'
            },
            'info': {
                'zh': '设置学习率warmup比例',
                'en': 'Set the warmup ratio in total steps'
            }
        },
        'more_params': {
            'label': {
                'zh': '其他高级参数',
                'en': 'Other params'
            },
            'info': {
                'zh': '以json格式或--xxx xxx命令行格式填入',
                'en': 'Fill in with json format or --xxx xxx cmd format'
            }
        },
        'truncation_strategy': {
            'label': {
                'zh': '数据集超长策略',
                'en': 'Dataset truncation strategy'
            },
            'info': {
                'zh': '如果token超长该如何处理',
                'en': 'How to deal with the rows exceed the max length'
            }
        },
        'max_steps': {
            'label': {
                'zh': '最大迭代步数',
                'en': 'Max steps',
            },
            'info': {
                'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
                'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
            }
        },
        'per_device_eval_batch_size': {
            'label': {
                'zh': '验证batch size',
                'en': 'Val batch size',
            },
            'info': {
                'zh': '验证的batch size',
                'en': 'Set the val batch size',
            }
        },
        'max_grad_norm': {
            'label': {
                'zh': '梯度裁剪',
                'en': 'Max grad norm',
            },
            'info': {
                'zh': '设置梯度裁剪',
                'en': 'Set the max grad norm',
            }
        },
        'predict_with_generate': {
            'label': {
                'zh': '使用生成指标代替loss',
                'en': 'Use generate metric instead of loss',
            },
            'info': {
                'zh': '验证时使用generate/Rouge代替loss',
                'en': 'Use model.generate/Rouge instead of loss',
            }
        },
        'deepspeed': {
            'label': {
                'zh': 'deepspeed',
                'en': 'deepspeed',
            },
            'info': {
                'zh': '可以选择下拉列表,也支持传入路径',
                'en': 'Choose from the dropbox or fill in a valid path',
            }
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='advanced_param', open=False):
            with gr.Blocks():
                with gr.Row():
                    gr.Textbox(elem_id='optim', lines=1, scale=20)
                    gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
                    gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
                    gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
                    gr.Textbox(elem_id='max_steps', lines=1, scale=20)
                    gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
                with gr.Row():
                    gr.Dropdown(elem_id='truncation_strategy', scale=20)
                    gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
                    gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
                    gr.Dropdown(
                        elem_id='deepspeed',
                        scale=20,
                        allow_custom_value=True,
                        value=None,
                        choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
                with gr.Row():
                    gr.Textbox(elem_id='more_params', lines=4, scale=20)