File size: 1,845 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Galore(BaseUI):

    group = 'llm_train'

    locale_dict = {
        'galore_tab': {
            'label': {
                'zh': 'Galore参数设置',
                'en': 'Galore Settings'
            },
        },
        'use_galore': {
            'label': {
                'zh': '使用GaLore',
                'en': 'Use GaLore'
            },
            'info': {
                'zh': '使用Galore来减少全参数训练的显存消耗',
                'en': 'Use Galore to reduce GPU memory usage in full parameter training'
            }
        },
        'galore_rank': {
            'label': {
                'zh': 'Galore的秩',
                'en': 'The rank of Galore'
            },
        },
        'galore_update_proj_gap': {
            'label': {
                'zh': 'Galore project matrix更新频率',
                'en': 'The updating gap of the project matrix'
            },
        },
        'galore_optim_per_parameter': {
            'label': {
                'zh': '为每个Galore Parameter创建单独的optimizer',
                'en': 'Create unique optimizer for per Galore parameter'
            },
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='galore_tab', open=False):
            with gr.Blocks():
                with gr.Row():
                    gr.Checkbox(elem_id='use_galore', scale=4)
                    gr.Slider(elem_id='galore_rank', minimum=8, maximum=256, step=8, scale=4)
                    gr.Slider(elem_id='galore_update_proj_gap', minimum=10, maximum=1000, step=50, scale=4)
                    gr.Checkbox(elem_id='galore_optim_per_parameter', scale=4)