File size: 5,545 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type
import gradio as gr
from swift.ui.base import BaseUI
class Advanced(BaseUI):
group = 'llm_train'
locale_dict = {
'advanced_param': {
'label': {
'zh': '高级参数设置',
'en': 'Advanced settings'
},
},
'optim': {
'label': {
'zh': 'Optimizer类型',
'en': 'The Optimizer type'
},
'info': {
'zh': '设置Optimizer类型',
'en': 'Set the Optimizer type'
}
},
'weight_decay': {
'label': {
'zh': '权重衰减',
'en': 'Weight decay'
},
'info': {
'zh': '设置weight decay',
'en': 'Set the weight decay'
}
},
'logging_steps': {
'label': {
'zh': '日志打印步数',
'en': 'Logging steps'
},
'info': {
'zh': '设置日志打印的步数间隔',
'en': 'Set the logging interval'
}
},
'lr_scheduler_type': {
'label': {
'zh': 'LrScheduler类型',
'en': 'The LrScheduler type'
},
'info': {
'zh': '设置LrScheduler类型',
'en': 'Set the LrScheduler type'
}
},
'warmup_ratio': {
'label': {
'zh': '学习率warmup比例',
'en': 'Lr warmup ratio'
},
'info': {
'zh': '设置学习率warmup比例',
'en': 'Set the warmup ratio in total steps'
}
},
'more_params': {
'label': {
'zh': '其他高级参数',
'en': 'Other params'
},
'info': {
'zh': '以json格式或--xxx xxx命令行格式填入',
'en': 'Fill in with json format or --xxx xxx cmd format'
}
},
'truncation_strategy': {
'label': {
'zh': '数据集超长策略',
'en': 'Dataset truncation strategy'
},
'info': {
'zh': '如果token超长该如何处理',
'en': 'How to deal with the rows exceed the max length'
}
},
'max_steps': {
'label': {
'zh': '最大迭代步数',
'en': 'Max steps',
},
'info': {
'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
}
},
'per_device_eval_batch_size': {
'label': {
'zh': '验证batch size',
'en': 'Val batch size',
},
'info': {
'zh': '验证的batch size',
'en': 'Set the val batch size',
}
},
'max_grad_norm': {
'label': {
'zh': '梯度裁剪',
'en': 'Max grad norm',
},
'info': {
'zh': '设置梯度裁剪',
'en': 'Set the max grad norm',
}
},
'predict_with_generate': {
'label': {
'zh': '使用生成指标代替loss',
'en': 'Use generate metric instead of loss',
},
'info': {
'zh': '验证时使用generate/Rouge代替loss',
'en': 'Use model.generate/Rouge instead of loss',
}
},
'deepspeed': {
'label': {
'zh': 'deepspeed',
'en': 'deepspeed',
},
'info': {
'zh': '可以选择下拉列表,也支持传入路径',
'en': 'Choose from the dropbox or fill in a valid path',
}
},
}
@classmethod
def do_build_ui(cls, base_tab: Type['BaseUI']):
with gr.Accordion(elem_id='advanced_param', open=False):
with gr.Blocks():
with gr.Row():
gr.Textbox(elem_id='optim', lines=1, scale=20)
gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
gr.Textbox(elem_id='max_steps', lines=1, scale=20)
gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
with gr.Row():
gr.Dropdown(elem_id='truncation_strategy', scale=20)
gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
gr.Dropdown(
elem_id='deepspeed',
scale=20,
allow_custom_value=True,
value=None,
choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
with gr.Row():
gr.Textbox(elem_id='more_params', lines=4, scale=20)
|