File size: 3,451 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Type
import gradio as gr
from swift.llm import ModelType
from swift.llm.model.register import get_all_models
from swift.ui.base import BaseUI
class RLHF(BaseUI):
group = 'llm_train'
locale_dict = {
'rlhf_tab': {
'label': {
'zh': '人类对齐参数设置',
'en': 'RLHF settings'
},
},
'rlhf_type': {
'label': {
'zh': '人类对齐算法类型',
'en': 'RLHF type'
},
},
'ref_model_type': {
'label': {
'zh': '选择ref模型',
'en': 'Select ref model'
},
'info': {
'zh': 'SWIFT已支持的模型名称',
'en': 'Base model supported by SWIFT'
}
},
'ref_model': {
'label': {
'zh': 'ref模型id或路径',
'en': 'Ref model id or path'
},
'info': {
'zh': '实际的模型id或路径',
'en': 'The actual model id or path'
}
},
'beta': {
'label': {
'zh': 'KL正则项系数',
'en': 'KL regression ratio'
},
},
'rpo_alpha': {
'label': {
'zh': 'DPO中混合sft交叉熵的系数',
'en': 'DPO Cross Entropy ratio'
},
},
'simpo_gamma': {
'label': {
'zh': 'SimPO reward margin',
'en': 'SimPO reward margin'
},
},
'desirable_weight': {
'label': {
'zh': 'KTO符合项系数',
'en': 'KTO desirable ratio'
},
},
'undesirable_weight': {
'label': {
'zh': 'KTO不符合项系数',
'en': 'KTO undesirable ratio'
},
}
}
@classmethod
def do_build_ui(cls, base_tab: Type['BaseUI']):
with gr.Accordion(elem_id='rlhf_tab', open=False):
with gr.Blocks():
with gr.Row():
gr.Dropdown(elem_id='rlhf_type', value=None)
gr.Dropdown(
elem_id='ref_model', scale=20, value=None, choices=get_all_models(), allow_custom_value=True)
gr.Dropdown(elem_id='ref_model_type', choices=ModelType.get_model_name_list(), value=None, scale=20)
with gr.Row():
gr.Slider(elem_id='beta', minimum=0., maximum=5.0, step=0.1, scale=20)
gr.Slider(elem_id='rpo_alpha', minimum=0., maximum=2, step=0.1, scale=20)
gr.Slider(elem_id='simpo_gamma', minimum=0., maximum=2.0, step=0.1, scale=20)
gr.Slider(elem_id='desirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
gr.Slider(elem_id='undesirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
@classmethod
def after_build_ui(cls, base_tab: Type['BaseUI']):
cls.element('ref_model').change(
partial(cls.update_input_model, allow_keys=['ref_model_type'], has_record=False, is_ref_model=True),
inputs=[cls.element('ref_model')],
outputs=[cls.element('ref_model_type')])
|