Student0809 commited on
Commit
e7a862c
·
verified ·
1 Parent(s): ac35f70

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ms-swift/silence_overlaps/only_overlap/.ipynb_checkpoints/overlap5s_isoverlap_train-checkpoint.json +0 -0
  2. ms-swift/swift/tuners/__pycache__/lora.cpython-310.pyc +0 -0
  3. ms-swift/swift/tuners/__pycache__/mapping.cpython-310.pyc +0 -0
  4. ms-swift/swift/tuners/__pycache__/part.cpython-310.pyc +0 -0
  5. ms-swift/swift/tuners/__pycache__/restuning.cpython-310.pyc +0 -0
  6. ms-swift/swift/tuners/__pycache__/restuning_components.cpython-310.pyc +0 -0
  7. ms-swift/swift/tuners/__pycache__/side.cpython-310.pyc +0 -0
  8. ms-swift/swift/tuners/__pycache__/utils.cpython-310.pyc +0 -0
  9. ms-swift/swift/tuners/adapter.py +189 -0
  10. ms-swift/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc +0 -0
  11. ms-swift/swift/tuners/peft.py +392 -0
  12. ms-swift/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc +0 -0
  13. ms-swift/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc +0 -0
  14. ms-swift/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc +0 -0
  15. ms-swift/swift/tuners/scetuning/scetuning_components.py +127 -0
  16. ms-swift/swift/tuners/side.py +245 -0
  17. ms-swift/swift/ui/app.py +92 -0
  18. ms-swift/swift/ui/base.py +388 -0
  19. ms-swift/swift/ui/llm_eval/__init__.py +1 -0
  20. ms-swift/swift/ui/llm_eval/eval.py +130 -0
  21. ms-swift/swift/ui/llm_eval/model.py +78 -0
  22. ms-swift/swift/ui/llm_export/llm_export.py +191 -0
  23. ms-swift/swift/ui/llm_export/model.py +83 -0
  24. ms-swift/swift/ui/llm_export/runtime.py +75 -0
  25. ms-swift/swift/ui/llm_infer/__init__.py +1 -0
  26. ms-swift/swift/ui/llm_infer/generate.py +65 -0
  27. ms-swift/swift/ui/llm_infer/llm_infer.py +396 -0
  28. ms-swift/swift/ui/llm_infer/model.py +126 -0
  29. ms-swift/swift/ui/llm_infer/runtime.py +285 -0
  30. ms-swift/swift/ui/llm_train/__init__.py +1 -0
  31. ms-swift/swift/ui/llm_train/advanced.py +164 -0
  32. ms-swift/swift/ui/llm_train/dataset.py +91 -0
  33. ms-swift/swift/ui/llm_train/hyper.py +129 -0
  34. ms-swift/swift/ui/llm_train/llamapro.py +40 -0
  35. ms-swift/swift/ui/llm_train/llm_train.py +420 -0
  36. ms-swift/swift/ui/llm_train/lora.py +102 -0
  37. ms-swift/swift/ui/llm_train/model.py +127 -0
  38. ms-swift/swift/ui/llm_train/quantization.py +68 -0
  39. ms-swift/swift/ui/llm_train/report_to.py +75 -0
  40. ms-swift/swift/ui/llm_train/rlhf.py +102 -0
  41. ms-swift/swift/ui/llm_train/runtime.py +571 -0
  42. ms-swift/swift/ui/llm_train/save.py +84 -0
  43. ms-swift/swift/ui/llm_train/self_cog.py +57 -0
  44. ms-swift/swift/utils/__init__.py +19 -0
  45. ms-swift/swift/utils/__pycache__/np_utils.cpython-310.pyc +0 -0
  46. ms-swift/swift/utils/constants.py +27 -0
  47. ms-swift/swift/utils/logger.py +138 -0
  48. ms-swift/swift/utils/tb_utils.py +72 -0
  49. ms-swift/swift/utils/torch_utils.py +391 -0
  50. ms-swift/tests/deploy/test_dataset.py +61 -0
ms-swift/silence_overlaps/only_overlap/.ipynb_checkpoints/overlap5s_isoverlap_train-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/swift/tuners/__pycache__/lora.cpython-310.pyc ADDED
Binary file (7.04 kB). View file
 
ms-swift/swift/tuners/__pycache__/mapping.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
ms-swift/swift/tuners/__pycache__/part.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
ms-swift/swift/tuners/__pycache__/restuning.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
ms-swift/swift/tuners/__pycache__/restuning_components.cpython-310.pyc ADDED
Binary file (9.9 kB). View file
 
ms-swift/swift/tuners/__pycache__/side.cpython-310.pyc ADDED
Binary file (8.84 kB). View file
 
ms-swift/swift/tuners/__pycache__/utils.cpython-310.pyc ADDED
Binary file (16.4 kB). View file
 
ms-swift/swift/tuners/adapter.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import inspect
3
+ import re
4
+ import types
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Union
7
+
8
+ import torch
9
+ from torch import nn
10
+ from transformers.activations import ACT2CLS
11
+
12
+ from swift.utils.torch_utils import find_sub_module, get_logger
13
+ from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ @dataclass
19
+ class AdapterConfig(SwiftConfig):
20
+ """
21
+ The configuration class for the adapter module.
22
+
23
+ Adapters project input tokens by an MLP layer.
24
+ 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
25
+ See http://arxiv.org/abs/1902.00751
26
+
27
+ Args:
28
+ dim(`int`): The dimension of the hidden states
29
+ target_modules(`Union[str, List[str]]`): The feedforward module to be replaced.
30
+ in regex format if this argument is str, else will match with `end with` if List[str].
31
+ hidden_pos(`Union[str, int]`): The position of the hidden state to be passed into the adapter,
32
+ can be int (args) or str (kwargs)
33
+ method_name(`str`): The method to be replaced, default is `forward`
34
+ adapter_length: The length of the adapter length (intermediate length)
35
+ act_layer: The activation layer of the adapter
36
+ """
37
+
38
+ dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
39
+
40
+ target_modules: Union[str, List[str]] = field(
41
+ default=None,
42
+ metadata={
43
+ 'help':
44
+ 'The feedforward module to be replaced. in regex format if this argument is str, '
45
+ 'else will match with `end with` if List[str].'
46
+ })
47
+
48
+ hidden_pos: Union[str, int] = field(
49
+ default=None,
50
+ metadata={
51
+ 'help': 'The position of the hidden state to be passed into the adapter, can be int (args) or str (kwargs)'
52
+ })
53
+
54
+ method_name: str = field(default='forward', metadata={'help': 'The method to be replaced, default is `forward`'})
55
+
56
+ adapter_length: int = field(
57
+ default=128, metadata={'help': 'The length of the adapter length (intermediate length)'})
58
+
59
+ act_layer: str = field(default='gelu', metadata={'help': 'The activation layer of the adapter'})
60
+
61
+ def __post_init__(self):
62
+ from .mapping import SwiftTuners
63
+ self.swift_type = SwiftTuners.ADAPTER
64
+
65
+
66
+ class Adapter(SwiftAdapter):
67
+
68
+ @staticmethod
69
+ def prepare_model(model: nn.Module, config: AdapterConfig, adapter_name: str) -> SwiftOutput:
70
+ """Prepare a model with `AdapterConfig`"""
71
+ module_keys = [key for key, _ in model.named_modules()]
72
+
73
+ for module_key in module_keys:
74
+ if isinstance(config.target_modules, str):
75
+ target_module_found = re.fullmatch(config.target_modules, module_key)
76
+ else:
77
+ target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
78
+
79
+ if target_module_found: # noqa
80
+ module = model.get_submodule(module_key)
81
+
82
+ def _forward(self, *args, **kwargs):
83
+ args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
84
+ if isinstance(args, (tuple, list, dict)):
85
+ if isinstance(config.hidden_pos, int):
86
+ _type = type(args)
87
+ args = list(args)
88
+ args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
89
+ args = _type(args)
90
+ else:
91
+ args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
92
+ elif isinstance(args, torch.Tensor):
93
+ args = getattr(self, f'adapter_{adapter_name}')(args)
94
+ return args
95
+
96
+ def _feed_forward_chunk(self, attention_output):
97
+ return _forward(self, attention_output)
98
+
99
+ # TODO The `config.method_name` method should not be replaced twice.
100
+
101
+ setattr(module, f'forward_origin_{adapter_name}', getattr(module, config.method_name))
102
+ num_args_in_forward_chunk_fn = len(
103
+ inspect.signature(getattr(module, f'forward_origin_{adapter_name}')).parameters)
104
+ if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1:
105
+ setattr(module, config.method_name, types.MethodType(_feed_forward_chunk, module))
106
+ else:
107
+ setattr(module, config.method_name, types.MethodType(_forward, module))
108
+ adapter_module = AdapterModule(config.dim, adapter_name, module_key, config.adapter_length,
109
+ ACT2CLS[config.act_layer])
110
+ setattr(module, f'adapter_{adapter_name}', adapter_module)
111
+ logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}')
112
+
113
+ def state_dict_callback(state_dict, adapter_name: str, **kwargs):
114
+ return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key}
115
+
116
+ def mark_trainable_callback(model):
117
+ return
118
+
119
+ return SwiftOutput(
120
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
121
+
122
+ @staticmethod
123
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
124
+ modules = find_sub_module(module, f'adapter_{adapter_name}')
125
+ for _module in modules:
126
+ _module: ActivationMixin
127
+ _module: nn.Module
128
+ _module.set_activation(adapter_name, activate)
129
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
130
+
131
+
132
+ class AdapterModule(nn.Module, ActivationMixin):
133
+ """The implementation of adapter tuning method.
134
+
135
+ Adapters project input tokens by an MLP layer.
136
+ 'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
137
+ See http://arxiv.org/abs/1902.00751
138
+
139
+ Args:
140
+ dim: An integer indicating the embedding dimension.
141
+ adapter_length: An integer indicating the length of adapter tuning.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ dim,
147
+ adapter_name,
148
+ module_key,
149
+ adapter_length=None,
150
+ act_layer=nn.GELU,
151
+ ):
152
+ super(AdapterModule, self).__init__()
153
+ super(nn.Module, self).__init__(module_key)
154
+ self.dim = dim
155
+ self.adapter_name = adapter_name
156
+ self.adapter_length = adapter_length
157
+ self.linear1 = nn.Linear(dim, adapter_length)
158
+ self.act = act_layer()
159
+ self.linear2 = nn.Linear(adapter_length, dim)
160
+ self.init_weights()
161
+ self._prepared = False
162
+ self.mark_all_sub_modules_as_plugin()
163
+
164
+ def init_weights(self):
165
+
166
+ def _init_weights(m):
167
+ if isinstance(m, nn.Linear):
168
+ nn.init.xavier_uniform_(m.weight)
169
+ nn.init.normal_(m.bias, std=1e-6)
170
+
171
+ self.apply(_init_weights)
172
+
173
+ def forward(self, x, identity=None):
174
+ if not self.is_activated(self.adapter_name):
175
+ return x
176
+ if not self._prepared:
177
+ self.linear1.to(x.device)
178
+ self.act.to(x.device)
179
+ self.linear2.to(x.device)
180
+ self._prepared = True
181
+
182
+ x_dtype = x.dtype
183
+ x = x.to(self.linear1.weight.dtype)
184
+ out = self.linear2(self.act(self.linear1(x)))
185
+ if identity is None:
186
+ identity = x
187
+ identity = identity.to(out.dtype)
188
+ out = identity + out
189
+ return out.to(x_dtype)
ms-swift/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
ms-swift/swift/tuners/peft.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Copyright 2023-present the HuggingFace Inc. team.
3
+ import os.path
4
+ from dataclasses import asdict, dataclass, field
5
+ from functools import partial, reduce
6
+ from types import MethodType
7
+ from typing import Dict, Optional
8
+
9
+ import json
10
+ import peft
11
+ import torch
12
+ import torch.nn
13
+ import transformers
14
+ from modelscope import snapshot_download
15
+ from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig,
16
+ PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
17
+ PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig,
18
+ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config,
19
+ get_peft_model, get_peft_model_state_dict)
20
+ from peft.config import PeftConfigMixin
21
+ from peft.tuners import lora
22
+ from peft.tuners.adalora import AdaLoraModel, RankAllocator
23
+ from peft.tuners.lora import Embedding
24
+ from transformers import Trainer
25
+
26
+ from swift.utils import get_logger
27
+
28
+ try:
29
+ from peft import FourierFTModel
30
+ except ImportError:
31
+ FourierFTModel = None
32
+
33
+ try:
34
+ from peft import BoneModel
35
+ except ImportError:
36
+ BoneModel = None
37
+
38
+ logger = get_logger()
39
+ dispatchers = []
40
+
41
+
42
+ @dataclass
43
+ class LoraConfig(peft.LoraConfig):
44
+ lora_dtype: Optional[str] = field(
45
+ default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})
46
+
47
+ lorap_lr_ratio: Optional[float] = field(default=None, metadata={'help': 'The lr ratio of lora_B in lora+'})
48
+
49
+ lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})
50
+
51
+ def to_peft_config(self) -> peft.LoraConfig:
52
+ _dict = asdict(self)
53
+ _dict.pop('lora_dtype')
54
+ _dict.pop('lorap_lr_ratio')
55
+ _dict.pop('lorap_emb_lr')
56
+ return peft.LoraConfig(**_dict)
57
+
58
+ def save_pretrained(self, save_directory: str, **kwargs) -> None:
59
+ self.to_peft_config().save_pretrained(save_directory, **kwargs)
60
+ additional_args = {
61
+ 'lora_dtype': self.lora_dtype,
62
+ 'lorap_lr_ratio': self.lorap_lr_ratio,
63
+ 'lorap_emb_lr': self.lorap_emb_lr,
64
+ }
65
+ with open(os.path.join(save_directory, 'additional_config.json'), 'w', encoding='utf-8') as f:
66
+ json.dump(additional_args, f)
67
+
68
+ @classmethod
69
+ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
70
+ if hasattr(PeftConfigMixin, 'from_pretrained_origin'):
71
+ self = PeftConfigMixin.from_pretrained_origin(pretrained_model_name_or_path, subfolder, **kwargs)
72
+ else:
73
+ self = super(LoraConfig, cls).from_pretrained(pretrained_model_name_or_path, subfolder, **kwargs)
74
+
75
+ if type(self) == peft.LoraConfig:
76
+ self = LoraConfig(**self.to_dict())
77
+
78
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, 'additional_config.json')):
79
+ with open(
80
+ os.path.join(pretrained_model_name_or_path, 'additional_config.json'), 'r', encoding='utf-8') as f:
81
+ _json = json.load(f)
82
+ for key, value in _json.items():
83
+ setattr(self, key, value)
84
+
85
+ return self
86
+
87
+
88
+ def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
89
+ all_supported_names = ('linear', )
90
+ all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear)
91
+ target_modules = getattr(peft_config, 'target_modules', None)
92
+ if target is None:
93
+ return
94
+
95
+ if isinstance(target_modules, str) and not any(
96
+ [name in target.__class__.__name__.lower()
97
+ for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
98
+ return
99
+
100
+ if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
101
+ return
102
+
103
+ return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs)
104
+
105
+
106
+ def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str):
107
+ if lora_dtype is not None:
108
+ torch_dtype = eval(f'torch.{lora_dtype}')
109
+ if hasattr(target, 'lora_A') and adapter_name in target.lora_A:
110
+ target.lora_A[adapter_name].to(torch_dtype)
111
+ target.lora_B[adapter_name].to(torch_dtype)
112
+ if hasattr(target, 'lora_embedding_A') and adapter_name in target.lora_embedding_A:
113
+ target.lora_embedding_A[adapter_name].to(torch_dtype)
114
+ target.lora_embedding_B[adapter_name].to(torch_dtype)
115
+
116
+
117
+ def create_optimizer_param_groups(self: PeftModel, **defaults):
118
+ if not isinstance(self.peft_config[self.active_adapter],
119
+ LoraConfig) or self.peft_config[self.active_adapter].lorap_lr_ratio is None:
120
+ return None
121
+
122
+ def get_module(name):
123
+ parent_idx = 2 if 'lora' in name else 1
124
+ module_names = name.split(sep='.')[:-parent_idx]
125
+ module = reduce(getattr, module_names, self.base_model)
126
+ return module
127
+
128
+ param_groups = {
129
+ 'groupA': {},
130
+ 'groupB': {},
131
+ 'groupB_no_decay': {},
132
+ 'embedding': {},
133
+ }
134
+
135
+ decay_parameters = Trainer.get_decay_parameter_names(None, self.base_model)
136
+ for name, param in self.base_model.named_parameters():
137
+ if not param.requires_grad:
138
+ continue
139
+
140
+ module = get_module(name)
141
+ if isinstance(module, Embedding):
142
+ param_groups['embedding'][name] = param
143
+ elif 'lora_B' in name or param.ndim == 1:
144
+ if name in decay_parameters:
145
+ param_groups['groupB'][name] = param
146
+ else:
147
+ param_groups['groupB_no_decay'][name] = param
148
+ else:
149
+ param_groups['groupA'][name] = param
150
+
151
+ lr = defaults['lr']
152
+ weight_decay = defaults.get('weight_decay', 0.0)
153
+
154
+ param_groups = [
155
+ {
156
+ 'params': list(param_groups['groupA'].values()),
157
+ 'weight_decay': weight_decay,
158
+ 'lr': lr,
159
+ },
160
+ {
161
+ 'params': list(param_groups['embedding'].values()),
162
+ 'weight_decay': weight_decay,
163
+ 'lr': self.peft_config[self.active_adapter].lorap_emb_lr,
164
+ },
165
+ {
166
+ 'params': list(param_groups['groupB'].values()),
167
+ 'weight_decay': weight_decay,
168
+ 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
169
+ },
170
+ {
171
+ 'params': list(param_groups['groupB_no_decay'].values()),
172
+ 'weight_decay': 0.0,
173
+ 'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
174
+ },
175
+ ]
176
+ return param_groups
177
+
178
+
179
+ def adalora_forward(self, *args, **kwargs):
180
+ from peft.utils.integrations import gather_params_ctx
181
+ outputs = self.model.forward(*args, **kwargs)
182
+
183
+ if (getattr(outputs, 'loss', None) is not None) and isinstance(outputs.loss, torch.Tensor):
184
+ # Calculate the orthogonal regularization
185
+ orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
186
+
187
+ if orth_reg_weight <= 0:
188
+ raise ValueError('orth_reg_weight should be greater than 0. ')
189
+
190
+ regu_loss = 0
191
+ num_param = 0
192
+ for n, p in self.model.named_parameters():
193
+ if ('lora_A' in n or 'lora_B' in n) and self.trainable_adapter_name in n:
194
+ if p.shape == torch.Size([0]):
195
+ with gather_params_ctx(p, fwd_module=self):
196
+ para_cov = p @ p.T if 'lora_A' in n else p.T @ p
197
+ else:
198
+ para_cov = p @ p.T if 'lora_A' in n else p.T @ p
199
+ I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
200
+ I.requires_grad = False
201
+ num_param += 1
202
+ if isinstance(regu_loss, torch.Tensor):
203
+ regu_loss = regu_loss.to(para_cov.device)
204
+ regu_loss += torch.norm(para_cov - I, p='fro')
205
+ if num_param > 0:
206
+ regu_loss = regu_loss / num_param
207
+ else:
208
+ regu_loss = 0
209
+ if isinstance(regu_loss, torch.Tensor) and isinstance(outputs.loss, torch.Tensor):
210
+ regu_loss = regu_loss.to(outputs.loss.device)
211
+ outputs.loss += orth_reg_weight * regu_loss
212
+ return outputs
213
+
214
+
215
+ def adalora_mask_to_budget(self, model, budget):
216
+ value_ipt = {}
217
+ vector_ipt = {}
218
+ triplet_ipt = {}
219
+ # Get the importance score for A, E, B
220
+ for n, p in model.named_parameters():
221
+ if f'lora_A.{self.adapter_name}' in n:
222
+ entry_ipt = self._element_score(n)
223
+ comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
224
+ name_m = n.replace('lora_A', '%s')
225
+ if name_m not in vector_ipt:
226
+ vector_ipt[name_m] = [comb_ipt]
227
+ else:
228
+ vector_ipt[name_m].append(comb_ipt)
229
+ if f'lora_B.{self.adapter_name}' in n:
230
+ entry_ipt = self._element_score(n)
231
+ comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
232
+ name_m = n.replace('lora_B', '%s')
233
+ if name_m not in vector_ipt:
234
+ vector_ipt[name_m] = [comb_ipt]
235
+ else:
236
+ vector_ipt[name_m].append(comb_ipt)
237
+ if f'lora_E.{self.adapter_name}' in n:
238
+ entry_ipt = self._element_score(n)
239
+ name_m = n.replace('lora_E', '%s')
240
+ value_ipt[name_m] = entry_ipt
241
+
242
+ all_score = []
243
+ # Calculate the score for each triplet
244
+ for name_m in vector_ipt:
245
+ ipt_E = value_ipt[name_m]
246
+ ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
247
+ sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
248
+ name_E = name_m % 'lora_E'
249
+ triplet_ipt[name_E] = sum_ipt.view(-1, 1)
250
+ sum_ipt = sum_ipt.view(-1)
251
+ if all_score:
252
+ sum_ipt = sum_ipt.to(all_score[0].device)
253
+ all_score.append(sum_ipt)
254
+
255
+ # Get the threshold by ranking ipt
256
+ mask_threshold = torch.kthvalue(
257
+ torch.cat(all_score),
258
+ k=self.init_bgt - budget,
259
+ )[0].item()
260
+
261
+ rank_pattern = {}
262
+ # Mask the unimportant triplets
263
+ with torch.no_grad():
264
+ for n, p in model.named_parameters():
265
+ if f'lora_E.{self.adapter_name}' in n:
266
+ p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
267
+ rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
268
+ return rank_pattern
269
+
270
+
271
+ def keep_device_forward(self, *args, **kwargs):
272
+ x = args[0]
273
+ if self.weight.device != x.device:
274
+ return self.forward_origin(x.to(self.weight.device), *args[1:], **kwargs)
275
+ else:
276
+ return self.forward_origin(*args, **kwargs)
277
+
278
+
279
+ def hot_patch_peft_module():
280
+ from peft.tuners.lora import LoraLayer
281
+ if hasattr('LoraModel', '_create_and_replace_origin'):
282
+ return
283
+
284
+ # Fix Lora does not support NonDynamicallyQuantizableLinear
285
+ LoraModel._create_and_replace_origin = LoraModel._create_and_replace
286
+ LoraModel._create_and_replace = _create_and_replace_hook
287
+ AdaLoraModel._create_and_replace_origin = AdaLoraModel._create_and_replace
288
+ AdaLoraModel._create_and_replace = _create_and_replace_hook
289
+ VeraModel._create_and_replace_origin = VeraModel._create_and_replace
290
+ VeraModel._create_and_replace = _create_and_replace_hook
291
+ BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
292
+ BOFTModel._create_and_replace = _create_and_replace_hook
293
+ if FourierFTModel is not None:
294
+ FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
295
+ FourierFTModel._create_and_replace = _create_and_replace_hook
296
+ if BoneModel is not None:
297
+ BoneModel._create_and_replace_origin = BoneModel._create_and_replace
298
+ BoneModel._create_and_replace = _create_and_replace_hook
299
+
300
+ # Support type conversion
301
+ def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str):
302
+
303
+ self.__init_origin__(model, config, adapter_name)
304
+ active_adapters = self.active_adapter
305
+ if isinstance(active_adapters, str):
306
+ active_adapters = [active_adapters]
307
+ for active_adapter in active_adapters:
308
+ active_config = config[active_adapter] if isinstance(config, dict) else config
309
+ if hasattr(active_config, 'lora_dtype'):
310
+ for name, module in model.named_modules():
311
+ if isinstance(module, LoraLayer):
312
+ _convert_dtype(module, active_adapter, active_config.lora_dtype)
313
+ for lora in list(module.lora_A.values()) + list(module.lora_B.values()):
314
+ if not hasattr(lora, 'forward_origin'):
315
+ lora.forward_origin = lora.forward
316
+ lora.forward = MethodType(keep_device_forward, lora)
317
+
318
+ LoraModel.__init_origin__ = LoraModel.__init__
319
+ LoraModel.__init__ = __new_init__
320
+
321
+ # Support LoRA+
322
+ PeftModel.create_optimizer_param_groups = create_optimizer_param_groups
323
+
324
+ PeftConfigMixin.from_pretrained_origin = PeftConfigMixin.from_pretrained
325
+ PeftConfigMixin.from_pretrained = LoraConfig.from_pretrained
326
+
327
+ # Compatible with SwiftModel
328
+ def dummy_function(*args, **kwargs):
329
+ logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.')
330
+
331
+ PeftModel.activate_adapter = PeftModel.set_adapter
332
+ PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter')
333
+ PeftModel.set_active_adapters = partial(dummy_function, func='set_active_adapters')
334
+
335
+ # Fix adalora does not support device_map
336
+ AdaLoraModel.forward = adalora_forward
337
+ RankAllocator.mask_to_budget = adalora_mask_to_budget
338
+
339
+
340
+ def get_wrapped_class(module_class):
341
+ """Get a custom wrapper class for peft classes to download the models from the ModelScope hub
342
+
343
+ Args:
344
+ module_class: The actual module class
345
+
346
+ Returns:
347
+ The wrapper
348
+ """
349
+
350
+ class PeftWrapper(module_class):
351
+
352
+ @classmethod
353
+ def from_pretrained(cls, model, model_id, *args, revision: Optional[str] = None, **kwargs):
354
+ if not os.path.exists(model_id):
355
+ model_id = snapshot_download(model_id, revision=revision)
356
+ return module_class.from_pretrained(model, model_id, *args, **kwargs)
357
+
358
+ PeftWrapper.__name__ = module_class.__name__
359
+ PeftWrapper.__qualname__ = module_class.__qualname__
360
+ return PeftWrapper
361
+
362
+
363
+ def wrap_module(module):
364
+ if not hasattr(module, 'from_pretrained'):
365
+ return module
366
+
367
+ return get_wrapped_class(module)
368
+
369
+
370
+ hot_patch_peft_module()
371
+ PeftModel = wrap_module(PeftModel)
372
+ PeftConfig = wrap_module(PeftConfig)
373
+ PeftModelForSeq2SeqLM = wrap_module(PeftModelForSeq2SeqLM)
374
+ PeftModelForSequenceClassification = wrap_module(PeftModelForSequenceClassification)
375
+ PeftModelForTokenClassification = wrap_module(PeftModelForTokenClassification)
376
+ PeftModelForCausalLM = wrap_module(PeftModelForCausalLM)
377
+ PromptEncoderConfig = wrap_module(PromptEncoderConfig)
378
+ PromptTuningConfig = wrap_module(PromptTuningConfig)
379
+ PrefixTuningConfig = wrap_module(PrefixTuningConfig)
380
+ PromptLearningConfig = wrap_module(PromptLearningConfig)
381
+ LoraConfig = wrap_module(LoraConfig)
382
+ AdaLoraConfig = wrap_module(AdaLoraConfig)
383
+ LoHaConfig = wrap_module(LoHaConfig)
384
+ LoKrConfig = wrap_module(LoKrConfig)
385
+ LoftQConfig = wrap_module(LoftQConfig)
386
+ OFTConfig = wrap_module(OFTConfig)
387
+ BOFTConfig = wrap_module(BOFTConfig)
388
+ VeraConfig = wrap_module(VeraConfig)
389
+ OFTConfig = wrap_module(OFTConfig)
390
+ get_peft_config = get_peft_config
391
+ get_peft_model_state_dict = get_peft_model_state_dict
392
+ get_peft_model = get_peft_model
ms-swift/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (246 Bytes). View file
 
ms-swift/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc ADDED
Binary file (8.37 kB). View file
 
ms-swift/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
ms-swift/swift/tuners/scetuning/scetuning_components.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from swift.utils.logger import get_logger
8
+
9
+ logger = get_logger()
10
+
11
+
12
+ def detach_tensors(feats):
13
+ if type(feats) in [list, tuple]:
14
+ feats = [detach_tensors(feat) if feat is not None else None for feat in feats]
15
+ elif isinstance(feats, dict):
16
+ feats = {key: detach_tensors(val) for key, val in feats.items()}
17
+ elif isinstance(feats, torch.Tensor):
18
+ feats = feats.detach()
19
+ else:
20
+ feats = feats.detach()
21
+ return feats
22
+
23
+
24
+ def probe_tensors(module, feats, name):
25
+ feats = detach_tensors(feats)
26
+ setattr(module, name, feats)
27
+
28
+
29
+ def probe_input_pre_hook(self, args):
30
+ input = args[0]
31
+ probe_tensors(self, input, 'probe_input_data')
32
+ return args
33
+
34
+
35
+ def probe_output_hook(self, args, result):
36
+ output = result
37
+ probe_tensors(self, output, 'probe_output_data')
38
+ return output
39
+
40
+
41
+ def choose_weight_type(weight_type, dim):
42
+ if weight_type == 'gate':
43
+ scaling = nn.Linear(dim, 1)
44
+ elif weight_type == 'scale':
45
+ scaling = nn.Parameter(torch.Tensor(1))
46
+ scaling.data.fill_(1)
47
+ elif weight_type == 'scale_channel':
48
+ scaling = nn.Parameter(torch.Tensor(dim))
49
+ scaling.data.fill_(1)
50
+ elif weight_type and weight_type.startswith('scalar'):
51
+ scaling = float(weight_type.split('_')[-1])
52
+ else:
53
+ scaling = None
54
+ return scaling
55
+
56
+
57
+ def get_weight_value(weight_type, scaling, x):
58
+ if weight_type in ['gate']:
59
+ scaling = torch.mean(torch.sigmoid(scaling(x)), dim=1).view(-1, 1, 1)
60
+ elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'):
61
+ scaling = scaling
62
+ else:
63
+ scaling = None
64
+ return scaling
65
+
66
+
67
+ class SCEAdapter(nn.Module):
68
+
69
+ def __init__(self,
70
+ dim,
71
+ adapter_length,
72
+ adapter_type=None,
73
+ adapter_weight=None,
74
+ act_layer=nn.GELU,
75
+ zero_init_last=True,
76
+ use_bias=True):
77
+ super(SCEAdapter, self).__init__()
78
+ self.dim = dim
79
+ self.adapter_length = adapter_length
80
+ self.adapter_type = adapter_type
81
+ self.adapter_weight = adapter_weight
82
+ self.zero_init_last = zero_init_last
83
+ self.ln1 = nn.Linear(dim, adapter_length, bias=use_bias)
84
+ self.activate = act_layer()
85
+ self.ln2 = nn.Linear(adapter_length, dim, bias=use_bias)
86
+ self.init_weights()
87
+ self.init_scaling()
88
+
89
+ def _zero_init_weights(self, m):
90
+ if isinstance(m, nn.Linear):
91
+ nn.init.zeros_(m.weight)
92
+ nn.init.zeros_(m.bias)
93
+
94
+ def _kaiming_init_weights(self, m):
95
+ if isinstance(m, nn.Linear):
96
+ nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
97
+
98
+ def init_weights(self):
99
+ self._kaiming_init_weights(self.ln1)
100
+ if self.zero_init_last:
101
+ self._zero_init_weights(self.ln2)
102
+ else:
103
+ self._kaiming_init_weights(self.ln2)
104
+
105
+ def init_scaling(self):
106
+ if self.adapter_weight:
107
+ self.scaling = choose_weight_type(self.adapter_weight, self.dim)
108
+ else:
109
+ self.scaling = None
110
+
111
+ def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
112
+ if x_shortcut is None:
113
+ x_shortcut = x
114
+ x_shape = x.shape
115
+ if len(x_shape) == 4:
116
+ b, d, h, w = x_shape
117
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, d)
118
+ out = self.ln2(self.activate(self.ln1(x)))
119
+ if self.adapter_weight:
120
+ scaling = get_weight_value(self.adapter_weight, self.scaling, out)
121
+ out = out * scaling if scaling is not None else out
122
+ if len(x_shape) == 4:
123
+ b, d, h, w = x_shape
124
+ out = out.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
125
+ if use_shortcut:
126
+ out = x_shortcut + out
127
+ return out
ms-swift/swift/tuners/side.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import copy
3
+ import re
4
+ import types
5
+ from collections import OrderedDict
6
+ from dataclasses import dataclass, field
7
+ from functools import partial
8
+ from itertools import repeat
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from swift.utils.logger import get_logger
15
+ from swift.utils.torch_utils import find_sub_module
16
+ from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ @dataclass
22
+ class SideConfig(SwiftConfig):
23
+ """
24
+ The configuration class for the side module.
25
+
26
+ Side-Tuning only needs to train one side network and
27
+ weights the output of pre-trained model and side network.
28
+ 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
29
+ by Zhang et al.(2019)
30
+ See https://arxiv.org/abs/1912.13503
31
+
32
+ Args:
33
+ target_modules: The feedforward module to be replaced, in regex format
34
+ """
35
+
36
+ dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
37
+
38
+ target_modules: str = field(
39
+ default=None, metadata={'help': 'The target module to be replaced, in full match format'})
40
+
41
+ side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'})
42
+
43
+ source_hidden_pos: Union[str, int] = field(
44
+ default=0,
45
+ metadata={
46
+ 'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)'
47
+ })
48
+
49
+ target_hidden_pos: Union[str, int] = field(
50
+ default=0,
51
+ metadata={
52
+ 'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)'
53
+ })
54
+
55
+ def __post_init__(self):
56
+ from .mapping import SwiftTuners
57
+ self.swift_type = SwiftTuners.SIDE
58
+
59
+
60
+ class Side(SwiftAdapter):
61
+
62
+ @staticmethod
63
+ def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput:
64
+ """Prepare a model with `SideConfig`"""
65
+ module_keys = [key for key, _ in model.named_modules()]
66
+
67
+ for module_key in module_keys:
68
+ if re.fullmatch(config.target_modules, module_key): # noqa
69
+ tgt_module = model.get_submodule(module_key)
70
+ logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
71
+ if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
72
+ raise Exception(
73
+ f'Type of {type(tgt_module)} may not be supported because of its customized forward')
74
+
75
+ def _forward(self, *args, **kwargs):
76
+ args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
77
+
78
+ if isinstance(config.source_hidden_pos, int):
79
+ x = args[config.source_hidden_pos]
80
+ else:
81
+ x = kwargs[config.source_hidden_pos]
82
+
83
+ x_main = args_main[config.target_hidden_pos] \
84
+ if isinstance(args_main, (tuple, list, dict)) else args_main
85
+ out = getattr(self, f'side_{adapter_name}')(x, x_main)
86
+ if isinstance(args_main, (tuple, list, dict)):
87
+ args_main[config.target_hidden_pos] = out
88
+ else:
89
+ args_main = out
90
+ return args_main
91
+
92
+ if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'):
93
+ tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
94
+
95
+ def forward_seq(self, input, *args, **kwargs):
96
+ for idx, module in enumerate(self):
97
+ if idx >= len(tgt_module.tgt_module_keys):
98
+ continue
99
+ input = module(input)
100
+ return input
101
+
102
+ setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module))
103
+ else:
104
+ setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
105
+ tgt_module.forward = types.MethodType(_forward, tgt_module)
106
+ side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name)
107
+ setattr(tgt_module, f'side_{adapter_name}', side_module)
108
+ logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')
109
+
110
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
111
+ return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}
112
+
113
+ def mark_trainable_callback(model):
114
+ return
115
+
116
+ return SwiftOutput(
117
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
118
+
119
+ @staticmethod
120
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
121
+ modules = find_sub_module(module, f'side_{adapter_name}')
122
+ for _module in modules:
123
+ _module: ActivationMixin
124
+ _module: nn.Module
125
+ _module.set_activation(adapter_name, activate)
126
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
127
+
128
+
129
+ class SideModule(nn.Module, ActivationMixin):
130
+ """The implementation of vision side-tuning method.
131
+
132
+ Side-Tuning only needs to train one side network and
133
+ weights the output of pre-trained model and side network.
134
+ 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
135
+ by Zhang et al.(2019)
136
+ See https://arxiv.org/abs/1912.13503
137
+
138
+ Args:
139
+ side_module_name: The name of the additive side networks.
140
+ """
141
+
142
+ def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'):
143
+ super(SideModule, self).__init__()
144
+ super(nn.Module, self).__init__(module_key)
145
+ self.adapter_name = adapter_name
146
+
147
+ side_module_name = side_module_name.lower()
148
+ if side_module_name == 'fcn4':
149
+ self.side_net = FCN4(out_dims=dim)
150
+ elif side_module_name == 'mlp':
151
+ self.side_net = Mlp(dim)
152
+ elif side_module_name == 'alexnet':
153
+ import torchvision
154
+ mm = torchvision.models.alexnet(pretrained=True)
155
+ self.side_net = nn.Sequential(
156
+ OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()),
157
+ ('fc', nn.Linear(9216, dim, bias=False))]))
158
+ else:
159
+ raise ValueError(f'Unsupported side_module_name: {side_module_name}')
160
+ self.alpha = nn.Parameter(torch.tensor(0.0))
161
+ self.mark_all_sub_modules_as_plugin()
162
+
163
+ def forward(self, x, x_main):
164
+ if not self.is_activated(self.adapter_name):
165
+ return x_main
166
+ alpha_squashed = torch.sigmoid(self.alpha)
167
+ x_side = self.side_net(x)
168
+ x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side
169
+ return x_out
170
+
171
+
172
+ class FCN4(nn.Module):
173
+ """The implementation of simple FCN4 network for side network.
174
+ """
175
+
176
+ def __init__(self, out_dims=-1, **kwargs):
177
+ super(FCN4, self).__init__(**kwargs)
178
+
179
+ self.conv1 = nn.Sequential(
180
+ nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16),
181
+ nn.ReLU())
182
+ self.conv2 = nn.Sequential(
183
+ nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16),
184
+ nn.ReLU())
185
+ self.conv3 = nn.Sequential(
186
+ nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32),
187
+ nn.ReLU())
188
+ self.conv4 = nn.Sequential(
189
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64),
190
+ nn.ReLU())
191
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
192
+ if out_dims > 0:
193
+ self.fc = nn.Linear(64, out_dims)
194
+ else:
195
+ self.fc = None
196
+
197
+ def forward(self, x):
198
+ x = self.conv1(x)
199
+ x = self.conv2(x)
200
+ x = self.conv3(x)
201
+ x = self.conv4(x)
202
+ x = self.pool(x)
203
+ x = x.view(x.size(0), -1)
204
+ if self.fc is not None:
205
+ x = self.fc(x)
206
+ return x
207
+
208
+
209
+ class Mlp(nn.Module):
210
+ """ MLP as used in Vision Transformer.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ in_features,
216
+ hidden_features=None,
217
+ out_features=None,
218
+ act_layer=nn.GELU,
219
+ norm_layer=None,
220
+ bias=True,
221
+ drop=0.,
222
+ use_conv=False,
223
+ ):
224
+ super().__init__()
225
+ out_features = out_features or in_features
226
+ hidden_features = hidden_features or in_features
227
+ bias = tuple(repeat(bias, 2))
228
+ drop_probs = tuple(repeat(drop, 2))
229
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
230
+
231
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
232
+ self.act = act_layer()
233
+ self.drop1 = nn.Dropout(drop_probs[0])
234
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
235
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
236
+ self.drop2 = nn.Dropout(drop_probs[1])
237
+
238
+ def forward(self, x):
239
+ x = self.fc1(x)
240
+ x = self.act(x)
241
+ x = self.drop1(x)
242
+ x = self.norm(x)
243
+ x = self.fc2(x)
244
+ x = self.drop2(x)
245
+ return x
ms-swift/swift/ui/app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from functools import partial
4
+ from typing import List, Union
5
+
6
+ import gradio as gr
7
+ from packaging import version
8
+ from transformers.utils import strtobool
9
+
10
+ import swift
11
+ from swift.llm import DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SwiftPipeline, WebUIArguments
12
+ from swift.ui.llm_eval.llm_eval import LLMEval
13
+ from swift.ui.llm_export.llm_export import LLMExport
14
+ from swift.ui.llm_infer.llm_infer import LLMInfer
15
+ from swift.ui.llm_train.llm_train import LLMTrain
16
+
17
+ locale_dict = {
18
+ 'title': {
19
+ 'zh': '🚀SWIFT: 轻量级大模型训练推理框架',
20
+ 'en': '🚀SWIFT: Scalable lightWeight Infrastructure for Fine-Tuning and Inference'
21
+ },
22
+ 'sub_title': {
23
+ 'zh':
24
+ '请查看 <a href=\"https://github.com/modelscope/swift/tree/main/docs/source\" target=\"_blank\">'
25
+ 'SWIFT 文档</a>来查看更多功能,使用SWIFT_UI_LANG=en环境变量来切换英文界面',
26
+ 'en':
27
+ 'Please check <a href=\"https://github.com/modelscope/swift/tree/main/docs/source_en\" target=\"_blank\">'
28
+ 'SWIFT Documentation</a> for more usages, Use SWIFT_UI_LANG=zh variable to switch to Chinese UI',
29
+ },
30
+ 'star_beggar': {
31
+ 'zh':
32
+ '喜欢<a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>就动动手指给我们加个star吧🥺 ',
33
+ 'en':
34
+ 'If you like <a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>, '
35
+ 'please take a few seconds to star us🥺 '
36
+ },
37
+ }
38
+
39
+
40
+ class SwiftWebUI(SwiftPipeline):
41
+
42
+ args_class = WebUIArguments
43
+ args: args_class
44
+
45
+ def run(self):
46
+ lang = os.environ.get('SWIFT_UI_LANG') or self.args.lang
47
+ share_env = os.environ.get('WEBUI_SHARE')
48
+ share = strtobool(share_env) if share_env else self.args.share
49
+ server = os.environ.get('WEBUI_SERVER') or self.args.server_name
50
+ port_env = os.environ.get('WEBUI_PORT')
51
+ port = int(port_env) if port_env else self.args.server_port
52
+ LLMTrain.set_lang(lang)
53
+ LLMInfer.set_lang(lang)
54
+ LLMExport.set_lang(lang)
55
+ LLMEval.set_lang(lang)
56
+ with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app:
57
+ try:
58
+ _version = swift.__version__
59
+ except AttributeError:
60
+ _version = ''
61
+ gr.HTML(f"<h1><center>{locale_dict['title'][lang]}({_version})</center></h1>")
62
+ gr.HTML(f"<h3><center>{locale_dict['sub_title'][lang]}</center></h3>")
63
+ with gr.Tabs():
64
+ LLMTrain.build_ui(LLMTrain)
65
+ LLMInfer.build_ui(LLMInfer)
66
+ LLMExport.build_ui(LLMExport)
67
+ LLMEval.build_ui(LLMEval)
68
+
69
+ concurrent = {}
70
+ if version.parse(gr.__version__) < version.parse('4.0.0'):
71
+ concurrent = {'concurrency_count': 5}
72
+ app.load(
73
+ partial(LLMTrain.update_input_model, arg_cls=RLHFArguments),
74
+ inputs=[LLMTrain.element('model')],
75
+ outputs=[LLMTrain.element('train_record')] + list(LLMTrain.valid_elements().values()))
76
+ app.load(
77
+ partial(LLMInfer.update_input_model, arg_cls=DeployArguments, has_record=False),
78
+ inputs=[LLMInfer.element('model')],
79
+ outputs=list(LLMInfer.valid_elements().values()))
80
+ app.load(
81
+ partial(LLMExport.update_input_model, arg_cls=ExportArguments, has_record=False),
82
+ inputs=[LLMExport.element('model')],
83
+ outputs=list(LLMExport.valid_elements().values()))
84
+ app.load(
85
+ partial(LLMEval.update_input_model, arg_cls=EvalArguments, has_record=False),
86
+ inputs=[LLMEval.element('model')],
87
+ outputs=list(LLMEval.valid_elements().values()))
88
+ app.queue(**concurrent).launch(server_name=server, inbrowser=True, server_port=port, height=800, share=share)
89
+
90
+
91
+ def webui_main(args: Union[List[str], WebUIArguments, None] = None):
92
+ return SwiftWebUI(args).main()
ms-swift/swift/ui/base.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import dataclasses
3
+ import os
4
+ import sys
5
+ import time
6
+ import typing
7
+ from collections import OrderedDict
8
+ from dataclasses import fields
9
+ from datetime import datetime
10
+ from functools import wraps
11
+ from typing import Any, Dict, List, Type
12
+
13
+ import gradio as gr
14
+ import json
15
+ from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video
16
+ from modelscope.hub.utils.utils import get_cache_dir
17
+
18
+ from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta
19
+
20
+ all_langs = ['zh', 'en']
21
+ builder: Type['BaseUI'] = None
22
+ base_builder: Type['BaseUI'] = None
23
+
24
+
25
+ def update_data(fn):
26
+
27
+ @wraps(fn)
28
+ def wrapper(*args, **kwargs):
29
+ elem_id = kwargs.get('elem_id', None)
30
+ self = args[0]
31
+
32
+ if builder is not None:
33
+ choices = base_builder.choice(elem_id)
34
+ if choices:
35
+ choices = [str(choice) if choice is not None else None for choice in choices]
36
+ kwargs['choices'] = choices
37
+
38
+ if not isinstance(self, (Tab, TabItem, Accordion)) and 'interactive' not in kwargs: # noqa
39
+ kwargs['interactive'] = True
40
+
41
+ if 'is_list' in kwargs:
42
+ self.is_list = kwargs.pop('is_list')
43
+
44
+ if base_builder and base_builder.default(elem_id) is not None and not kwargs.get('value'):
45
+ kwargs['value'] = base_builder.default(elem_id)
46
+
47
+ if builder is not None:
48
+ if elem_id in builder.locales(builder.lang):
49
+ values = builder.locale(elem_id, builder.lang)
50
+ if 'info' in values:
51
+ kwargs['info'] = values['info']
52
+ if 'value' in values:
53
+ kwargs['value'] = values['value']
54
+ if 'label' in values:
55
+ kwargs['label'] = values['label']
56
+ if hasattr(builder, 'visible'):
57
+ kwargs['visible'] = builder.visible
58
+ argument = base_builder.argument(elem_id)
59
+ if argument and 'label' in kwargs:
60
+ kwargs['label'] = kwargs['label'] + f'({argument})'
61
+
62
+ kwargs['elem_classes'] = 'align'
63
+ ret = fn(self, **kwargs)
64
+ self.constructor_args.update(kwargs)
65
+
66
+ if builder is not None:
67
+ builder.element_dict[elem_id] = self
68
+ return ret
69
+
70
+ return wrapper
71
+
72
+
73
+ Textbox.__init__ = update_data(Textbox.__init__)
74
+ Dropdown.__init__ = update_data(Dropdown.__init__)
75
+ Checkbox.__init__ = update_data(Checkbox.__init__)
76
+ Slider.__init__ = update_data(Slider.__init__)
77
+ TabItem.__init__ = update_data(TabItem.__init__)
78
+ Accordion.__init__ = update_data(Accordion.__init__)
79
+ Button.__init__ = update_data(Button.__init__)
80
+ File.__init__ = update_data(File.__init__)
81
+ Image.__init__ = update_data(Image.__init__)
82
+ Video.__init__ = update_data(Video.__init__)
83
+ Audio.__init__ = update_data(Audio.__init__)
84
+
85
+
86
+ class BaseUI:
87
+
88
+ choice_dict: Dict[str, List] = {}
89
+ default_dict: Dict[str, Any] = {}
90
+ locale_dict: Dict[str, Dict] = {}
91
+ element_dict: Dict[str, Dict] = {}
92
+ arguments: Dict[str, str] = {}
93
+ sub_ui: List[Type['BaseUI']] = []
94
+ group: str = None
95
+ lang: str = all_langs[0]
96
+ int_regex = r'^[-+]?[0-9]+$'
97
+ float_regex = r'[-+]?(?:\d*\.*\d+)'
98
+ bool_regex = r'^(T|t)rue$|^(F|f)alse$'
99
+ cache_dir = os.path.join(get_cache_dir(), 'swift-web-ui')
100
+ os.makedirs(cache_dir, exist_ok=True)
101
+ quote = '\'' if sys.platform != 'win32' else '"'
102
+ visible = True
103
+ _locale = {
104
+ 'local_dir_alert': {
105
+ 'value': {
106
+ 'zh': '无法识别model_type和template,请手动选择',
107
+ 'en': 'Cannot recognize the model_type and template, please choose manually'
108
+ }
109
+ },
110
+ }
111
+
112
+ @classmethod
113
+ def build_ui(cls, base_tab: Type['BaseUI']):
114
+ """Build UI"""
115
+ global builder, base_builder
116
+ cls.element_dict = {}
117
+ old_builder = builder
118
+ old_base_builder = base_builder
119
+ builder = cls
120
+ base_builder = base_tab
121
+ cls.do_build_ui(base_tab)
122
+ builder = old_builder
123
+ base_builder = old_base_builder
124
+ if cls is base_tab:
125
+ for ui in cls.sub_ui:
126
+ ui.after_build_ui(base_tab)
127
+
128
+ @classmethod
129
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
130
+ pass
131
+
132
+ @classmethod
133
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
134
+ """Build UI"""
135
+ pass
136
+
137
+ @classmethod
138
+ def save_cache(cls, key, value):
139
+ timestamp = str(int(time.time()))
140
+ key = key.replace('/', '-')
141
+ filename = os.path.join(cls.cache_dir, key + '-' + timestamp)
142
+ with open(filename, 'w', encoding='utf-8') as f:
143
+ json.dump(value, f)
144
+
145
+ @classmethod
146
+ def list_cache(cls, key):
147
+ files = []
148
+ key = key.replace('/', '-')
149
+ for _, _, filenames in os.walk(cls.cache_dir):
150
+ for filename in filenames:
151
+ if filename.startswith(key):
152
+ idx = filename.rfind('-')
153
+ key, ts = filename[:idx], filename[idx + 1:]
154
+ dt_object = datetime.fromtimestamp(int(ts))
155
+ formatted_time = dt_object.strftime('%Y/%m/%d %H:%M:%S')
156
+ files.append(formatted_time)
157
+ return sorted(files, reverse=True)
158
+
159
+ @classmethod
160
+ def load_cache(cls, key, timestamp) -> BaseArguments:
161
+ dt_object = datetime.strptime(timestamp, '%Y/%m/%d %H:%M:%S')
162
+ timestamp = int(dt_object.timestamp())
163
+ key = key.replace('/', '-')
164
+ filename = key + '-' + str(timestamp)
165
+ with open(os.path.join(cls.cache_dir, filename), 'r', encoding='utf-8') as f:
166
+ return json.load(f)
167
+
168
+ @classmethod
169
+ def clear_cache(cls, key):
170
+ key = key.replace('/', '-')
171
+ for _, _, filenames in os.walk(cls.cache_dir):
172
+ for filename in filenames:
173
+ if filename.startswith(key):
174
+ os.remove(os.path.join(cls.cache_dir, filename))
175
+
176
+ @classmethod
177
+ def choice(cls, elem_id):
178
+ """Get choice by elem_id"""
179
+ for sub_ui in BaseUI.sub_ui:
180
+ _choice = sub_ui.choice(elem_id)
181
+ if _choice:
182
+ return _choice
183
+ return cls.choice_dict.get(elem_id, [])
184
+
185
+ @classmethod
186
+ def default(cls, elem_id):
187
+ """Get choice by elem_id"""
188
+ if elem_id in cls.default_dict:
189
+ return cls.default_dict.get(elem_id)
190
+ for sub_ui in BaseUI.sub_ui:
191
+ _choice = sub_ui.default(elem_id)
192
+ if _choice:
193
+ return _choice
194
+ return None
195
+
196
+ @classmethod
197
+ def locale(cls, elem_id, lang):
198
+ """Get locale by elem_id"""
199
+ return cls.locales(lang)[elem_id]
200
+
201
+ @classmethod
202
+ def locales(cls, lang):
203
+ """Get locale by lang"""
204
+ locales = OrderedDict()
205
+ for sub_ui in cls.sub_ui:
206
+ _locales = sub_ui.locales(lang)
207
+ locales.update(_locales)
208
+ for key, value in cls.locale_dict.items():
209
+ locales[key] = {k: v[lang] for k, v in value.items()}
210
+ return locales
211
+
212
+ @classmethod
213
+ def elements(cls):
214
+ """Get all elements"""
215
+ elements = OrderedDict()
216
+ elements.update(cls.element_dict)
217
+ for sub_ui in cls.sub_ui:
218
+ _elements = sub_ui.elements()
219
+ elements.update(_elements)
220
+ return elements
221
+
222
+ @classmethod
223
+ def valid_elements(cls):
224
+ valid_elements = OrderedDict()
225
+ elements = cls.elements()
226
+ for key, value in elements.items():
227
+ if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record':
228
+ valid_elements[key] = value
229
+ return valid_elements
230
+
231
+ @classmethod
232
+ def element_keys(cls):
233
+ return list(cls.elements().keys())
234
+
235
+ @classmethod
236
+ def valid_element_keys(cls):
237
+ return [
238
+ key for key, value in cls.elements().items()
239
+ if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record'
240
+ ]
241
+
242
+ @classmethod
243
+ def element(cls, elem_id):
244
+ """Get element by elem_id"""
245
+ elements = cls.elements()
246
+ return elements[elem_id]
247
+
248
+ @classmethod
249
+ def argument(cls, elem_id):
250
+ """Get argument by elem_id"""
251
+ return cls.arguments.get(elem_id)
252
+
253
+ @classmethod
254
+ def set_lang(cls, lang):
255
+ cls.lang = lang
256
+ for sub_ui in cls.sub_ui:
257
+ sub_ui.lang = lang
258
+
259
+ @staticmethod
260
+ def get_choices_from_dataclass(dataclass):
261
+ choice_dict = {}
262
+ for f in fields(dataclass):
263
+ default_value = f.default
264
+ if 'MISSING_TYPE' in str(default_value):
265
+ default_value = None
266
+ if 'choices' in f.metadata:
267
+ choice_dict[f.name] = list(f.metadata['choices'])
268
+ if 'Literal' in str(f.type) and typing.get_args(f.type):
269
+ choice_dict[f.name] = list(typing.get_args(f.type))
270
+ if f.name in choice_dict and default_value not in choice_dict[f.name]:
271
+ choice_dict[f.name].insert(0, default_value)
272
+ return choice_dict
273
+
274
+ @staticmethod
275
+ def get_default_value_from_dataclass(dataclass):
276
+ default_dict = {}
277
+ for f in fields(dataclass):
278
+ if f.default.__class__ is dataclasses._MISSING_TYPE:
279
+ default_dict[f.name] = f.default_factory()
280
+ else:
281
+ default_dict[f.name] = f.default
282
+ if isinstance(default_dict[f.name], list):
283
+ try:
284
+ default_dict[f.name] = ' '.join(default_dict[f.name])
285
+ except TypeError:
286
+ default_dict[f.name] = None
287
+ if not default_dict[f.name]:
288
+ default_dict[f.name] = None
289
+ return default_dict
290
+
291
+ @staticmethod
292
+ def get_argument_names(dataclass):
293
+ arguments = {}
294
+ for f in fields(dataclass):
295
+ arguments[f.name] = f'--{f.name}'
296
+ return arguments
297
+
298
+ @classmethod
299
+ def update_input_model(cls, model, allow_keys=None, has_record=True, arg_cls=BaseArguments, is_ref_model=False):
300
+ keys = cls.valid_element_keys()
301
+ if allow_keys:
302
+ keys = [key for key in keys if key in allow_keys]
303
+
304
+ if not model:
305
+ ret = [gr.update()] * (len(keys) + int(has_record))
306
+ if len(ret) == 1:
307
+ return ret[0]
308
+ else:
309
+ return ret
310
+
311
+ model_meta = get_matched_model_meta(model)
312
+ local_args_path = os.path.join(model, 'args.json')
313
+ if model_meta is None and not os.path.exists(local_args_path):
314
+ gr.Info(cls._locale['local_dir_alert']['value'][cls.lang])
315
+ ret = [gr.update()] * (len(keys) + int(has_record))
316
+ if len(ret) == 1:
317
+ return ret[0]
318
+ else:
319
+ return ret
320
+
321
+ if os.path.exists(local_args_path):
322
+ try:
323
+ if hasattr(arg_cls, 'resume_from_checkpoint'):
324
+ try:
325
+ args = arg_cls(resume_from_checkpoint=model, load_data_args=True)
326
+ except Exception as e:
327
+ if 'using `--model`' in str(e): # TODO a dirty fix
328
+ args = arg_cls(model=model, load_data_args=True)
329
+ else:
330
+ raise e
331
+ else:
332
+ args = arg_cls(ckpt_dir=model, load_data_args=True)
333
+ except ValueError:
334
+ return [gr.update()] * (len(keys) + int(has_record))
335
+ values = []
336
+ for key in keys:
337
+ arg_value = getattr(args, key, None)
338
+ if arg_value and key != 'model':
339
+ if key in ('torch_dtype', 'bnb_4bit_compute_dtype'):
340
+ arg_value = str(arg_value).split('.')[1]
341
+ if isinstance(arg_value, list) and key != 'dataset':
342
+ try:
343
+ arg_value = ' '.join(arg_value)
344
+ except Exception:
345
+ arg_value = None
346
+ values.append(gr.update(value=arg_value))
347
+ else:
348
+ values.append(gr.update())
349
+ ret = [gr.update(choices=[])] * int(has_record) + values
350
+ if len(ret) == 1:
351
+ return ret[0]
352
+ else:
353
+ return ret
354
+ else:
355
+ values = []
356
+ for key in keys:
357
+ if key not in ('template', 'model_type', 'ref_model_type', 'system'):
358
+ values.append(gr.update())
359
+ elif key in ('template', 'model_type', 'ref_model_type'):
360
+ if key == 'ref_model_type':
361
+ if is_ref_model:
362
+ values.append(gr.update(value=getattr(model_meta, 'model_type')))
363
+ else:
364
+ values.append(gr.update())
365
+ else:
366
+ values.append(gr.update(value=getattr(model_meta, key)))
367
+ else:
368
+ values.append(gr.update(value=TEMPLATE_MAPPING[model_meta.template].default_system))
369
+
370
+ if has_record:
371
+ return [gr.update(choices=cls.list_cache(model))] + values
372
+ else:
373
+ if len(values) == 1:
374
+ return values[0]
375
+ return values
376
+
377
+ @classmethod
378
+ def update_all_settings(cls, model, train_record, base_tab):
379
+ if not train_record:
380
+ return [gr.update()] * len(cls.elements())
381
+ cache = cls.load_cache(model, train_record)
382
+ updates = []
383
+ for key, value in base_tab.valid_elements().items():
384
+ if key in cache:
385
+ updates.append(gr.update(value=cache[key]))
386
+ else:
387
+ updates.append(gr.update())
388
+ return updates
ms-swift/swift/ui/llm_eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
ms-swift/swift/ui/llm_eval/eval.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+ from swift.utils import get_logger
8
+
9
+ logger = get_logger()
10
+
11
+
12
+ class Eval(BaseUI):
13
+
14
+ group = 'llm_eval'
15
+
16
+ locale_dict = {
17
+ 'eval_backend': {
18
+ 'label': {
19
+ 'zh': '评测后端',
20
+ 'en': 'Eval backend'
21
+ },
22
+ 'info': {
23
+ 'zh': '选择评测后端',
24
+ 'en': 'Select eval backend'
25
+ }
26
+ },
27
+ 'eval_dataset': {
28
+ 'label': {
29
+ 'zh': '评测数据集',
30
+ 'en': 'Evaluation dataset'
31
+ },
32
+ 'info': {
33
+ 'zh': '选择评测数据集,支持多选 (先选择评测后端)',
34
+ 'en': 'Select eval dataset, multiple datasets supported (select eval backend first)'
35
+ }
36
+ },
37
+ 'eval_limit': {
38
+ 'label': {
39
+ 'zh': '评测数据个数',
40
+ 'en': 'Eval numbers for each dataset'
41
+ },
42
+ 'info': {
43
+ 'zh': '每个评测集的取样数',
44
+ 'en': 'Number of rows sampled from each dataset'
45
+ }
46
+ },
47
+ 'eval_output_dir': {
48
+ 'label': {
49
+ 'zh': '评测输出目录',
50
+ 'en': 'Eval output dir'
51
+ },
52
+ 'info': {
53
+ 'zh': '评测结果的输出目录',
54
+ 'en': 'The dir to save the eval results'
55
+ }
56
+ },
57
+ 'custom_eval_config': {
58
+ 'label': {
59
+ 'zh': '自定义数据集评测配置',
60
+ 'en': 'Custom eval config'
61
+ },
62
+ 'info': {
63
+ 'zh': '可以使用该配置评测自己的数据集,详见github文档的评测部分',
64
+ 'en': 'Use this config to eval your own datasets, check the docs in github for details'
65
+ }
66
+ },
67
+ 'eval_url': {
68
+ 'label': {
69
+ 'zh': '评测链接',
70
+ 'en': 'The eval url'
71
+ },
72
+ 'info': {
73
+ 'zh':
74
+ 'OpenAI样式的评测链接(如:http://localhost:8080/v1/chat/completions),用于评测接口(模型类型输入为实际模型类型)',
75
+ 'en':
76
+ 'The OpenAI style link(like: http://localhost:8080/v1/chat/completions) for '
77
+ 'evaluation(Input actual model type into model_type)'
78
+ }
79
+ },
80
+ 'api_key': {
81
+ 'label': {
82
+ 'zh': '接口token',
83
+ 'en': 'The url token'
84
+ },
85
+ 'info': {
86
+ 'zh': 'eval_url的token',
87
+ 'en': 'The token used with eval_url'
88
+ }
89
+ },
90
+ 'infer_backend': {
91
+ 'label': {
92
+ 'zh': '推理框架',
93
+ 'en': 'Infer backend'
94
+ },
95
+ }
96
+ }
97
+
98
+ @classmethod
99
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
100
+ try:
101
+ from swift.llm.argument.eval_args import EvalArguments
102
+ eval_dataset_dict = EvalArguments.list_eval_dataset()
103
+ default_backend = EvalArguments.eval_backend
104
+ except Exception as e:
105
+ logger.warn(e)
106
+ eval_dataset_dict = {}
107
+ default_backend = None
108
+
109
+ with gr.Row():
110
+ gr.Dropdown(elem_id='eval_backend', choices=list(eval_dataset_dict.keys()), value=default_backend, scale=20)
111
+ gr.Dropdown(
112
+ elem_id='eval_dataset',
113
+ is_list=True,
114
+ choices=eval_dataset_dict.get(default_backend, []),
115
+ multiselect=True,
116
+ allow_custom_value=True,
117
+ scale=20)
118
+ gr.Textbox(elem_id='eval_limit', scale=20)
119
+ gr.Dropdown(elem_id='infer_backend', scale=20)
120
+ with gr.Row():
121
+ gr.Textbox(elem_id='custom_eval_config', scale=20)
122
+ gr.Textbox(elem_id='eval_output_dir', scale=20)
123
+ gr.Textbox(elem_id='eval_url', scale=20)
124
+ gr.Textbox(elem_id='api_key', scale=20)
125
+
126
+ def update_eval_dataset(backend):
127
+ return gr.update(choices=eval_dataset_dict[backend])
128
+
129
+ cls.element('eval_backend').change(update_eval_dataset, [cls.element('eval_backend')],
130
+ [cls.element('eval_dataset')])
ms-swift/swift/ui/llm_eval/model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Type
4
+
5
+ import gradio as gr
6
+
7
+ from swift.llm import TEMPLATE_MAPPING, EvalArguments, ModelType
8
+ from swift.llm.model.register import get_all_models
9
+ from swift.ui.base import BaseUI
10
+
11
+
12
+ class Model(BaseUI):
13
+
14
+ group = 'llm_eval'
15
+
16
+ locale_dict = {
17
+ 'checkpoint': {
18
+ 'value': {
19
+ 'zh': '训练后的模型',
20
+ 'en': 'Trained model'
21
+ }
22
+ },
23
+ 'model_type': {
24
+ 'label': {
25
+ 'zh': '选择模型类型',
26
+ 'en': 'Select Model Type'
27
+ },
28
+ 'info': {
29
+ 'zh': 'SWIFT已支持的模型类型',
30
+ 'en': 'Base model type supported by SWIFT'
31
+ }
32
+ },
33
+ 'model': {
34
+ 'label': {
35
+ 'zh': '模型id或路径',
36
+ 'en': 'Model id or path'
37
+ },
38
+ 'info': {
39
+ 'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
40
+ 'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
41
+ }
42
+ },
43
+ 'reset': {
44
+ 'value': {
45
+ 'zh': '恢复初始值',
46
+ 'en': 'Reset to default'
47
+ },
48
+ },
49
+ 'template': {
50
+ 'label': {
51
+ 'zh': '模型Prompt模板类型',
52
+ 'en': 'Prompt template type'
53
+ },
54
+ 'info': {
55
+ 'zh': '选择匹配模型的Prompt模板',
56
+ 'en': 'Choose the template type of the model'
57
+ }
58
+ },
59
+ }
60
+
61
+ @classmethod
62
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
63
+ with gr.Row():
64
+ gr.Dropdown(
65
+ elem_id='model',
66
+ scale=20,
67
+ choices=get_all_models(),
68
+ value='Qwen/Qwen2.5-7B-Instruct',
69
+ allow_custom_value=True)
70
+ gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
71
+ gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
72
+
73
+ @classmethod
74
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
75
+ cls.element('model').change(
76
+ partial(cls.update_input_model, arg_cls=EvalArguments, has_record=False),
77
+ inputs=[cls.element('model')],
78
+ outputs=list(cls.valid_elements().values()))
ms-swift/swift/ui/llm_export/llm_export.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import re
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from functools import partial
8
+ from typing import Type
9
+
10
+ import gradio as gr
11
+ import json
12
+ import torch
13
+ from json import JSONDecodeError
14
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
15
+
16
+ from swift.llm import ExportArguments
17
+ from swift.ui.base import BaseUI
18
+ from swift.ui.llm_export.export import Export
19
+ from swift.ui.llm_export.model import Model
20
+ from swift.ui.llm_export.runtime import ExportRuntime
21
+ from swift.utils import get_device_count
22
+
23
+
24
+ class LLMExport(BaseUI):
25
+ group = 'llm_export'
26
+
27
+ sub_ui = [Model, Export, ExportRuntime]
28
+
29
+ locale_dict = {
30
+ 'llm_export': {
31
+ 'label': {
32
+ 'zh': 'LLM导出',
33
+ 'en': 'LLM export',
34
+ }
35
+ },
36
+ 'more_params': {
37
+ 'label': {
38
+ 'zh': '更多参数',
39
+ 'en': 'More params'
40
+ },
41
+ 'info': {
42
+ 'zh': '以json格式或--xxx xxx命令行格式填入',
43
+ 'en': 'Fill in with json format or --xxx xxx cmd format'
44
+ }
45
+ },
46
+ 'export': {
47
+ 'value': {
48
+ 'zh': '开始导出',
49
+ 'en': 'Begin Export'
50
+ },
51
+ },
52
+ 'gpu_id': {
53
+ 'label': {
54
+ 'zh': '选择可用GPU',
55
+ 'en': 'Choose GPU'
56
+ },
57
+ 'info': {
58
+ 'zh': '选择使用的GPU号,如CUDA不可用只能选择CPU',
59
+ 'en': 'Select GPU to export'
60
+ }
61
+ },
62
+ }
63
+
64
+ choice_dict = BaseUI.get_choices_from_dataclass(ExportArguments)
65
+ default_dict = BaseUI.get_default_value_from_dataclass(ExportArguments)
66
+ arguments = BaseUI.get_argument_names(ExportArguments)
67
+
68
+ @classmethod
69
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
70
+ with gr.TabItem(elem_id='llm_export', label=''):
71
+ default_device = 'cpu'
72
+ device_count = get_device_count()
73
+ if device_count > 0:
74
+ default_device = '0'
75
+ with gr.Blocks():
76
+ Model.build_ui(base_tab)
77
+ Export.build_ui(base_tab)
78
+ ExportRuntime.build_ui(base_tab)
79
+ with gr.Row():
80
+ gr.Textbox(elem_id='more_params', lines=4, scale=20)
81
+ gr.Button(elem_id='export', scale=2, variant='primary')
82
+ gr.Dropdown(
83
+ elem_id='gpu_id',
84
+ multiselect=True,
85
+ choices=[str(i) for i in range(device_count)] + ['cpu'],
86
+ value=default_device,
87
+ scale=8)
88
+
89
+ cls.element('export').click(
90
+ cls.export_model, list(base_tab.valid_elements().values()),
91
+ [cls.element('runtime_tab'), cls.element('running_tasks')])
92
+
93
+ base_tab.element('running_tasks').change(
94
+ partial(ExportRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
95
+ list(base_tab.valid_elements().values()) + [cls.element('log')])
96
+ ExportRuntime.element('kill_task').click(
97
+ ExportRuntime.kill_task,
98
+ [ExportRuntime.element('running_tasks')],
99
+ [ExportRuntime.element('running_tasks')] + [ExportRuntime.element('log')],
100
+ )
101
+
102
+ @classmethod
103
+ def export(cls, *args):
104
+ export_args = cls.get_default_value_from_dataclass(ExportArguments)
105
+ kwargs = {}
106
+ kwargs_is_list = {}
107
+ other_kwargs = {}
108
+ more_params = {}
109
+ more_params_cmd = ''
110
+ keys = cls.valid_element_keys()
111
+ for key, value in zip(keys, args):
112
+ compare_value = export_args.get(key)
113
+ compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
114
+ compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
115
+ if key in export_args and compare_value_ui != compare_value_arg and value:
116
+ if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
117
+ value = int(value)
118
+ elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
119
+ value = float(value)
120
+ elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
121
+ value = True if value.lower() == 'true' else False
122
+ kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
123
+ kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
124
+ else:
125
+ other_kwargs[key] = value
126
+ if key == 'more_params' and value:
127
+ try:
128
+ more_params = json.loads(value)
129
+ except (JSONDecodeError or TypeError):
130
+ more_params_cmd = value
131
+
132
+ kwargs.update(more_params)
133
+ model = kwargs.get('model')
134
+ if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
135
+ kwargs['ckpt_dir'] = kwargs.pop('model')
136
+ export_args = ExportArguments(
137
+ **{
138
+ key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value
139
+ for key, value in kwargs.items()
140
+ })
141
+ params = ''
142
+ sep = f'{cls.quote} {cls.quote}'
143
+ for e in kwargs:
144
+ if isinstance(kwargs[e], list):
145
+ params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
146
+ elif e in kwargs_is_list and kwargs_is_list[e]:
147
+ all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
148
+ params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
149
+ else:
150
+ params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
151
+ params += more_params_cmd + ' '
152
+ devices = other_kwargs['gpu_id']
153
+ devices = [d for d in devices if d]
154
+ assert (len(devices) == 1 or 'cpu' not in devices)
155
+ gpus = ','.join(devices)
156
+ cuda_param = ''
157
+ if gpus != 'cpu':
158
+ if is_torch_npu_available():
159
+ cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
160
+ elif is_torch_cuda_available():
161
+ cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
162
+ else:
163
+ cuda_param = ''
164
+ now = datetime.now()
165
+ time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}'
166
+ file_path = f'output/{export_args.model_type}-{time_str}'
167
+ if not os.path.exists(file_path):
168
+ os.makedirs(file_path, exist_ok=True)
169
+ log_file = os.path.join(os.getcwd(), f'{file_path}/run_export.log')
170
+ export_args.log_file = log_file
171
+ params += f'--log_file "{log_file}" '
172
+ params += '--ignore_args_error true '
173
+ additional_param = ''
174
+ if export_args.quant_method == 'gptq':
175
+ additional_param = 'OMP_NUM_THREADS=14'
176
+ if sys.platform == 'win32':
177
+ if cuda_param:
178
+ cuda_param = f'set {cuda_param} && '
179
+ if additional_param:
180
+ additional_param = f'set {additional_param} && '
181
+ run_command = f'{cuda_param}{additional_param}start /b swift export {params} > {log_file} 2>&1'
182
+ else:
183
+ run_command = f'{cuda_param} {additional_param} nohup swift export {params} > {log_file} 2>&1 &'
184
+ return run_command, export_args, log_file
185
+
186
+ @classmethod
187
+ def export_model(cls, *args):
188
+ run_command, export_args, log_file = cls.export(*args)
189
+ os.system(run_command)
190
+ time.sleep(2)
191
+ return gr.update(open=True), ExportRuntime.refresh_tasks(log_file)
ms-swift/swift/ui/llm_export/model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Type
4
+
5
+ import gradio as gr
6
+
7
+ from swift.llm import TEMPLATE_MAPPING, ExportArguments, ModelType
8
+ from swift.llm.model.register import get_all_models
9
+ from swift.ui.base import BaseUI
10
+
11
+
12
+ class Model(BaseUI):
13
+
14
+ group = 'llm_export'
15
+
16
+ locale_dict = {
17
+ 'checkpoint': {
18
+ 'value': {
19
+ 'zh': '训练后的模型',
20
+ 'en': 'Trained model'
21
+ }
22
+ },
23
+ 'model_type': {
24
+ 'label': {
25
+ 'zh': '选择模型类型',
26
+ 'en': 'Select Model Type'
27
+ },
28
+ 'info': {
29
+ 'zh': 'SWIFT已支持的模型类型',
30
+ 'en': 'Base model type supported by SWIFT'
31
+ }
32
+ },
33
+ 'model': {
34
+ 'label': {
35
+ 'zh': '模型id或路径',
36
+ 'en': 'Model id or path'
37
+ },
38
+ 'info': {
39
+ 'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
40
+ 'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
41
+ }
42
+ },
43
+ 'reset': {
44
+ 'value': {
45
+ 'zh': '恢复初始值',
46
+ 'en': 'Reset to default'
47
+ },
48
+ },
49
+ 'template': {
50
+ 'label': {
51
+ 'zh': '模型Prompt模板类型',
52
+ 'en': 'Prompt template type'
53
+ },
54
+ 'info': {
55
+ 'zh': '选择匹配模型的Prompt模板',
56
+ 'en': 'Choose the template type of the model'
57
+ }
58
+ },
59
+ }
60
+
61
+ ignored_models = ['int1', 'int2', 'int4', 'int8', 'awq', 'gptq', 'bnb', 'eetq', 'aqlm', 'hqq']
62
+
63
+ @classmethod
64
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
65
+ with gr.Row():
66
+ all_models = [
67
+ model for model in get_all_models() if not any([ignored in model for ignored in cls.ignored_models])
68
+ ]
69
+ gr.Dropdown(
70
+ elem_id='model',
71
+ scale=20,
72
+ choices=all_models,
73
+ value='Qwen/Qwen2.5-7B-Instruct',
74
+ allow_custom_value=True)
75
+ gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
76
+ gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
77
+
78
+ @classmethod
79
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
80
+ cls.element('model').change(
81
+ partial(cls.update_input_model, arg_cls=ExportArguments, has_record=False),
82
+ inputs=[cls.element('model')],
83
+ outputs=list(cls.valid_elements().values()))
ms-swift/swift/ui/llm_export/runtime.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from swift.ui.llm_infer.runtime import Runtime
3
+ from swift.utils import get_logger
4
+
5
+ logger = get_logger()
6
+
7
+
8
+ class ExportRuntime(Runtime):
9
+
10
+ group = 'llm_export'
11
+
12
+ cmd = 'export'
13
+
14
+ locale_dict = {
15
+ 'runtime_tab': {
16
+ 'label': {
17
+ 'zh': '运行时',
18
+ 'en': 'Runtime'
19
+ },
20
+ },
21
+ 'running_cmd': {
22
+ 'label': {
23
+ 'zh': '运行命令',
24
+ 'en': 'Command line'
25
+ },
26
+ 'info': {
27
+ 'zh': '执行的实际命令',
28
+ 'en': 'The actual command'
29
+ }
30
+ },
31
+ 'show_log': {
32
+ 'value': {
33
+ 'zh': '展示导出状态',
34
+ 'en': 'Show export status'
35
+ },
36
+ },
37
+ 'stop_show_log': {
38
+ 'value': {
39
+ 'zh': '停止展示',
40
+ 'en': 'Stop showing running status'
41
+ },
42
+ },
43
+ 'log': {
44
+ 'label': {
45
+ 'zh': '日志输出',
46
+ 'en': 'Logging content'
47
+ },
48
+ 'info': {
49
+ 'zh': '如果日志无更新请再次点击"展示日志内容"',
50
+ 'en': 'Please press "Show log" if the log content is not updating'
51
+ }
52
+ },
53
+ 'running_tasks': {
54
+ 'label': {
55
+ 'zh': '运行中导出任务',
56
+ 'en': 'Running export task'
57
+ },
58
+ 'info': {
59
+ 'zh': '所有的swift export命令启动的任务',
60
+ 'en': 'All tasks started by swift export'
61
+ }
62
+ },
63
+ 'refresh_tasks': {
64
+ 'value': {
65
+ 'zh': '找回导出任务',
66
+ 'en': 'Find export'
67
+ },
68
+ },
69
+ 'kill_task': {
70
+ 'value': {
71
+ 'zh': '杀死导出任务',
72
+ 'en': 'Kill export'
73
+ },
74
+ },
75
+ }
ms-swift/swift/ui/llm_infer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
ms-swift/swift/ui/llm_infer/generate.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Generate(BaseUI):
10
+
11
+ group = 'llm_infer'
12
+
13
+ locale_dict = {
14
+ 'max_new_tokens': {
15
+ 'label': {
16
+ 'zh': '生成序列最大长度',
17
+ 'en': 'Max new tokens'
18
+ },
19
+ },
20
+ 'temperature': {
21
+ 'label': {
22
+ 'zh': 'temperature',
23
+ 'en': 'temperature'
24
+ },
25
+ },
26
+ 'top_k': {
27
+ 'label': {
28
+ 'zh': 'top_k',
29
+ 'en': 'top_k'
30
+ },
31
+ },
32
+ 'top_p': {
33
+ 'label': {
34
+ 'zh': 'top_p',
35
+ 'en': 'top_p'
36
+ },
37
+ },
38
+ 'repetition_penalty': {
39
+ 'label': {
40
+ 'zh': 'repetition_penalty',
41
+ 'en': 'repetition_penalty'
42
+ },
43
+ },
44
+ 'system': {
45
+ 'label': {
46
+ 'zh': 'system字段',
47
+ 'en': 'system'
48
+ },
49
+ 'info': {
50
+ 'zh': 'system字段支持在加载模型后修改',
51
+ 'en': 'system can be modified after the model weights loaded'
52
+ }
53
+ },
54
+ }
55
+
56
+ @classmethod
57
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
58
+ with gr.Row():
59
+ gr.Textbox(elem_id='max_new_tokens', lines=1, value='2048')
60
+ gr.Slider(elem_id='temperature', minimum=0.0, maximum=10, step=0.1, value=0.3)
61
+ gr.Slider(elem_id='top_k', minimum=1, maximum=100, step=5, value=20)
62
+ gr.Slider(elem_id='top_p', minimum=0.0, maximum=1.0, step=0.05, value=0.7)
63
+ gr.Slider(elem_id='repetition_penalty', minimum=0.0, maximum=10, step=0.05, value=1.05)
64
+ with gr.Row():
65
+ gr.Textbox(elem_id='system', lines=4, scale=20)
ms-swift/swift/ui/llm_infer/llm_infer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import re
4
+ import signal
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from functools import partial
10
+ from typing import List, Type
11
+
12
+ import gradio as gr
13
+ import json
14
+ import torch
15
+ from json import JSONDecodeError
16
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
17
+
18
+ from swift.llm import DeployArguments, InferArguments, InferClient, InferRequest, RequestConfig
19
+ from swift.ui.base import BaseUI
20
+ from swift.ui.llm_infer.model import Model
21
+ from swift.ui.llm_infer.runtime import Runtime
22
+ from swift.utils import get_device_count, get_logger
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class LLMInfer(BaseUI):
28
+
29
+ group = 'llm_infer'
30
+
31
+ is_multimodal = True
32
+
33
+ sub_ui = [Model, Runtime]
34
+
35
+ locale_dict = {
36
+ 'generate_alert': {
37
+ 'value': {
38
+ 'zh': '请先部署模型',
39
+ 'en': 'Please deploy model first',
40
+ }
41
+ },
42
+ 'port': {
43
+ 'label': {
44
+ 'zh': '端口',
45
+ 'en': 'port'
46
+ },
47
+ },
48
+ 'llm_infer': {
49
+ 'label': {
50
+ 'zh': 'LLM推理',
51
+ 'en': 'LLM Inference',
52
+ }
53
+ },
54
+ 'load_alert': {
55
+ 'value': {
56
+ 'zh': '部署中,请点击"展示部署状态"查看',
57
+ 'en': 'Start to deploy model, '
58
+ 'please Click "Show running '
59
+ 'status" to view details',
60
+ }
61
+ },
62
+ 'loaded_alert': {
63
+ 'value': {
64
+ 'zh': '模型加载完成',
65
+ 'en': 'Model loaded'
66
+ }
67
+ },
68
+ 'port_alert': {
69
+ 'value': {
70
+ 'zh': '该端口已被占用',
71
+ 'en': 'The port has been occupied'
72
+ }
73
+ },
74
+ 'chatbot': {
75
+ 'value': {
76
+ 'zh': '对话框',
77
+ 'en': 'Chat bot'
78
+ },
79
+ },
80
+ 'infer_model_type': {
81
+ 'label': {
82
+ 'zh': 'Lora模块',
83
+ 'en': 'Lora module'
84
+ },
85
+ 'info': {
86
+ 'zh': '发送给server端哪个LoRA,默认为`default`',
87
+ 'en': 'Which LoRA to use on server, default value is `default`'
88
+ }
89
+ },
90
+ 'prompt': {
91
+ 'label': {
92
+ 'zh': '请输入:',
93
+ 'en': 'Input:'
94
+ },
95
+ },
96
+ 'clear_history': {
97
+ 'value': {
98
+ 'zh': '清除对话信息',
99
+ 'en': 'Clear history'
100
+ },
101
+ },
102
+ 'submit': {
103
+ 'value': {
104
+ 'zh': '🚀 发送',
105
+ 'en': '🚀 Send'
106
+ },
107
+ },
108
+ 'gpu_id': {
109
+ 'label': {
110
+ 'zh': '选择可用GPU',
111
+ 'en': 'Choose GPU'
112
+ },
113
+ 'info': {
114
+ 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU',
115
+ 'en': 'Select GPU to train'
116
+ }
117
+ },
118
+ }
119
+
120
+ choice_dict = BaseUI.get_choices_from_dataclass(InferArguments)
121
+ default_dict = BaseUI.get_default_value_from_dataclass(InferArguments)
122
+ arguments = BaseUI.get_argument_names(InferArguments)
123
+
124
+ @classmethod
125
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
126
+ with gr.TabItem(elem_id='llm_infer', label=''):
127
+ default_device = 'cpu'
128
+ device_count = get_device_count()
129
+ if device_count > 0:
130
+ default_device = '0'
131
+ with gr.Blocks():
132
+ infer_request = gr.State(None)
133
+ Model.build_ui(base_tab)
134
+ Runtime.build_ui(base_tab)
135
+ with gr.Row():
136
+ gr.Dropdown(
137
+ elem_id='gpu_id',
138
+ multiselect=True,
139
+ choices=[str(i) for i in range(device_count)] + ['cpu'],
140
+ value=default_device,
141
+ scale=8)
142
+ infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
143
+ gr.Textbox(elem_id='port', lines=1, value='8000', scale=4)
144
+ chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
145
+ with gr.Row():
146
+ prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
147
+ with gr.Tabs(visible=cls.is_multimodal):
148
+ with gr.TabItem(label='Image'):
149
+ image = gr.Image(type='filepath')
150
+ with gr.TabItem(label='Video'):
151
+ video = gr.Video()
152
+ with gr.TabItem(label='Audio'):
153
+ audio = gr.Audio(type='filepath')
154
+
155
+ with gr.Row():
156
+ clear_history = gr.Button(elem_id='clear_history')
157
+ submit = gr.Button(elem_id='submit')
158
+
159
+ cls.element('load_checkpoint').click(
160
+ cls.deploy_model, list(base_tab.valid_elements().values()),
161
+ [cls.element('runtime_tab'), cls.element('running_tasks')])
162
+ submit.click(
163
+ cls.send_message,
164
+ inputs=[
165
+ cls.element('running_tasks'),
166
+ cls.element('template'), prompt, image, video, audio, infer_request, infer_model_type,
167
+ cls.element('system'),
168
+ cls.element('max_new_tokens'),
169
+ cls.element('temperature'),
170
+ cls.element('top_k'),
171
+ cls.element('top_p'),
172
+ cls.element('repetition_penalty')
173
+ ],
174
+ outputs=[prompt, chatbot, image, video, audio, infer_request],
175
+ queue=True)
176
+
177
+ clear_history.click(
178
+ fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot, image, video, audio, infer_request])
179
+
180
+ base_tab.element('running_tasks').change(
181
+ partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
182
+ list(cls.valid_elements().values()) + [cls.element('log')])
183
+ Runtime.element('kill_task').click(
184
+ Runtime.kill_task,
185
+ [Runtime.element('running_tasks')],
186
+ [Runtime.element('running_tasks')] + [Runtime.element('log')],
187
+ )
188
+
189
+ @classmethod
190
+ def deploy(cls, *args):
191
+ deploy_args = cls.get_default_value_from_dataclass(DeployArguments)
192
+ kwargs = {}
193
+ kwargs_is_list = {}
194
+ other_kwargs = {}
195
+ more_params = {}
196
+ more_params_cmd = ''
197
+ keys = cls.valid_element_keys()
198
+ for key, value in zip(keys, args):
199
+ compare_value = deploy_args.get(key)
200
+ compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
201
+ compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
202
+ if key in deploy_args and compare_value_ui != compare_value_arg and value:
203
+ if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
204
+ value = int(value)
205
+ elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
206
+ value = float(value)
207
+ elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
208
+ value = True if value.lower() == 'true' else False
209
+ kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
210
+ kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
211
+ else:
212
+ other_kwargs[key] = value
213
+ if key == 'more_params' and value:
214
+ try:
215
+ more_params = json.loads(value)
216
+ except (JSONDecodeError or TypeError):
217
+ more_params_cmd = value
218
+
219
+ kwargs.update(more_params)
220
+ model = kwargs.get('model')
221
+ if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
222
+ kwargs['ckpt_dir'] = kwargs.pop('model')
223
+ with open(os.path.join(kwargs['ckpt_dir'], 'args.json'), 'r', encoding='utf-8') as f:
224
+ _json = json.load(f)
225
+ kwargs['model_type'] = _json['model_type']
226
+ kwargs['train_type'] = _json['train_type']
227
+ deploy_args = DeployArguments(
228
+ **{
229
+ key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value
230
+ for key, value in kwargs.items()
231
+ })
232
+ if deploy_args.port in Runtime.get_all_ports():
233
+ raise gr.Error(cls.locale('port_alert', cls.lang)['value'])
234
+ params = ''
235
+ sep = f'{cls.quote} {cls.quote}'
236
+ for e in kwargs:
237
+ if isinstance(kwargs[e], list):
238
+ params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
239
+ elif e in kwargs_is_list and kwargs_is_list[e]:
240
+ all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
241
+ params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
242
+ else:
243
+ params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
244
+ if 'port' not in kwargs:
245
+ params += f'--port "{deploy_args.port}" '
246
+ params += more_params_cmd + ' '
247
+ devices = other_kwargs['gpu_id']
248
+ devices = [d for d in devices if d]
249
+ assert (len(devices) == 1 or 'cpu' not in devices)
250
+ gpus = ','.join(devices)
251
+ cuda_param = ''
252
+ if gpus != 'cpu':
253
+ if is_torch_npu_available():
254
+ cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
255
+ elif is_torch_cuda_available():
256
+ cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
257
+ else:
258
+ cuda_param = ''
259
+ now = datetime.now()
260
+ time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}'
261
+ file_path = f'output/{deploy_args.model_type}-{time_str}'
262
+ if not os.path.exists(file_path):
263
+ os.makedirs(file_path, exist_ok=True)
264
+ log_file = os.path.join(os.getcwd(), f'{file_path}/run_deploy.log')
265
+ deploy_args.log_file = log_file
266
+ params += f'--log_file "{log_file}" '
267
+ params += '--ignore_args_error true '
268
+ if sys.platform == 'win32':
269
+ if cuda_param:
270
+ cuda_param = f'set {cuda_param} && '
271
+ run_command = f'{cuda_param}start /b swift deploy {params} > {log_file} 2>&1'
272
+ else:
273
+ run_command = f'{cuda_param} nohup swift deploy {params} > {log_file} 2>&1 &'
274
+ return run_command, deploy_args, log_file
275
+
276
+ @classmethod
277
+ def deploy_model(cls, *args):
278
+ run_command, deploy_args, log_file = cls.deploy(*args)
279
+ logger.info(f'Running deployment command: {run_command}')
280
+ os.system(run_command)
281
+ gr.Info(cls.locale('load_alert', cls.lang)['value'])
282
+ time.sleep(2)
283
+ running_task = Runtime.refresh_tasks(log_file)
284
+ return gr.update(open=True), running_task
285
+
286
+ @classmethod
287
+ def register_clean_hook(cls):
288
+ signal.signal(signal.SIGINT, LLMInfer.signal_handler)
289
+ if os.name != 'nt':
290
+ signal.signal(signal.SIGTERM, LLMInfer.signal_handler)
291
+
292
+ @staticmethod
293
+ def signal_handler(*args, **kwargs):
294
+ LLMInfer.clean_deployment()
295
+ sys.exit(0)
296
+
297
+ @classmethod
298
+ def clear_session(cls):
299
+ return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), []
300
+
301
+ @classmethod
302
+ def _replace_tag_with_media(cls, infer_request: InferRequest):
303
+ total_history = []
304
+ messages = deepcopy(infer_request.messages)
305
+ if messages[0]['role'] == 'system':
306
+ messages.pop(0)
307
+ for i in range(0, len(messages), 2):
308
+ slices = messages[i:i + 2]
309
+ if len(slices) == 2:
310
+ user, assistant = slices
311
+ else:
312
+ user = slices[0]
313
+ assistant = {'role': 'assistant', 'content': None}
314
+ user['content'] = (user['content'] or '').replace('<image>', '').replace('<video>',
315
+ '').replace('<audio>', '').strip()
316
+ for media in user['medias']:
317
+ total_history.append([(media, ), None])
318
+ if user['content'] or assistant['content']:
319
+ total_history.append((user['content'], assistant['content']))
320
+ return total_history
321
+
322
+ @classmethod
323
+ def agent_type(cls, response):
324
+ if not response:
325
+ return None
326
+ if response.lower().endswith('observation:'):
327
+ return 'react'
328
+ if 'observation:' not in response.lower() and 'action input:' in response.lower():
329
+ return 'toolbench'
330
+ return None
331
+
332
+ @classmethod
333
+ def send_message(cls, running_task, template_type, prompt: str, image, video, audio, infer_request: InferRequest,
334
+ infer_model_type, system, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
335
+
336
+ if not infer_request:
337
+ infer_request = InferRequest(messages=[])
338
+ if system:
339
+ if not infer_request.messages or infer_request.messages[0]['role'] != 'system':
340
+ infer_request.messages.insert(0, {'role': 'system', 'content': system})
341
+ else:
342
+ infer_request.messages[0]['content'] = system
343
+ if not infer_request.messages or infer_request.messages[-1]['role'] != 'user':
344
+ infer_request.messages.append({'role': 'user', 'content': '', 'medias': []})
345
+ media = image or video or audio
346
+ media_type = 'images' if image else 'videos' if video else 'audios'
347
+ if media:
348
+ _saved_medias: List = getattr(infer_request, media_type)
349
+ if not _saved_medias or _saved_medias[-1] != media:
350
+ _saved_medias.append(media)
351
+ infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + f'<{media_type[:-1]}>'
352
+ infer_request.messages[-1]['medias'].append(media)
353
+
354
+ if not prompt:
355
+ yield '', cls._replace_tag_with_media(infer_request), gr.update(value=None), gr.update(
356
+ value=None), gr.update(value=None), infer_request
357
+ return
358
+ else:
359
+ infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt
360
+
361
+ _, args = Runtime.parse_info_from_cmdline(running_task)
362
+ request_config = RequestConfig(
363
+ temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
364
+ request_config.stream = True
365
+ request_config.stop = ['Observation:']
366
+ request_config.max_tokens = max_new_tokens
367
+ stream_resp_with_history = ''
368
+ response = ''
369
+ i = len(infer_request.messages) - 1
370
+ for i in range(len(infer_request.messages) - 1, -1, -1):
371
+ if infer_request.messages[i]['role'] == 'assistant':
372
+ response = infer_request.messages[i]['content']
373
+ agent_type = cls.agent_type(response)
374
+ if i != len(infer_request.messages) - 1 and agent_type == 'toolbench':
375
+ infer_request.messages[i + 1]['role'] = 'tool'
376
+
377
+ chat = not template_type.endswith('generation')
378
+ _infer_request = deepcopy(infer_request)
379
+ for m in _infer_request.messages:
380
+ if 'medias' in m:
381
+ m.pop('medias')
382
+ model_kwargs = {}
383
+ if infer_model_type:
384
+ model_kwargs = {'model': infer_model_type}
385
+ gen_list = InferClient(
386
+ port=args['port'], ).infer(
387
+ infer_requests=[_infer_request], request_config=request_config, **model_kwargs)
388
+ if infer_request.messages[-1]['role'] != 'assistant':
389
+ infer_request.messages.append({'role': 'assistant', 'content': ''})
390
+ for chunk in gen_list[0]:
391
+ if chunk is None:
392
+ continue
393
+ stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text
394
+ infer_request.messages[-1]['content'] = stream_resp_with_history
395
+ yield '', cls._replace_tag_with_media(infer_request), gr.update(value=None), gr.update(
396
+ value=None), gr.update(value=None), infer_request
ms-swift/swift/ui/llm_infer/model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Type
4
+
5
+ import gradio as gr
6
+
7
+ from swift.llm import TEMPLATE_MAPPING, DeployArguments, ModelType
8
+ from swift.llm.model.register import get_all_models
9
+ from swift.ui.base import BaseUI
10
+ from swift.ui.llm_infer.generate import Generate
11
+
12
+
13
+ class Model(BaseUI):
14
+
15
+ llm_train = 'llm_infer'
16
+
17
+ sub_ui = [Generate]
18
+
19
+ locale_dict = {
20
+ 'model_type': {
21
+ 'label': {
22
+ 'zh': '选择模型类型',
23
+ 'en': 'Select Model Type'
24
+ },
25
+ 'info': {
26
+ 'zh': 'SWIFT已支持的模型类型',
27
+ 'en': 'Base model type supported by SWIFT'
28
+ }
29
+ },
30
+ 'load_checkpoint': {
31
+ 'value': {
32
+ 'zh': '部署模型',
33
+ 'en': 'Deploy model',
34
+ }
35
+ },
36
+ 'model': {
37
+ 'label': {
38
+ 'zh': '模型id或路径',
39
+ 'en': 'Model id or path'
40
+ },
41
+ 'info': {
42
+ 'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
43
+ 'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
44
+ }
45
+ },
46
+ 'template': {
47
+ 'label': {
48
+ 'zh': '模型Prompt模板类型',
49
+ 'en': 'Prompt template type'
50
+ },
51
+ 'info': {
52
+ 'zh': '选择匹配模型的Prompt模板',
53
+ 'en': 'Choose the template type of the model'
54
+ }
55
+ },
56
+ 'merge_lora': {
57
+ 'label': {
58
+ 'zh': '合并lora',
59
+ 'en': 'merge lora'
60
+ },
61
+ 'info': {
62
+ 'zh': '仅在sft_type=lora时可用',
63
+ 'en': 'Only available when sft_type=lora'
64
+ }
65
+ },
66
+ 'lora_modules': {
67
+ 'label': {
68
+ 'zh': '外部lora模块',
69
+ 'en': 'More lora modules'
70
+ },
71
+ 'info': {
72
+ 'zh': '空格分割的name=/path1/path2键值对',
73
+ 'en': 'name=/path1/path2 split by blanks'
74
+ }
75
+ },
76
+ 'more_params': {
77
+ 'label': {
78
+ 'zh': '更多参数',
79
+ 'en': 'More params'
80
+ },
81
+ 'info': {
82
+ 'zh': '以json格式或--xxx xxx命令行格式填入',
83
+ 'en': 'Fill in with json format or --xxx xxx cmd format'
84
+ }
85
+ },
86
+ 'reset': {
87
+ 'value': {
88
+ 'zh': '恢复初始值',
89
+ 'en': 'Reset to default'
90
+ },
91
+ },
92
+ 'infer_backend': {
93
+ 'label': {
94
+ 'zh': '推理框架',
95
+ 'en': 'Infer backend'
96
+ },
97
+ },
98
+ }
99
+
100
+ @classmethod
101
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
102
+ with gr.Row():
103
+ gr.Dropdown(
104
+ elem_id='model',
105
+ scale=20,
106
+ choices=get_all_models(),
107
+ value='Qwen/Qwen2.5-7B-Instruct',
108
+ allow_custom_value=True)
109
+ gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
110
+ gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
111
+ gr.Checkbox(elem_id='merge_lora', scale=4)
112
+ gr.Button(elem_id='reset', scale=2)
113
+ with gr.Row():
114
+ gr.Dropdown(elem_id='infer_backend', value='pt', scale=5)
115
+ Generate.build_ui(base_tab)
116
+ with gr.Row():
117
+ gr.Textbox(elem_id='lora_modules', lines=1, is_list=True, scale=40)
118
+ gr.Textbox(elem_id='more_params', lines=1, scale=20)
119
+ gr.Button(elem_id='load_checkpoint', scale=2, variant='primary')
120
+
121
+ @classmethod
122
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
123
+ cls.element('model').change(
124
+ partial(cls.update_input_model, arg_cls=DeployArguments, has_record=False),
125
+ inputs=[cls.element('model')],
126
+ outputs=list(cls.valid_elements().values()))
ms-swift/swift/ui/llm_infer/runtime.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import collections
3
+ import os.path
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from typing import Dict, List, Tuple, Type
8
+
9
+ import gradio as gr
10
+ import psutil
11
+ from packaging import version
12
+
13
+ from swift.ui.base import BaseUI
14
+ from swift.utils import get_logger
15
+ from swift.utils.utils import format_time
16
+
17
+ logger = get_logger()
18
+
19
+
20
+ class Runtime(BaseUI):
21
+ handlers: Dict[str, Tuple[List, Tuple]] = {}
22
+
23
+ group = 'llm_infer'
24
+
25
+ cmd = 'deploy'
26
+
27
+ log_event = {}
28
+
29
+ locale_dict = {
30
+ 'runtime_tab': {
31
+ 'label': {
32
+ 'zh': '运行时',
33
+ 'en': 'Runtime'
34
+ },
35
+ },
36
+ 'running_cmd': {
37
+ 'label': {
38
+ 'zh': '运行命令',
39
+ 'en': 'Command line'
40
+ },
41
+ 'info': {
42
+ 'zh': '执行的实际命令',
43
+ 'en': 'The actual command'
44
+ }
45
+ },
46
+ 'show_log': {
47
+ 'value': {
48
+ 'zh': '展示部署状态',
49
+ 'en': 'Show running status'
50
+ },
51
+ },
52
+ 'stop_show_log': {
53
+ 'value': {
54
+ 'zh': '停止展示',
55
+ 'en': 'Stop showing running status'
56
+ },
57
+ },
58
+ 'log': {
59
+ 'label': {
60
+ 'zh': '日志输出',
61
+ 'en': 'Logging content'
62
+ },
63
+ 'info': {
64
+ 'zh': '如果日志无更新请再次点击"展示日志内容"',
65
+ 'en': 'Please press "Show log" if the log content is not updating'
66
+ }
67
+ },
68
+ 'running_tasks': {
69
+ 'label': {
70
+ 'zh': '运行中部署',
71
+ 'en': 'Running deployments'
72
+ },
73
+ 'info': {
74
+ 'zh': '所有的swift deploy命令启动的任务',
75
+ 'en': 'Started by swift deploy'
76
+ }
77
+ },
78
+ 'refresh_tasks': {
79
+ 'value': {
80
+ 'zh': '找回部署',
81
+ 'en': 'Find deployments'
82
+ },
83
+ },
84
+ 'kill_task': {
85
+ 'value': {
86
+ 'zh': '杀死部署',
87
+ 'en': 'Kill running task'
88
+ },
89
+ },
90
+ }
91
+
92
+ @classmethod
93
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
94
+ with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
95
+ with gr.Blocks():
96
+ with gr.Row():
97
+ gr.Dropdown(elem_id='running_tasks', scale=10, allow_custom_value=True)
98
+ gr.Button(elem_id='refresh_tasks', scale=1, variant='primary')
99
+ gr.Button(elem_id='show_log', scale=1, variant='primary')
100
+ gr.Button(elem_id='stop_show_log', scale=1)
101
+ gr.Button(elem_id='kill_task', scale=1)
102
+ with gr.Row():
103
+ gr.Textbox(elem_id='log', lines=6, visible=False)
104
+
105
+ concurrency_limit = {}
106
+ if version.parse(gr.__version__) >= version.parse('4.0.0'):
107
+ concurrency_limit = {'concurrency_limit': 5}
108
+ base_tab.element('show_log').click(cls.update_log, [],
109
+ [cls.element('log')]).then(cls.wait,
110
+ [base_tab.element('running_tasks')],
111
+ [cls.element('log')], **concurrency_limit)
112
+
113
+ base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])
114
+
115
+ base_tab.element('refresh_tasks').click(
116
+ cls.refresh_tasks,
117
+ [base_tab.element('running_tasks')],
118
+ [base_tab.element('running_tasks')],
119
+ )
120
+
121
+ @classmethod
122
+ def break_log_event(cls, task):
123
+ if not task:
124
+ return
125
+ pid, all_args = cls.parse_info_from_cmdline(task)
126
+ cls.log_event[all_args['log_file']] = True
127
+
128
+ @classmethod
129
+ def update_log(cls):
130
+ return gr.update(visible=True)
131
+
132
+ @classmethod
133
+ def wait(cls, task):
134
+ if not task:
135
+ return [None]
136
+ _, args = cls.parse_info_from_cmdline(task)
137
+ log_file = args['log_file']
138
+ cls.log_event[log_file] = False
139
+ offset = 0
140
+ latest_data = ''
141
+ lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
142
+ try:
143
+ with open(log_file, 'r', encoding='utf-8') as input:
144
+ input.seek(offset)
145
+ fail_cnt = 0
146
+ while True:
147
+ try:
148
+ latest_data += input.read()
149
+ except UnicodeDecodeError:
150
+ continue
151
+ if not latest_data:
152
+ time.sleep(0.5)
153
+ fail_cnt += 1
154
+ if fail_cnt > 50:
155
+ break
156
+
157
+ if cls.log_event.get(log_file, False):
158
+ cls.log_event[log_file] = False
159
+ break
160
+
161
+ if '\n' not in latest_data:
162
+ continue
163
+ latest_lines = latest_data.split('\n')
164
+ if latest_data[-1] != '\n':
165
+ latest_data = latest_lines[-1]
166
+ latest_lines = latest_lines[:-1]
167
+ else:
168
+ latest_data = ''
169
+ lines.extend(latest_lines)
170
+ yield '\n'.join(lines)
171
+ except IOError:
172
+ pass
173
+
174
+ @classmethod
175
+ def get_all_ports(cls):
176
+ process_name = 'swift'
177
+ cmd_name = cls.cmd
178
+ ports = set()
179
+ for proc in psutil.process_iter():
180
+ try:
181
+ cmdlines = proc.cmdline()
182
+ except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
183
+ cmdlines = []
184
+ if any([process_name in cmdline for cmdline in cmdlines]) and any( # noqa
185
+ [cmd_name == cmdline for cmdline in cmdlines]): # noqa
186
+ try:
187
+ ports.add(int(cls.parse_info_from_cmdline(cls.construct_running_task(proc))[1].get('port', 8000)))
188
+ except IndexError:
189
+ pass
190
+ return ports
191
+
192
+ @classmethod
193
+ def refresh_tasks(cls, running_task=None):
194
+ log_file = running_task if not running_task or 'pid:' not in running_task else None
195
+ process_name = 'swift'
196
+ negative_name = 'swift.exe'
197
+ cmd_name = cls.cmd
198
+ process = []
199
+ selected = None
200
+ for proc in psutil.process_iter():
201
+ try:
202
+ cmdlines = proc.cmdline()
203
+ except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
204
+ cmdlines = []
205
+ if any([process_name in cmdline
206
+ for cmdline in cmdlines]) and not any([negative_name in cmdline
207
+ for cmdline in cmdlines]) and any( # noqa
208
+ [cmd_name == cmdline for cmdline in cmdlines]): # noqa
209
+ process.append(cls.construct_running_task(proc))
210
+ if log_file is not None and any( # noqa
211
+ [log_file == cmdline for cmdline in cmdlines]): # noqa
212
+ selected = cls.construct_running_task(proc)
213
+ if not selected:
214
+ if running_task and running_task in process:
215
+ selected = running_task
216
+ if not selected and process:
217
+ selected = process[0]
218
+ return gr.update(choices=process, value=selected)
219
+
220
+ @staticmethod
221
+ def construct_running_task(proc):
222
+ pid = proc.pid
223
+ ts = time.time()
224
+ create_time = proc.create_time()
225
+ create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
226
+
227
+ return f'pid:{pid}/create:{create_time_formatted}' \
228
+ f'/running:{format_time(ts - create_time)}/cmd:{" ".join(proc.cmdline())}'
229
+
230
+ @classmethod
231
+ def parse_info_from_cmdline(cls, task):
232
+ pid = None
233
+ for i in range(3):
234
+ slash = task.find('/')
235
+ if i == 0:
236
+ pid = task[:slash].split(':')[1]
237
+ task = task[slash + 1:]
238
+ args = task.split(f'swift {cls.cmd}')[1]
239
+ args = [arg.strip() for arg in args.split('--') if arg.strip()]
240
+ all_args = {}
241
+ for i in range(len(args)):
242
+ space = args[i].find(' ')
243
+ splits = args[i][:space], args[i][space + 1:]
244
+ all_args[splits[0]] = splits[1]
245
+ return pid, all_args
246
+
247
+ @classmethod
248
+ def kill_task(cls, task):
249
+ if task:
250
+ pid, all_args = cls.parse_info_from_cmdline(task)
251
+ log_file = all_args['log_file']
252
+ if sys.platform == 'win32':
253
+ os.system(f'taskkill /f /t /pid "{pid}"')
254
+ else:
255
+ os.system(f'pkill -9 -f {log_file}')
256
+ time.sleep(1)
257
+ cls.break_log_event(task)
258
+ return [cls.refresh_tasks()] + [gr.update(value=None)]
259
+
260
+ @classmethod
261
+ def task_changed(cls, task, base_tab):
262
+ if task:
263
+ _, all_args = cls.parse_info_from_cmdline(task)
264
+ else:
265
+ all_args = {}
266
+ elements = list(base_tab.valid_elements().values())
267
+ ret = []
268
+ is_custom_path = 'ckpt_dir' in all_args
269
+ for e in elements:
270
+ if e.elem_id in all_args:
271
+ if isinstance(e, gr.Dropdown) and e.multiselect:
272
+ arg = all_args[e.elem_id].split(' ')
273
+ else:
274
+ if e.elem_id == 'model':
275
+ if is_custom_path:
276
+ arg = all_args['ckpt_dir']
277
+ else:
278
+ arg = all_args[e.elem_id]
279
+ else:
280
+ arg = all_args[e.elem_id]
281
+ ret.append(gr.update(value=arg))
282
+ else:
283
+ ret.append(gr.update())
284
+ cls.break_log_event(task)
285
+ return ret + [gr.update(value=None)]
ms-swift/swift/ui/llm_train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
ms-swift/swift/ui/llm_train/advanced.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Advanced(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'advanced_param': {
15
+ 'label': {
16
+ 'zh': '高级参数设置',
17
+ 'en': 'Advanced settings'
18
+ },
19
+ },
20
+ 'optim': {
21
+ 'label': {
22
+ 'zh': 'Optimizer类型',
23
+ 'en': 'The Optimizer type'
24
+ },
25
+ 'info': {
26
+ 'zh': '设置Optimizer类型',
27
+ 'en': 'Set the Optimizer type'
28
+ }
29
+ },
30
+ 'weight_decay': {
31
+ 'label': {
32
+ 'zh': '权重衰减',
33
+ 'en': 'Weight decay'
34
+ },
35
+ 'info': {
36
+ 'zh': '设置weight decay',
37
+ 'en': 'Set the weight decay'
38
+ }
39
+ },
40
+ 'logging_steps': {
41
+ 'label': {
42
+ 'zh': '日志打印步数',
43
+ 'en': 'Logging steps'
44
+ },
45
+ 'info': {
46
+ 'zh': '设置日志打印的步数间隔',
47
+ 'en': 'Set the logging interval'
48
+ }
49
+ },
50
+ 'lr_scheduler_type': {
51
+ 'label': {
52
+ 'zh': 'LrScheduler类型',
53
+ 'en': 'The LrScheduler type'
54
+ },
55
+ 'info': {
56
+ 'zh': '设置LrScheduler类型',
57
+ 'en': 'Set the LrScheduler type'
58
+ }
59
+ },
60
+ 'warmup_ratio': {
61
+ 'label': {
62
+ 'zh': '学习率warmup比例',
63
+ 'en': 'Lr warmup ratio'
64
+ },
65
+ 'info': {
66
+ 'zh': '设置学习率warmup比例',
67
+ 'en': 'Set the warmup ratio in total steps'
68
+ }
69
+ },
70
+ 'more_params': {
71
+ 'label': {
72
+ 'zh': '其他高级参数',
73
+ 'en': 'Other params'
74
+ },
75
+ 'info': {
76
+ 'zh': '以json格式或--xxx xxx命令行格式填入',
77
+ 'en': 'Fill in with json format or --xxx xxx cmd format'
78
+ }
79
+ },
80
+ 'truncation_strategy': {
81
+ 'label': {
82
+ 'zh': '数据集超长策略',
83
+ 'en': 'Dataset truncation strategy'
84
+ },
85
+ 'info': {
86
+ 'zh': '如果token超长该如何处理',
87
+ 'en': 'How to deal with the rows exceed the max length'
88
+ }
89
+ },
90
+ 'max_steps': {
91
+ 'label': {
92
+ 'zh': '最大迭代步数',
93
+ 'en': 'Max steps',
94
+ },
95
+ 'info': {
96
+ 'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
97
+ 'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
98
+ }
99
+ },
100
+ 'per_device_eval_batch_size': {
101
+ 'label': {
102
+ 'zh': '验证batch size',
103
+ 'en': 'Val batch size',
104
+ },
105
+ 'info': {
106
+ 'zh': '验证的batch size',
107
+ 'en': 'Set the val batch size',
108
+ }
109
+ },
110
+ 'max_grad_norm': {
111
+ 'label': {
112
+ 'zh': '梯度裁剪',
113
+ 'en': 'Max grad norm',
114
+ },
115
+ 'info': {
116
+ 'zh': '设置梯度裁剪',
117
+ 'en': 'Set the max grad norm',
118
+ }
119
+ },
120
+ 'predict_with_generate': {
121
+ 'label': {
122
+ 'zh': '使用生成指标代替loss',
123
+ 'en': 'Use generate metric instead of loss',
124
+ },
125
+ 'info': {
126
+ 'zh': '验证时使用generate/Rouge代替loss',
127
+ 'en': 'Use model.generate/Rouge instead of loss',
128
+ }
129
+ },
130
+ 'deepspeed': {
131
+ 'label': {
132
+ 'zh': 'deepspeed',
133
+ 'en': 'deepspeed',
134
+ },
135
+ 'info': {
136
+ 'zh': '可以选择下拉列表,也支持传入路径',
137
+ 'en': 'Choose from the dropbox or fill in a valid path',
138
+ }
139
+ },
140
+ }
141
+
142
+ @classmethod
143
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
144
+ with gr.Accordion(elem_id='advanced_param', open=False):
145
+ with gr.Blocks():
146
+ with gr.Row():
147
+ gr.Textbox(elem_id='optim', lines=1, scale=20)
148
+ gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
149
+ gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
150
+ gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
151
+ gr.Textbox(elem_id='max_steps', lines=1, scale=20)
152
+ gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
153
+ with gr.Row():
154
+ gr.Dropdown(elem_id='truncation_strategy', scale=20)
155
+ gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
156
+ gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
157
+ gr.Dropdown(
158
+ elem_id='deepspeed',
159
+ scale=20,
160
+ allow_custom_value=True,
161
+ value=None,
162
+ choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
163
+ with gr.Row():
164
+ gr.Textbox(elem_id='more_params', lines=4, scale=20)
ms-swift/swift/ui/llm_train/dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.llm.dataset.register import get_dataset_list
7
+ from swift.ui.base import BaseUI
8
+
9
+
10
+ class Dataset(BaseUI):
11
+
12
+ group = 'llm_train'
13
+
14
+ locale_dict = {
15
+ 'dataset': {
16
+ 'label': {
17
+ 'zh': '数据集名称',
18
+ 'en': 'Dataset Code'
19
+ },
20
+ 'info': {
21
+ 'zh': '选择训练的数据集,支持复选/本地路径',
22
+ 'en': 'The dataset(s) to train the models, support multi select and local folder/files'
23
+ }
24
+ },
25
+ 'max_length': {
26
+ 'label': {
27
+ 'zh': '句子最大长度',
28
+ 'en': 'The max length',
29
+ },
30
+ 'info': {
31
+ 'zh': '设置输入模型的最大长度',
32
+ 'en': 'Set the max length input to the model',
33
+ }
34
+ },
35
+ 'split_dataset_ratio': {
36
+ 'label': {
37
+ 'zh': '验证集拆分比例',
38
+ 'en': 'Split ratio of eval dataset'
39
+ },
40
+ 'info': {
41
+ 'zh': '表示将总数据的多少拆分到验证集中',
42
+ 'en': 'Split the datasets by this ratio for eval'
43
+ }
44
+ },
45
+ 'train_dataset_sample': {
46
+ 'label': {
47
+ 'zh': '训练集采样数量',
48
+ 'en': 'The sample size from the train dataset'
49
+ },
50
+ 'info': {
51
+ 'zh': '从训练集中采样一定行数进行训练',
52
+ 'en': 'Train with the sample size from the dataset',
53
+ }
54
+ },
55
+ 'val_dataset_sample': {
56
+ 'label': {
57
+ 'zh': '验证集采样数量',
58
+ 'en': 'The sample size from the val dataset'
59
+ },
60
+ 'info': {
61
+ 'zh': '从验证集中采样一定行数进行训练',
62
+ 'en': 'Validate with the sample size from the dataset',
63
+ }
64
+ },
65
+ 'custom_dataset_info': {
66
+ 'label': {
67
+ 'zh': '外部数据集配置',
68
+ 'en': 'Custom dataset config'
69
+ },
70
+ 'info': {
71
+ 'zh': '注册外部数据集的配置文件',
72
+ 'en': 'An extra dataset config to register your own datasets'
73
+ }
74
+ },
75
+ 'dataset_param': {
76
+ 'label': {
77
+ 'zh': '数据集设置',
78
+ 'en': 'Dataset settings'
79
+ },
80
+ },
81
+ }
82
+
83
+ @classmethod
84
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
85
+ with gr.Accordion(elem_id='dataset_param', open=True):
86
+ with gr.Row():
87
+ gr.Dropdown(
88
+ elem_id='dataset', multiselect=True, choices=get_dataset_list(), scale=20, allow_custom_value=True)
89
+ gr.Textbox(elem_id='custom_dataset_info', is_list=False, scale=20)
90
+ gr.Slider(elem_id='split_dataset_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=10)
91
+ gr.Slider(elem_id='max_length', minimum=32, maximum=32768, value=1024, step=1, scale=10)
ms-swift/swift/ui/llm_train/hyper.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Hyper(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'hyper_param': {
15
+ 'label': {
16
+ 'zh': '超参数设置(更多参数->高级参数设置)',
17
+ 'en': 'Hyper settings(more params->Advanced settings)',
18
+ },
19
+ },
20
+ 'per_device_train_batch_size': {
21
+ 'label': {
22
+ 'zh': '训练batch size',
23
+ 'en': 'Train batch size',
24
+ },
25
+ 'info': {
26
+ 'zh': '训练的batch size',
27
+ 'en': 'Set the train batch size',
28
+ }
29
+ },
30
+ 'learning_rate': {
31
+ 'label': {
32
+ 'zh': '学习率',
33
+ 'en': 'Learning rate',
34
+ },
35
+ 'info': {
36
+ 'zh': '设置学习率',
37
+ 'en': 'Set the learning rate',
38
+ }
39
+ },
40
+ 'eval_steps': {
41
+ 'label': {
42
+ 'zh': '交叉验证步数',
43
+ 'en': 'Eval steps',
44
+ },
45
+ 'info': {
46
+ 'zh': '设置每隔多少步数进行一次验证',
47
+ 'en': 'Set the step interval to validate',
48
+ }
49
+ },
50
+ 'num_train_epochs': {
51
+ 'label': {
52
+ 'zh': '数据集迭代轮次',
53
+ 'en': 'Train epoch',
54
+ },
55
+ 'info': {
56
+ 'zh': '设置对数据集训练多少轮次',
57
+ 'en': 'Set the max train epoch',
58
+ }
59
+ },
60
+ 'gradient_accumulation_steps': {
61
+ 'label': {
62
+ 'zh': '梯度累计步数',
63
+ 'en': 'Gradient accumulation steps',
64
+ },
65
+ 'info': {
66
+ 'zh': '设置梯度累计步数以减小显存占用',
67
+ 'en': 'Set the gradient accumulation steps',
68
+ }
69
+ },
70
+ 'attn_impl': {
71
+ 'label': {
72
+ 'zh': 'Flash Attention类型',
73
+ 'en': 'Flash Attention Type',
74
+ },
75
+ },
76
+ 'neftune_noise_alpha': {
77
+ 'label': {
78
+ 'zh': 'neftune_noise_alpha',
79
+ 'en': 'neftune_noise_alpha'
80
+ },
81
+ 'info': {
82
+ 'zh': '使用neftune提升训练效果, 一般设置为5或者10',
83
+ 'en': 'Use neftune to improve performance, normally the value should be 5 or 10'
84
+ }
85
+ },
86
+ 'save_steps': {
87
+ 'label': {
88
+ 'zh': '存储步数',
89
+ 'en': 'save steps',
90
+ },
91
+ 'info': {
92
+ 'zh': '设置每个多少步数进行存储',
93
+ 'en': 'Set the save steps',
94
+ }
95
+ },
96
+ 'output_dir': {
97
+ 'label': {
98
+ 'zh': '存储目录',
99
+ 'en': 'The output dir',
100
+ },
101
+ 'info': {
102
+ 'zh': '设置输出模型存储在哪个文件夹下',
103
+ 'en': 'Set the output folder',
104
+ }
105
+ },
106
+ }
107
+
108
+ @classmethod
109
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
110
+ with gr.Accordion(elem_id='hyper_param', open=False):
111
+ with gr.Blocks():
112
+ with gr.Row():
113
+ gr.Slider(elem_id='per_device_train_batch_size', minimum=1, maximum=256, step=2, scale=20)
114
+ gr.Textbox(elem_id='learning_rate', value='1e-4', lines=1, scale=20)
115
+ gr.Textbox(elem_id='num_train_epochs', lines=1, scale=20)
116
+ gr.Dropdown(elem_id='attn_impl', scale=20, value='flash_attn')
117
+ gr.Slider(elem_id='gradient_accumulation_steps', minimum=1, maximum=256, step=2, value=16, scale=20)
118
+ with gr.Row():
119
+ gr.Textbox(elem_id='eval_steps', lines=1, value='500', scale=20)
120
+ gr.Textbox(elem_id='save_steps', value='500', lines=1, scale=20)
121
+ gr.Textbox(elem_id='output_dir', scale=20)
122
+ gr.Slider(elem_id='neftune_noise_alpha', minimum=0.0, maximum=20.0, step=0.5, scale=20)
123
+
124
+ @staticmethod
125
+ def update_lr(sft_type):
126
+ if sft_type == 'full':
127
+ return 1e-5
128
+ else:
129
+ return 1e-4
ms-swift/swift/ui/llm_train/llamapro.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class LlamaPro(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'llamapro_tab': {
15
+ 'label': {
16
+ 'zh': 'LLAMAPRO参数设置',
17
+ 'en': 'LLAMAPRO Settings'
18
+ },
19
+ },
20
+ 'llamapro_num_new_blocks': {
21
+ 'label': {
22
+ 'zh': 'LLAMAPRO插入层数',
23
+ 'en': 'LLAMAPRO new layers'
24
+ },
25
+ },
26
+ 'llamapro_num_groups': {
27
+ 'label': {
28
+ 'zh': 'LLAMAPRO对原模型的分组数',
29
+ 'en': 'LLAMAPRO groups of model'
30
+ }
31
+ },
32
+ }
33
+
34
+ @classmethod
35
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
36
+ with gr.Accordion(elem_id='llamapro_tab', open=False):
37
+ with gr.Blocks():
38
+ with gr.Row():
39
+ gr.Textbox(elem_id='llamapro_num_new_blocks')
40
+ gr.Textbox(elem_id='llamapro_num_groups')
ms-swift/swift/ui/llm_train/llm_train.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import collections
3
+ import os
4
+ import re
5
+ import sys
6
+ import time
7
+ from functools import partial
8
+ from subprocess import PIPE, STDOUT, Popen
9
+ from typing import Dict, Type
10
+
11
+ import gradio as gr
12
+ import json
13
+ import torch
14
+ from json import JSONDecodeError
15
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
16
+
17
+ from swift.llm import RLHFArguments
18
+ from swift.llm.argument.base_args.base_args import get_supported_tuners
19
+ from swift.ui.base import BaseUI
20
+ from swift.ui.llm_train.advanced import Advanced
21
+ from swift.ui.llm_train.dataset import Dataset
22
+ from swift.ui.llm_train.galore import Galore
23
+ from swift.ui.llm_train.hyper import Hyper
24
+ from swift.ui.llm_train.lisa import Lisa
25
+ from swift.ui.llm_train.llamapro import LlamaPro
26
+ from swift.ui.llm_train.lora import LoRA
27
+ from swift.ui.llm_train.model import Model
28
+ from swift.ui.llm_train.quantization import Quantization
29
+ from swift.ui.llm_train.report_to import ReportTo
30
+ from swift.ui.llm_train.rlhf import RLHF
31
+ from swift.ui.llm_train.runtime import Runtime
32
+ from swift.ui.llm_train.save import Save
33
+ from swift.ui.llm_train.self_cog import SelfCog
34
+ from swift.utils import get_device_count, get_logger
35
+
36
+ logger = get_logger()
37
+
38
+
39
+ class LLMTrain(BaseUI):
40
+ group = 'llm_train'
41
+
42
+ sub_ui = [
43
+ Model,
44
+ Dataset,
45
+ Runtime,
46
+ Save,
47
+ LoRA,
48
+ Hyper,
49
+ Quantization,
50
+ SelfCog,
51
+ Advanced,
52
+ RLHF,
53
+ Lisa,
54
+ Galore,
55
+ LlamaPro,
56
+ ReportTo,
57
+ ]
58
+
59
+ locale_dict: Dict[str, Dict] = {
60
+ 'llm_train': {
61
+ 'label': {
62
+ 'zh': 'LLM训练',
63
+ 'en': 'LLM Training',
64
+ }
65
+ },
66
+ 'train_stage': {
67
+ 'label': {
68
+ 'zh': '训练Stage',
69
+ 'en': 'Train Stage'
70
+ },
71
+ 'info': {
72
+ 'zh': '请注意选择与此匹配的数据集,人类对齐配置在页面下方',
73
+ 'en': 'Please choose matched dataset, RLHF settings is at the bottom of the page'
74
+ }
75
+ },
76
+ 'submit_alert': {
77
+ 'value': {
78
+ 'zh':
79
+ '任务已开始,请查看tensorboard或日志记录,关闭本页面不影响训练过程',
80
+ 'en':
81
+ 'Task started, please check the tensorboard or log file, '
82
+ 'closing this page does not affect training'
83
+ }
84
+ },
85
+ 'dataset_alert': {
86
+ 'value': {
87
+ 'zh': '请选择或填入一个数据集',
88
+ 'en': 'Please input or select a dataset'
89
+ }
90
+ },
91
+ 'submit': {
92
+ 'value': {
93
+ 'zh': '🚀 开始训练',
94
+ 'en': '🚀 Begin'
95
+ }
96
+ },
97
+ 'dry_run': {
98
+ 'label': {
99
+ 'zh': '仅生成运行命令',
100
+ 'en': 'Dry-run'
101
+ },
102
+ 'info': {
103
+ 'zh': '仅生成运行命令,开发者自行运行',
104
+ 'en': 'Generate run command only, for manually running'
105
+ }
106
+ },
107
+ 'gpu_id': {
108
+ 'label': {
109
+ 'zh': '选择可用GPU',
110
+ 'en': 'Choose GPU'
111
+ },
112
+ 'info': {
113
+ 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU',
114
+ 'en': 'Select GPU to train'
115
+ }
116
+ },
117
+ 'train_type': {
118
+ 'label': {
119
+ 'zh': '训练方式',
120
+ 'en': 'Train type'
121
+ },
122
+ 'info': {
123
+ 'zh': '选择训练的方式',
124
+ 'en': 'Select the training type'
125
+ }
126
+ },
127
+ 'seed': {
128
+ 'label': {
129
+ 'zh': '随机数种子',
130
+ 'en': 'Seed'
131
+ },
132
+ 'info': {
133
+ 'zh': '选择随机数种子',
134
+ 'en': 'Select a random seed'
135
+ }
136
+ },
137
+ 'torch_dtype': {
138
+ 'label': {
139
+ 'zh': '训练精度',
140
+ 'en': 'Training Precision'
141
+ },
142
+ 'info': {
143
+ 'zh': '选择训练精度',
144
+ 'en': 'Select the training precision'
145
+ }
146
+ },
147
+ 'envs': {
148
+ 'label': {
149
+ 'zh': '环境变量',
150
+ 'en': 'Extra env vars'
151
+ },
152
+ },
153
+ 'use_ddp': {
154
+ 'label': {
155
+ 'zh': '使用DDP',
156
+ 'en': 'Use DDP'
157
+ },
158
+ 'info': {
159
+ 'zh': '是否使用数据并行训练',
160
+ 'en': 'Use Distributed Data Parallel to train'
161
+ }
162
+ },
163
+ 'ddp_num': {
164
+ 'label': {
165
+ 'zh': 'DDP分片数量',
166
+ 'en': 'Number of DDP sharding'
167
+ },
168
+ 'info': {
169
+ 'zh': '启用多少进程的数据并��',
170
+ 'en': 'The data parallel size of DDP'
171
+ }
172
+ },
173
+ 'tuner_backend': {
174
+ 'label': {
175
+ 'zh': 'Tuner backend',
176
+ 'en': 'Tuner backend'
177
+ },
178
+ 'info': {
179
+ 'zh': 'tuner实现框架',
180
+ 'en': 'The tuner backend'
181
+ }
182
+ },
183
+ 'use_liger_kernel': {
184
+ 'label': {
185
+ 'zh': '使用Liger kernel',
186
+ 'en': 'Use Liger kernel'
187
+ },
188
+ 'info': {
189
+ 'zh': 'Liger kernel可以有效降低显存使用',
190
+ 'en': 'Liger kernel can reduce memory usage'
191
+ }
192
+ },
193
+ 'train_param': {
194
+ 'label': {
195
+ 'zh': '训练参数设置',
196
+ 'en': 'Train settings'
197
+ },
198
+ },
199
+ }
200
+
201
+ choice_dict = BaseUI.get_choices_from_dataclass(RLHFArguments)
202
+ default_dict = BaseUI.get_default_value_from_dataclass(RLHFArguments)
203
+ arguments = BaseUI.get_argument_names(RLHFArguments)
204
+
205
+ @classmethod
206
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
207
+ with gr.TabItem(elem_id='llm_train', label=''):
208
+ default_device = 'cpu'
209
+ device_count = get_device_count()
210
+ if device_count > 0:
211
+ default_device = '0'
212
+ with gr.Blocks():
213
+ Model.build_ui(base_tab)
214
+ Dataset.build_ui(base_tab)
215
+ with gr.Accordion(elem_id='train_param', open=True):
216
+ with gr.Row():
217
+ gr.Dropdown(elem_id='train_stage', choices=['pt', 'sft', 'rlhf'], value='sft', scale=3)
218
+ gr.Dropdown(elem_id='train_type', scale=2, choices=list(get_supported_tuners()))
219
+ gr.Dropdown(elem_id='tuner_backend', scale=2)
220
+ with gr.Row():
221
+ gr.Textbox(elem_id='seed', scale=4)
222
+ gr.Dropdown(elem_id='torch_dtype', scale=4)
223
+ gr.Checkbox(elem_id='use_liger_kernel', scale=4)
224
+ gr.Checkbox(elem_id='use_ddp', value=False, scale=4)
225
+ gr.Textbox(elem_id='ddp_num', value='2', scale=4)
226
+ Hyper.build_ui(base_tab)
227
+ Runtime.build_ui(base_tab)
228
+ with gr.Row():
229
+ gr.Dropdown(
230
+ elem_id='gpu_id',
231
+ multiselect=True,
232
+ choices=[str(i) for i in range(device_count)] + ['cpu'],
233
+ value=default_device,
234
+ scale=8)
235
+ gr.Textbox(elem_id='envs', scale=8)
236
+ gr.Checkbox(elem_id='dry_run', value=False, scale=4)
237
+ submit = gr.Button(elem_id='submit', scale=4, variant='primary')
238
+
239
+ LoRA.build_ui(base_tab)
240
+ RLHF.build_ui(base_tab)
241
+ Quantization.build_ui(base_tab)
242
+ Galore.build_ui(base_tab)
243
+ Lisa.build_ui(base_tab)
244
+ LlamaPro.build_ui(base_tab)
245
+ SelfCog.build_ui(base_tab)
246
+ Save.build_ui(base_tab)
247
+ ReportTo.build_ui(base_tab)
248
+ Advanced.build_ui(base_tab)
249
+
250
+ cls.element('train_type').change(
251
+ Hyper.update_lr, inputs=[base_tab.element('train_type')], outputs=[cls.element('learning_rate')])
252
+
253
+ submit.click(
254
+ cls.train_local,
255
+ list(cls.valid_elements().values()), [
256
+ cls.element('running_cmd'),
257
+ cls.element('logging_dir'),
258
+ cls.element('runtime_tab'),
259
+ cls.element('running_tasks'),
260
+ cls.element('train_record'),
261
+ ],
262
+ queue=True)
263
+
264
+ base_tab.element('running_tasks').change(
265
+ partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
266
+ list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots)
267
+ Runtime.element('kill_task').click(
268
+ Runtime.kill_task,
269
+ [Runtime.element('running_tasks')],
270
+ [Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots,
271
+ ).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Hyper.element('output_dir')])
272
+
273
+ @classmethod
274
+ def update_runtime(cls):
275
+ return gr.update(open=True), gr.update(visible=True)
276
+
277
+ @classmethod
278
+ def train(cls, *args):
279
+ ignore_elements = ('logging_dir', 'more_params', 'train_stage', 'envs')
280
+ default_args = cls.get_default_value_from_dataclass(RLHFArguments)
281
+ kwargs = {}
282
+ kwargs_is_list = {}
283
+ other_kwargs = {}
284
+ more_params = {}
285
+ more_params_cmd = ''
286
+ keys = cls.valid_element_keys()
287
+ train_stage = 'sft'
288
+ for key, value in zip(keys, args):
289
+ compare_value = default_args.get(key)
290
+ if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
291
+ value = int(value)
292
+ elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
293
+ value = float(value)
294
+ elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
295
+ value = True if value.lower() == 'true' else False
296
+ if key not in ignore_elements and key in default_args and compare_value != value and value:
297
+ kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
298
+ kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
299
+ else:
300
+ other_kwargs[key] = value
301
+ if key == 'more_params' and value:
302
+ try:
303
+ more_params = json.loads(value)
304
+ except (JSONDecodeError or TypeError):
305
+ more_params_cmd = value
306
+
307
+ if key == 'train_stage':
308
+ train_stage = value
309
+
310
+ kwargs.update(more_params)
311
+ if 'dataset' not in kwargs and 'custom_train_dataset_path' not in kwargs:
312
+ raise gr.Error(cls.locale('dataset_alert', cls.lang)['value'])
313
+
314
+ model = kwargs.get('model')
315
+ if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
316
+ kwargs['resume_from_checkpoint'] = kwargs.pop('model')
317
+
318
+ cmd = train_stage
319
+ if kwargs.get('deepspeed'):
320
+ more_params_cmd += f' --deepspeed {kwargs.pop("deepspeed")} '
321
+ try:
322
+ sft_args = RLHFArguments(
323
+ **{
324
+ key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
325
+ for key, value in kwargs.items()
326
+ })
327
+ except Exception as e:
328
+ if 'using `--model`' in str(e): # TODO a dirty fix
329
+ kwargs['model'] = kwargs.pop('resume_from_checkpoint')
330
+ sft_args = RLHFArguments(
331
+ **{
332
+ key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
333
+ for key, value in kwargs.items()
334
+ })
335
+ else:
336
+ raise e
337
+ params = ''
338
+
339
+ sep = f'{cls.quote} {cls.quote}'
340
+ for e in kwargs:
341
+ if isinstance(kwargs[e], list):
342
+ params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
343
+ elif e in kwargs_is_list and kwargs_is_list[e]:
344
+ all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
345
+ params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
346
+ else:
347
+ params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
348
+ params += more_params_cmd + ' '
349
+ params += f'--add_version False --output_dir {sft_args.output_dir} ' \
350
+ f'--logging_dir {sft_args.logging_dir} --ignore_args_error True'
351
+ ddp_param = ''
352
+ devices = other_kwargs['gpu_id']
353
+ envs = other_kwargs['envs'] or ''
354
+ envs = envs.strip()
355
+ devices = [d for d in devices if d]
356
+ if other_kwargs['use_ddp']:
357
+ assert int(other_kwargs['ddp_num']) > 0
358
+ ddp_param = f'NPROC_PER_NODE={int(other_kwargs["ddp_num"])}'
359
+ assert (len(devices) == 1 or 'cpu' not in devices)
360
+ gpus = ','.join(devices)
361
+ cuda_param = ''
362
+ if gpus != 'cpu':
363
+ if is_torch_npu_available():
364
+ cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
365
+ elif is_torch_cuda_available():
366
+ cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
367
+ else:
368
+ cuda_param = ''
369
+
370
+ log_file = os.path.join(sft_args.logging_dir, 'run.log')
371
+ if sys.platform == 'win32':
372
+ if cuda_param:
373
+ cuda_param = f'set {cuda_param} && '
374
+ if ddp_param:
375
+ ddp_param = f'set {ddp_param} && '
376
+ if envs:
377
+ envs = [env.strip() for env in envs.split(' ') if env.strip()]
378
+ _envs = ''
379
+ for env in envs:
380
+ _envs += f'set {env} && '
381
+ envs = _envs
382
+ run_command = f'{cuda_param}{ddp_param}{envs}start /b swift sft {params} > {log_file} 2>&1'
383
+ else:
384
+ run_command = f'{cuda_param} {ddp_param} {envs} nohup swift {cmd} {params} > {log_file} 2>&1 &'
385
+ logger.info(f'Run training: {run_command}')
386
+ if model:
387
+ record = {}
388
+ for key, value in zip(keys, args):
389
+ if key in default_args or key in ('more_params', 'train_stage', 'use_ddp', 'ddp_num', 'gpu_id', 'envs'):
390
+ record[key] = value or None
391
+ cls.save_cache(model, record)
392
+ return run_command, sft_args, other_kwargs
393
+
394
+ @classmethod
395
+ def train_studio(cls, *args):
396
+ run_command, sft_args, other_kwargs = cls.train(*args)
397
+ if not other_kwargs['dry_run']:
398
+ lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
399
+ process = Popen(run_command, shell=True, stdout=PIPE, stderr=STDOUT)
400
+ with process.stdout:
401
+ for line in iter(process.stdout.readline, b''):
402
+ line = line.decode('utf-8')
403
+ lines.append(line)
404
+ yield ['\n'.join(lines)] + Runtime.plot(run_command) + [run_command]
405
+ else:
406
+ yield [
407
+ 'Current is dryrun mode so you can only view the training cmd, please duplicate this space to '
408
+ 'do training or use with inference.'
409
+ ] + [None] * len(Runtime.sft_plot) + [run_command]
410
+
411
+ @classmethod
412
+ def train_local(cls, *args):
413
+ run_command, sft_args, other_kwargs = cls.train(*args)
414
+ if not other_kwargs['dry_run']:
415
+ os.makedirs(sft_args.logging_dir, exist_ok=True)
416
+ os.system(run_command)
417
+ time.sleep(1) # to make sure the log file has been created.
418
+ gr.Info(cls.locale('submit_alert', cls.lang)['value'])
419
+ return run_command, sft_args.logging_dir, gr.update(open=True), Runtime.refresh_tasks(
420
+ sft_args.output_dir), gr.update(choices=cls.list_cache(sft_args.model))
ms-swift/swift/ui/llm_train/lora.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class LoRA(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'lora_tab': {
15
+ 'label': {
16
+ 'zh': 'LoRA参数设置',
17
+ 'en': 'LoRA settings'
18
+ },
19
+ },
20
+ 'target_modules': {
21
+ 'label': {
22
+ 'zh': 'LoRA目标模块',
23
+ 'en': 'LoRA target modules'
24
+ },
25
+ 'info': {
26
+ 'zh': '设置LoRA目标模块,如训练所有Linear请改为`all-linear`',
27
+ 'en': 'Set the LoRA target modules, fill in `all-linear` if train all Linears'
28
+ }
29
+ },
30
+ 'lora_rank': {
31
+ 'label': {
32
+ 'zh': 'LoRA的秩',
33
+ 'en': 'The LoRA rank'
34
+ }
35
+ },
36
+ 'lora_alpha': {
37
+ 'label': {
38
+ 'zh': 'LoRA的alpha',
39
+ 'en': 'The LoRA alpha'
40
+ }
41
+ },
42
+ 'lora_dropout': {
43
+ 'label': {
44
+ 'zh': 'LoRA的dropout',
45
+ 'en': 'The LoRA dropout'
46
+ }
47
+ },
48
+ 'use_rslora': {
49
+ 'label': {
50
+ 'zh': '使用rslora',
51
+ 'en': 'Use rslora'
52
+ }
53
+ },
54
+ 'use_dora': {
55
+ 'label': {
56
+ 'zh': '使用dora',
57
+ 'en': 'Use dora'
58
+ }
59
+ },
60
+ 'lora_dtype': {
61
+ 'label': {
62
+ 'zh': 'lora部分的参数类型',
63
+ 'en': 'The dtype of lora parameters'
64
+ }
65
+ },
66
+ 'init_weights': {
67
+ 'label': {
68
+ 'zh': 'lora初始化方法',
69
+ 'en': 'init lora weights'
70
+ },
71
+ 'info': {
72
+ 'zh': 'gaussian/pissa/pissa_niter_[n]/olora/loftq/true/false',
73
+ 'en': 'gaussian/pissa/pissa_niter_[n]/olora/loftq/true/false',
74
+ }
75
+ },
76
+ 'lorap_lr_ratio': {
77
+ 'label': {
78
+ 'zh': 'Lora+学习率倍率',
79
+ 'en': 'The lr ratio of Lora+'
80
+ },
81
+ 'info': {
82
+ 'zh': '建议值16.0',
83
+ 'en': 'Suggested value: 16.0'
84
+ }
85
+ },
86
+ }
87
+
88
+ @classmethod
89
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
90
+ with gr.Accordion(elem_id='lora_tab', open=True):
91
+ with gr.Blocks():
92
+ with gr.Row():
93
+ gr.Textbox(elem_id='target_modules', lines=1, scale=5, value='all-linear', is_list=True)
94
+ gr.Slider(elem_id='lora_rank', value=8, minimum=1, maximum=512, step=8, scale=2)
95
+ gr.Slider(elem_id='lora_alpha', value=32, minimum=1, maximum=512, step=8, scale=2)
96
+ gr.Textbox(elem_id='lora_dropout', scale=2)
97
+ with gr.Row():
98
+ gr.Dropdown(elem_id='lora_dtype', scale=2, value=None)
99
+ gr.Textbox(elem_id='lorap_lr_ratio', scale=2)
100
+ gr.Checkbox(elem_id='use_rslora', scale=2)
101
+ gr.Checkbox(elem_id='use_dora', scale=2)
102
+ gr.Textbox(elem_id='init_weights', scale=4)
ms-swift/swift/ui/llm_train/model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Type
4
+
5
+ import gradio as gr
6
+
7
+ from swift.llm import TEMPLATE_MAPPING, ModelType, RLHFArguments
8
+ from swift.llm.model.register import get_all_models
9
+ from swift.ui.base import BaseUI
10
+
11
+
12
+ class Model(BaseUI):
13
+ group = 'llm_train'
14
+
15
+ locale_dict = {
16
+ 'model_type': {
17
+ 'label': {
18
+ 'zh': '模型类型',
19
+ 'en': 'Select Model Type'
20
+ },
21
+ 'info': {
22
+ 'zh': 'SWIFT已支持的模型类型',
23
+ 'en': 'Base model type supported by SWIFT'
24
+ }
25
+ },
26
+ 'model': {
27
+ 'label': {
28
+ 'zh': '模型id或路径',
29
+ 'en': 'Model id or path'
30
+ },
31
+ 'info': {
32
+ 'zh': '实际的模型id',
33
+ 'en': 'The actual model id or model path'
34
+ }
35
+ },
36
+ 'template': {
37
+ 'label': {
38
+ 'zh': '模型Prompt模板类型',
39
+ 'en': 'Prompt template type'
40
+ },
41
+ 'info': {
42
+ 'zh': '选择匹配模型的Prompt模板',
43
+ 'en': 'Choose the template type of the model'
44
+ }
45
+ },
46
+ 'system': {
47
+ 'label': {
48
+ 'zh': 'system字段',
49
+ 'en': 'system'
50
+ },
51
+ 'info': {
52
+ 'zh': '选择system字段的内容',
53
+ 'en': 'Choose the content of the system field'
54
+ }
55
+ },
56
+ 'reset': {
57
+ 'value': {
58
+ 'zh': '恢复模型初始值',
59
+ 'en': 'Reset model default'
60
+ },
61
+ },
62
+ 'train_record': {
63
+ 'label': {
64
+ 'zh': '训练记录',
65
+ 'en': 'Train record'
66
+ },
67
+ 'info': {
68
+ 'zh': '展示使用web-ui的历史训练及参数',
69
+ 'en': 'Show the training history and parameters'
70
+ }
71
+ },
72
+ 'clear_cache': {
73
+ 'value': {
74
+ 'zh': '删除训练记录',
75
+ 'en': 'Delete train records'
76
+ },
77
+ },
78
+ 'model_param': {
79
+ 'label': {
80
+ 'zh': '模型设置',
81
+ 'en': 'Model settings'
82
+ },
83
+ },
84
+ 'checkpoint': {
85
+ 'value': {
86
+ 'zh': '训练后的模型',
87
+ 'en': 'Trained model'
88
+ }
89
+ },
90
+ }
91
+
92
+ @classmethod
93
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
94
+ with gr.Accordion(elem_id='model_param', open=True):
95
+ with gr.Row():
96
+ model = gr.Dropdown(
97
+ elem_id='model',
98
+ scale=20,
99
+ choices=get_all_models(),
100
+ value='Qwen/Qwen2.5-7B-Instruct',
101
+ allow_custom_value=True)
102
+ gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
103
+ gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
104
+ train_record = gr.Dropdown(elem_id='train_record', choices=[], scale=20)
105
+ clear_cache = gr.Button(elem_id='clear_cache', scale=2)
106
+ with gr.Row():
107
+ gr.Textbox(elem_id='system', lines=1, scale=20)
108
+
109
+ def clear_record(model):
110
+ if model:
111
+ cls.clear_cache(model)
112
+ return gr.update(choices=[])
113
+ return gr.update()
114
+
115
+ clear_cache.click(clear_record, inputs=[model], outputs=[train_record])
116
+
117
+ @classmethod
118
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
119
+ cls.element('model').change(
120
+ partial(base_tab.update_input_model, arg_cls=RLHFArguments),
121
+ inputs=[cls.element('model')],
122
+ outputs=[cls.element('train_record')] + list(base_tab.valid_elements().values()))
123
+
124
+ cls.element('train_record').change(
125
+ partial(base_tab.update_all_settings, base_tab=base_tab),
126
+ inputs=[cls.element('model'), cls.element('train_record')],
127
+ outputs=list(base_tab.valid_elements().values()))
ms-swift/swift/ui/llm_train/quantization.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Quantization(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'quantization_tab': {
15
+ 'label': {
16
+ 'zh': '量化参数设置',
17
+ 'en': 'Quantization settings'
18
+ },
19
+ },
20
+ 'quant_method': {
21
+ 'label': {
22
+ 'zh': '量化方式',
23
+ 'en': 'Quantization method'
24
+ },
25
+ 'info': {
26
+ 'zh': '如果制定了量化位数,本参数默认为bnb',
27
+ 'en': 'Default is bnb if quantization_bit is specified'
28
+ }
29
+ },
30
+ 'quant_bits': {
31
+ 'label': {
32
+ 'zh': '量化bit数',
33
+ 'en': 'Quantization bit'
34
+ },
35
+ 'info': {
36
+ 'zh': '设置量化bit数, 0代表不进行量化',
37
+ 'en': 'Set the quantization bit, 0 for no quantization'
38
+ }
39
+ },
40
+ 'bnb_4bit_compute_dtype': {
41
+ 'label': {
42
+ 'zh': 'bnb_4bit_compute_dtype',
43
+ 'en': 'bnb_4bit_compute_dtype'
44
+ },
45
+ },
46
+ 'bnb_4bit_quant_type': {
47
+ 'label': {
48
+ 'zh': 'bnb_4bit_quant_type',
49
+ 'en': 'bnb_4bit_quant_type'
50
+ },
51
+ },
52
+ 'bnb_4bit_use_double_quant': {
53
+ 'label': {
54
+ 'zh': 'bnb_4bit_use_double_quant',
55
+ 'en': 'bnb_4bit_use_double_quant'
56
+ },
57
+ },
58
+ }
59
+
60
+ @classmethod
61
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
62
+ with gr.Accordion(elem_id='quantization_tab', open=False):
63
+ with gr.Row():
64
+ gr.Dropdown(elem_id='quant_bits', value=None)
65
+ gr.Dropdown(elem_id='quant_method', value=None)
66
+ gr.Dropdown(elem_id='bnb_4bit_compute_dtype', value=None)
67
+ gr.Dropdown(elem_id='bnb_4bit_quant_type', value=None)
68
+ gr.Checkbox(elem_id='bnb_4bit_use_double_quant', value=None)
ms-swift/swift/ui/llm_train/report_to.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class ReportTo(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'reporter': {
15
+ 'label': {
16
+ 'zh': '训练记录',
17
+ 'en': 'Training report'
18
+ },
19
+ },
20
+ 'report_to': {
21
+ 'label': {
22
+ 'zh': '训练记录方式',
23
+ 'en': 'Report to'
24
+ },
25
+ },
26
+ 'swanlab_token': {
27
+ 'label': {
28
+ 'zh': 'swanlab登录token',
29
+ 'en': 'The login token of swanlab'
30
+ },
31
+ },
32
+ 'swanlab_project': {
33
+ 'label': {
34
+ 'zh': 'swanlab项目名称',
35
+ 'en': 'Project of swanlab'
36
+ },
37
+ },
38
+ 'swanlab_workspace': {
39
+ 'label': {
40
+ 'zh': 'swanlab工作空间',
41
+ 'en': 'Workspace of swanlab'
42
+ },
43
+ },
44
+ 'swanlab_exp_name': {
45
+ 'label': {
46
+ 'zh': 'swanlab实验名称',
47
+ 'en': 'Experiment of swanlab'
48
+ },
49
+ },
50
+ 'swanlab_mode': {
51
+ 'label': {
52
+ 'zh': 'swanlab工作模式',
53
+ 'en': 'Work mode of swanlab'
54
+ },
55
+ },
56
+ }
57
+
58
+ @classmethod
59
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
60
+ with gr.Accordion(elem_id='reporter', open=False):
61
+ with gr.Blocks():
62
+ with gr.Row():
63
+ gr.Dropdown(
64
+ elem_id='report_to',
65
+ multiselect=True,
66
+ is_list=True,
67
+ choices=['tensorboard', 'wandb', 'swanlab'],
68
+ allow_custom_value=True,
69
+ scale=20)
70
+ gr.Textbox(elem_id='swanlab_token', lines=1, scale=20)
71
+ gr.Textbox(elem_id='swanlab_project', lines=1, scale=20)
72
+ with gr.Row():
73
+ gr.Textbox(elem_id='swanlab_workspace', lines=1, scale=20)
74
+ gr.Textbox(elem_id='swanlab_exp_name', lines=1, scale=20)
75
+ gr.Dropdown(elem_id='swanlab_mode', scale=20)
ms-swift/swift/ui/llm_train/rlhf.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Type
4
+
5
+ import gradio as gr
6
+
7
+ from swift.llm import ModelType
8
+ from swift.llm.model.register import get_all_models
9
+ from swift.ui.base import BaseUI
10
+
11
+
12
+ class RLHF(BaseUI):
13
+
14
+ group = 'llm_train'
15
+
16
+ locale_dict = {
17
+ 'rlhf_tab': {
18
+ 'label': {
19
+ 'zh': '人类对齐参数设置',
20
+ 'en': 'RLHF settings'
21
+ },
22
+ },
23
+ 'rlhf_type': {
24
+ 'label': {
25
+ 'zh': '人类对齐算法类型',
26
+ 'en': 'RLHF type'
27
+ },
28
+ },
29
+ 'ref_model_type': {
30
+ 'label': {
31
+ 'zh': '选择ref模型',
32
+ 'en': 'Select ref model'
33
+ },
34
+ 'info': {
35
+ 'zh': 'SWIFT已支持的模型名称',
36
+ 'en': 'Base model supported by SWIFT'
37
+ }
38
+ },
39
+ 'ref_model': {
40
+ 'label': {
41
+ 'zh': 'ref模型id或路径',
42
+ 'en': 'Ref model id or path'
43
+ },
44
+ 'info': {
45
+ 'zh': '实际的模型id或路径',
46
+ 'en': 'The actual model id or path'
47
+ }
48
+ },
49
+ 'beta': {
50
+ 'label': {
51
+ 'zh': 'KL正则项系数',
52
+ 'en': 'KL regression ratio'
53
+ },
54
+ },
55
+ 'rpo_alpha': {
56
+ 'label': {
57
+ 'zh': 'DPO中混合sft交叉熵的系数',
58
+ 'en': 'DPO Cross Entropy ratio'
59
+ },
60
+ },
61
+ 'simpo_gamma': {
62
+ 'label': {
63
+ 'zh': 'SimPO reward margin',
64
+ 'en': 'SimPO reward margin'
65
+ },
66
+ },
67
+ 'desirable_weight': {
68
+ 'label': {
69
+ 'zh': 'KTO符合项系数',
70
+ 'en': 'KTO desirable ratio'
71
+ },
72
+ },
73
+ 'undesirable_weight': {
74
+ 'label': {
75
+ 'zh': 'KTO不符合项系数',
76
+ 'en': 'KTO undesirable ratio'
77
+ },
78
+ }
79
+ }
80
+
81
+ @classmethod
82
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
83
+ with gr.Accordion(elem_id='rlhf_tab', open=False):
84
+ with gr.Blocks():
85
+ with gr.Row():
86
+ gr.Dropdown(elem_id='rlhf_type', value=None)
87
+ gr.Dropdown(
88
+ elem_id='ref_model', scale=20, value=None, choices=get_all_models(), allow_custom_value=True)
89
+ gr.Dropdown(elem_id='ref_model_type', choices=ModelType.get_model_name_list(), value=None, scale=20)
90
+ with gr.Row():
91
+ gr.Slider(elem_id='beta', minimum=0., maximum=5.0, step=0.1, scale=20)
92
+ gr.Slider(elem_id='rpo_alpha', minimum=0., maximum=2, step=0.1, scale=20)
93
+ gr.Slider(elem_id='simpo_gamma', minimum=0., maximum=2.0, step=0.1, scale=20)
94
+ gr.Slider(elem_id='desirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
95
+ gr.Slider(elem_id='undesirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
96
+
97
+ @classmethod
98
+ def after_build_ui(cls, base_tab: Type['BaseUI']):
99
+ cls.element('ref_model').change(
100
+ partial(cls.update_input_model, allow_keys=['ref_model_type'], has_record=False, is_ref_model=True),
101
+ inputs=[cls.element('ref_model')],
102
+ outputs=[cls.element('ref_model_type')])
ms-swift/swift/ui/llm_train/runtime.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import collections
3
+ import os.path
4
+ import sys
5
+ import time
6
+ import webbrowser
7
+ from datetime import datetime
8
+ from typing import Dict, List, Tuple, Type
9
+
10
+ import gradio as gr
11
+ import json
12
+ import matplotlib.pyplot as plt
13
+ import psutil
14
+ from packaging import version
15
+ from transformers import is_tensorboard_available
16
+
17
+ from swift.ui.base import BaseUI
18
+ from swift.ui.llm_train.utils import close_loop, run_command_in_subprocess
19
+ from swift.utils import TB_COLOR, TB_COLOR_SMOOTH, get_logger, read_tensorboard_file, tensorboard_smoothing
20
+ from swift.utils.utils import format_time
21
+
22
+ logger = get_logger()
23
+
24
+
25
+ class Runtime(BaseUI):
26
+
27
+ handlers: Dict[str, Tuple[List, Tuple]] = {}
28
+
29
+ group = 'llm_train'
30
+
31
+ all_plots = None
32
+
33
+ log_event = {}
34
+
35
+ sft_plot = [
36
+ {
37
+ 'name': 'train/loss',
38
+ 'smooth': 0.9,
39
+ },
40
+ {
41
+ 'name': 'train/acc',
42
+ 'smooth': None,
43
+ },
44
+ {
45
+ 'name': 'train/learning_rate',
46
+ 'smooth': None,
47
+ },
48
+ {
49
+ 'name': 'eval/loss',
50
+ 'smooth': 0.9,
51
+ },
52
+ {
53
+ 'name': 'eval/acc',
54
+ 'smooth': None,
55
+ },
56
+ ]
57
+
58
+ dpo_plot = [
59
+ {
60
+ 'name': 'train/loss',
61
+ 'smooth': 0.9,
62
+ },
63
+ {
64
+ 'name': 'train/rewards/accuracies',
65
+ 'smooth': None,
66
+ },
67
+ {
68
+ 'name': 'train/rewards/margins',
69
+ 'smooth': 0.9,
70
+ },
71
+ {
72
+ 'name': 'train/logps/chosen',
73
+ 'smooth': 0.9,
74
+ },
75
+ {
76
+ 'name': 'train/logps/rejected',
77
+ 'smooth': 0.9,
78
+ },
79
+ ]
80
+
81
+ kto_plot = [
82
+ {
83
+ 'name': 'kl',
84
+ 'smooth': None,
85
+ },
86
+ {
87
+ 'name': 'rewards/chosen_sum',
88
+ 'smooth': 0.9,
89
+ },
90
+ {
91
+ 'name': 'logps/chosen_sum',
92
+ 'smooth': 0.9,
93
+ },
94
+ {
95
+ 'name': 'rewards/rejected_sum',
96
+ 'smooth': 0.9,
97
+ },
98
+ {
99
+ 'name': 'logps/rejected_sum',
100
+ 'smooth': 0.9,
101
+ },
102
+ ]
103
+
104
+ orpo_plot = [
105
+ {
106
+ 'name': 'train/loss',
107
+ 'smooth': 0.9,
108
+ },
109
+ {
110
+ 'name': 'train/rewards/accuracies',
111
+ 'smooth': None,
112
+ },
113
+ {
114
+ 'name': 'train/rewards/margins',
115
+ 'smooth': 0.9,
116
+ },
117
+ {
118
+ 'name': 'train/rewards/chosen',
119
+ 'smooth': 0.9,
120
+ },
121
+ {
122
+ 'name': 'train/log_odds_ratio',
123
+ 'smooth': 0.9,
124
+ },
125
+ ]
126
+
127
+ locale_dict = {
128
+ 'runtime_tab': {
129
+ 'label': {
130
+ 'zh': '运行时',
131
+ 'en': 'Runtime'
132
+ },
133
+ },
134
+ 'tb_not_found': {
135
+ 'value': {
136
+ 'zh': 'tensorboard未安装,使用pip install tensorboard进行安装',
137
+ 'en': 'tensorboard not found, install it by pip install tensorboard',
138
+ }
139
+ },
140
+ 'running_cmd': {
141
+ 'label': {
142
+ 'zh': '运行命令',
143
+ 'en': 'Command line'
144
+ },
145
+ 'info': {
146
+ 'zh': '执行的实际命令',
147
+ 'en': 'The actual command'
148
+ }
149
+ },
150
+ 'show_log': {
151
+ 'value': {
152
+ 'zh': '展示运行状态',
153
+ 'en': 'Show running status'
154
+ },
155
+ },
156
+ 'stop_show_log': {
157
+ 'value': {
158
+ 'zh': '停止展示运行状态',
159
+ 'en': 'Stop showing running status'
160
+ },
161
+ },
162
+ 'logging_dir': {
163
+ 'label': {
164
+ 'zh': '日志路径',
165
+ 'en': 'Logging dir'
166
+ },
167
+ 'info': {
168
+ 'zh': '支持手动传入文件路径',
169
+ 'en': 'Support fill custom path in'
170
+ }
171
+ },
172
+ 'log': {
173
+ 'label': {
174
+ 'zh': '日志输出',
175
+ 'en': 'Logging content'
176
+ },
177
+ 'info': {
178
+ 'zh': '如果日志无更新请再次点击"展示日志内容"',
179
+ 'en': 'Please press "Show log" if the log content is not updating'
180
+ }
181
+ },
182
+ 'running_tasks': {
183
+ 'label': {
184
+ 'zh': '运行中任务',
185
+ 'en': 'Running Tasks'
186
+ },
187
+ 'info': {
188
+ 'zh': '运行中的任务(所有的swift sft命令)',
189
+ 'en': 'All running tasks(started by swift sft)'
190
+ }
191
+ },
192
+ 'refresh_tasks': {
193
+ 'value': {
194
+ 'zh': '找回运行时任务',
195
+ 'en': 'Find running tasks'
196
+ },
197
+ },
198
+ 'kill_task': {
199
+ 'value': {
200
+ 'zh': '杀死任务',
201
+ 'en': 'Kill running task'
202
+ },
203
+ },
204
+ 'tb_url': {
205
+ 'label': {
206
+ 'zh': 'Tensorboard链接',
207
+ 'en': 'Tensorboard URL'
208
+ },
209
+ 'info': {
210
+ 'zh': '仅展示,不可编辑',
211
+ 'en': 'Not editable'
212
+ }
213
+ },
214
+ 'start_tb': {
215
+ 'value': {
216
+ 'zh': '打开TensorBoard',
217
+ 'en': 'Start TensorBoard'
218
+ },
219
+ },
220
+ 'close_tb': {
221
+ 'value': {
222
+ 'zh': '关闭TensorBoard',
223
+ 'en': 'Close TensorBoard'
224
+ },
225
+ },
226
+ }
227
+
228
+ @classmethod
229
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
230
+ with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
231
+ with gr.Blocks():
232
+ with gr.Row():
233
+ gr.Textbox(elem_id='running_cmd', lines=1, scale=20, interactive=False, max_lines=1)
234
+ gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1)
235
+ gr.Button(elem_id='show_log', scale=2, variant='primary')
236
+ gr.Button(elem_id='stop_show_log', scale=2)
237
+ gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1)
238
+ gr.Button(elem_id='start_tb', scale=2, variant='primary')
239
+ gr.Button(elem_id='close_tb', scale=2)
240
+ with gr.Row():
241
+ gr.Textbox(elem_id='log', lines=6, visible=False)
242
+ with gr.Row():
243
+ gr.Dropdown(elem_id='running_tasks', scale=10)
244
+ gr.Button(elem_id='refresh_tasks', scale=1)
245
+ gr.Button(elem_id='kill_task', scale=1)
246
+
247
+ with gr.Row():
248
+ cls.all_plots = []
249
+ for idx, k in enumerate(Runtime.sft_plot):
250
+ name = k['name']
251
+ cls.all_plots.append(gr.Plot(elem_id=str(idx), label=name))
252
+
253
+ concurrency_limit = {}
254
+ if version.parse(gr.__version__) >= version.parse('4.0.0'):
255
+ concurrency_limit = {'concurrency_limit': 5}
256
+ base_tab.element('show_log').click(
257
+ Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then(
258
+ Runtime.wait, [base_tab.element('logging_dir'),
259
+ base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots,
260
+ **concurrency_limit)
261
+
262
+ base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])
263
+
264
+ base_tab.element('start_tb').click(
265
+ Runtime.start_tb,
266
+ [base_tab.element('logging_dir')],
267
+ [base_tab.element('tb_url')],
268
+ )
269
+
270
+ base_tab.element('close_tb').click(
271
+ Runtime.close_tb,
272
+ [base_tab.element('logging_dir')],
273
+ [],
274
+ )
275
+
276
+ base_tab.element('refresh_tasks').click(
277
+ Runtime.refresh_tasks,
278
+ [base_tab.element('running_tasks')],
279
+ [base_tab.element('running_tasks')],
280
+ )
281
+
282
+ @classmethod
283
+ def get_plot(cls, task):
284
+ if not task or 'swift sft' in task or 'swift pt' in task:
285
+ return cls.sft_plot
286
+
287
+ args: dict = cls.parse_info_from_cmdline(task)[1]
288
+ train_type = args.get('rlhf_type', 'dpo')
289
+ if train_type in ('dpo', 'cpo', 'simpo'):
290
+ return cls.dpo_plot
291
+ elif train_type == 'kto':
292
+ return cls.kto_plot
293
+ elif train_type == 'orpo':
294
+ return cls.orpo_plot
295
+
296
+ @classmethod
297
+ def update_log(cls, task):
298
+ ret = [gr.update(visible=True)]
299
+ plot = Runtime.get_plot(task)
300
+ for i in range(len(plot)):
301
+ p = plot[i]
302
+ ret.append(gr.update(visible=True, label=p['name']))
303
+ return ret
304
+
305
+ @classmethod
306
+ def get_initial(cls, line):
307
+ tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:']
308
+ for start in tqdm_starts:
309
+ if line.startswith(start):
310
+ return start
311
+ return None
312
+
313
+ @classmethod
314
+ def wait(cls, logging_dir, task):
315
+ if not logging_dir:
316
+ return [None] + Runtime.plot(task)
317
+ log_file = os.path.join(logging_dir, 'run.log')
318
+ cls.log_event[logging_dir] = False
319
+ offset = 0
320
+ latest_data = ''
321
+ lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
322
+ try:
323
+ with open(log_file, 'r', encoding='utf-8') as input:
324
+ input.seek(offset)
325
+ fail_cnt = 0
326
+ while True:
327
+ try:
328
+ latest_data += input.read()
329
+ except UnicodeDecodeError:
330
+ continue
331
+ if not latest_data:
332
+ time.sleep(0.5)
333
+ fail_cnt += 1
334
+ if fail_cnt > 50:
335
+ break
336
+
337
+ if cls.log_event.get(logging_dir, False):
338
+ cls.log_event[logging_dir] = False
339
+ break
340
+
341
+ if '\n' not in latest_data:
342
+ continue
343
+ latest_lines = latest_data.split('\n')
344
+ if latest_data[-1] != '\n':
345
+ latest_data = latest_lines[-1]
346
+ latest_lines = latest_lines[:-1]
347
+ else:
348
+ latest_data = ''
349
+ lines.extend(latest_lines)
350
+ start = cls.get_initial(lines[-1])
351
+ if start:
352
+ i = len(lines) - 2
353
+ while i >= 0:
354
+ if lines[i].startswith(start):
355
+ del lines[i]
356
+ i -= 1
357
+ else:
358
+ break
359
+ yield ['\n'.join(lines)] + Runtime.plot(task)
360
+ except IOError:
361
+ pass
362
+
363
+ @classmethod
364
+ def break_log_event(cls, task):
365
+ if not task:
366
+ return
367
+ pid, all_args = Runtime.parse_info_from_cmdline(task)
368
+ cls.log_event[all_args['logging_dir']] = True
369
+
370
+ @classmethod
371
+ def show_log(cls, logging_dir):
372
+ webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2)
373
+
374
+ @classmethod
375
+ def start_tb(cls, logging_dir):
376
+ if not is_tensorboard_available():
377
+ gr.Error(cls.locale('tb_not_found', cls.lang)['value'])
378
+ return ''
379
+
380
+ logging_dir = logging_dir.strip()
381
+ logging_dir = logging_dir if not logging_dir.endswith(os.sep) else logging_dir[:-1]
382
+ if logging_dir in cls.handlers:
383
+ return cls.handlers[logging_dir][1]
384
+
385
+ handler, lines = run_command_in_subprocess('tensorboard', '--logdir', logging_dir, timeout=2)
386
+ localhost_addr = ''
387
+ for line in lines:
388
+ if 'http://localhost:' in line:
389
+ line = line[line.index('http://localhost:'):]
390
+ localhost_addr = line[:line.index(' ')]
391
+ cls.handlers[logging_dir] = (handler, localhost_addr)
392
+ logger.info('===========Tensorboard Log============')
393
+ logger.info('\n'.join(lines))
394
+ webbrowser.open(localhost_addr, new=2)
395
+ return localhost_addr
396
+
397
+ @staticmethod
398
+ def close_tb(logging_dir):
399
+ if logging_dir in Runtime.handlers:
400
+ close_loop(Runtime.handlers[logging_dir][0])
401
+ Runtime.handlers.pop(logging_dir)
402
+
403
+ @staticmethod
404
+ def refresh_tasks(running_task=None):
405
+ output_dir = running_task if not running_task or 'pid:' not in running_task else None
406
+ process_name = 'swift'
407
+ negative_name = 'swift.exe'
408
+ cmd_name = ['pt', 'sft', 'rlhf']
409
+ process = []
410
+ selected = None
411
+ for proc in psutil.process_iter():
412
+ try:
413
+ cmdlines = proc.cmdline()
414
+ except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
415
+ cmdlines = []
416
+ if any([process_name in cmdline
417
+ for cmdline in cmdlines]) and not any([negative_name in cmdline
418
+ for cmdline in cmdlines]) and any( # noqa
419
+ [cmdline in cmd_name for cmdline in cmdlines]): # noqa
420
+ process.append(Runtime.construct_running_task(proc))
421
+ if output_dir is not None and any( # noqa
422
+ [output_dir == cmdline for cmdline in cmdlines]): # noqa
423
+ selected = Runtime.construct_running_task(proc)
424
+ if not selected:
425
+ if running_task and running_task in process:
426
+ selected = running_task
427
+ if not selected and process:
428
+ selected = process[0]
429
+ return gr.update(choices=process, value=selected)
430
+
431
+ @staticmethod
432
+ def construct_running_task(proc):
433
+ pid = proc.pid
434
+ ts = time.time()
435
+ create_time = proc.create_time()
436
+ create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
437
+
438
+ return f'pid:{pid}/create:{create_time_formatted}' \
439
+ f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}'
440
+
441
+ @staticmethod
442
+ def parse_info_from_cmdline(task):
443
+ pid = None
444
+ if '/cmd:' in task:
445
+ for i in range(3):
446
+ slash = task.find('/')
447
+ if i == 0:
448
+ pid = task[:slash].split(':')[1]
449
+ task = task[slash + 1:]
450
+ if 'swift sft' in task:
451
+ args = task.split('swift sft')[1]
452
+ elif 'swift pt' in task:
453
+ args = task.split('swift pt')[1]
454
+ elif 'swift rlhf' in task:
455
+ args = task.split('swift rlhf')[1]
456
+ else:
457
+ raise ValueError(f'Cannot parse cmd line: {task}')
458
+ args = [arg.strip() for arg in args.split('--') if arg.strip()]
459
+ all_args = {}
460
+ for i in range(len(args)):
461
+ space = args[i].find(' ')
462
+ splits = args[i][:space], args[i][space + 1:]
463
+ all_args[splits[0]] = splits[1]
464
+
465
+ output_dir = all_args['output_dir']
466
+ if os.path.exists(os.path.join(output_dir, 'args.json')):
467
+ with open(os.path.join(output_dir, 'args.json'), 'r', encoding='utf-8') as f:
468
+ _json = json.load(f)
469
+ for key in all_args.keys():
470
+ all_args[key] = _json.get(key)
471
+ if isinstance(all_args[key], list):
472
+ if any([' ' in value for value in all_args[key]]):
473
+ all_args[key] = [f'"{value}"' for value in all_args[key]]
474
+ all_args[key] = ' '.join(all_args[key])
475
+ return pid, all_args
476
+
477
+ @staticmethod
478
+ def kill_task(task):
479
+ if task:
480
+ pid, all_args = Runtime.parse_info_from_cmdline(task)
481
+ output_dir = all_args['output_dir']
482
+ if sys.platform == 'win32':
483
+ os.system(f'taskkill /f /t /pid "{pid}"')
484
+ else:
485
+ os.system(f'pkill -9 -f {output_dir}')
486
+ time.sleep(1)
487
+ Runtime.break_log_event(task)
488
+ return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)
489
+
490
+ @staticmethod
491
+ def reset():
492
+ return None, 'output'
493
+
494
+ @staticmethod
495
+ def task_changed(task, base_tab):
496
+ if task:
497
+ _, all_args = Runtime.parse_info_from_cmdline(task)
498
+ else:
499
+ all_args = {}
500
+ elements = list(base_tab.valid_elements().values())
501
+ ret = []
502
+ for e in elements:
503
+ if e.elem_id in all_args:
504
+ if isinstance(e, gr.Dropdown) and e.multiselect:
505
+ arg = all_args[e.elem_id].split(' ')
506
+ else:
507
+ arg = all_args[e.elem_id]
508
+ ret.append(gr.update(value=arg))
509
+ else:
510
+ ret.append(gr.update())
511
+ Runtime.break_log_event(task)
512
+ return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)
513
+
514
+ @staticmethod
515
+ def plot(task):
516
+ plot = Runtime.get_plot(task)
517
+ if not task:
518
+ return [None] * len(plot)
519
+ _, all_args = Runtime.parse_info_from_cmdline(task)
520
+ tb_dir = all_args['logging_dir']
521
+ if not os.path.exists(tb_dir):
522
+ return [None] * len(plot)
523
+ fname = [
524
+ fname for fname in os.listdir(tb_dir)
525
+ if os.path.isfile(os.path.join(tb_dir, fname)) and fname.startswith('events.out')
526
+ ]
527
+ if fname:
528
+ fname = fname[0]
529
+ else:
530
+ return [None] * len(plot)
531
+ tb_path = os.path.join(tb_dir, fname)
532
+ data = read_tensorboard_file(tb_path)
533
+
534
+ plots = []
535
+ for k in plot:
536
+ name = k['name']
537
+ smooth = k['smooth']
538
+ if name == 'train/acc':
539
+ if 'train/token_acc' in data:
540
+ name = 'train/token_acc'
541
+ if 'train/seq_acc' in data:
542
+ name = 'train/seq_acc'
543
+ if name == 'eval/acc':
544
+ if 'eval/token_acc' in data:
545
+ name = 'eval/token_acc'
546
+ if 'eval/seq_acc' in data:
547
+ name = 'eval/seq_acc'
548
+ if name not in data:
549
+ plots.append(None)
550
+ continue
551
+ _data = data[name]
552
+ steps = [d['step'] for d in _data]
553
+ values = [d['value'] for d in _data]
554
+ if len(values) == 0:
555
+ continue
556
+
557
+ plt.close('all')
558
+ fig = plt.figure()
559
+ ax = fig.add_subplot()
560
+ # _, ax = plt.subplots(1, 1, squeeze=True, figsize=(8, 5), dpi=100)
561
+ ax.set_title(name)
562
+ if len(values) == 1:
563
+ ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
564
+ elif smooth is not None:
565
+ ax.plot(steps, values, color=TB_COLOR)
566
+ values_s = tensorboard_smoothing(values, smooth)
567
+ ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
568
+ else:
569
+ ax.plot(steps, values, color=TB_COLOR_SMOOTH)
570
+ plots.append(fig)
571
+ return plots
ms-swift/swift/ui/llm_train/save.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Save(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'save_param': {
15
+ 'label': {
16
+ 'zh': '存储参数设置',
17
+ 'en': 'Saving settings'
18
+ },
19
+ },
20
+ 'push_to_hub': {
21
+ 'label': {
22
+ 'zh': '推送魔搭Hub',
23
+ 'en': 'Push to modelscope hub',
24
+ },
25
+ 'info': {
26
+ 'zh': '是否推送魔搭的模型库',
27
+ 'en': 'Whether push the output model to modelscope hub',
28
+ }
29
+ },
30
+ 'hub_model_id': {
31
+ 'label': {
32
+ 'zh': '魔搭模型id',
33
+ 'en': 'The model-id in modelscope',
34
+ },
35
+ 'info': {
36
+ 'zh': '设置魔搭的模型id',
37
+ 'en': 'Set the model-id of modelscope',
38
+ }
39
+ },
40
+ 'hub_private_repo': {
41
+ 'label': {
42
+ 'zh': '设置仓库私有',
43
+ 'en': 'Model is private',
44
+ },
45
+ 'info': {
46
+ 'zh': '以私有方式推送魔搭hub',
47
+ 'en': 'Set the model as private',
48
+ }
49
+ },
50
+ 'hub_strategy': {
51
+ 'label': {
52
+ 'zh': '推送策略',
53
+ 'en': 'Push strategy',
54
+ },
55
+ 'info': {
56
+ 'zh': '设置模型推送策略',
57
+ 'en': 'Set the push strategy',
58
+ }
59
+ },
60
+ 'hub_token': {
61
+ 'label': {
62
+ 'zh': '仓库token',
63
+ 'en': 'The hub token',
64
+ },
65
+ 'info': {
66
+ 'zh': '该token可以在www.modelscope.cn找到',
67
+ 'en': 'Find the token in www.modelscope.cn',
68
+ }
69
+ }
70
+ }
71
+
72
+ @classmethod
73
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
74
+ with gr.Accordion(elem_id='save_param', open=False):
75
+ with gr.Blocks():
76
+ with gr.Row():
77
+ gr.Checkbox(elem_id='push_to_hub', scale=20)
78
+ gr.Textbox(elem_id='hub_model_id', lines=1, scale=20)
79
+ gr.Checkbox(elem_id='hub_private_repo', scale=20)
80
+ gr.Dropdown(
81
+ elem_id='hub_strategy',
82
+ scale=20,
83
+ choices=['end', 'every_save', 'checkpoint', 'all_checkpoints'])
84
+ gr.Textbox(elem_id='hub_token', lines=1, scale=20)
ms-swift/swift/ui/llm_train/self_cog.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class SelfCog(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'self_cognition': {
15
+ 'label': {
16
+ 'zh': '自我认知任务参数设置',
17
+ 'en': 'Self cognition settings'
18
+ },
19
+ },
20
+ 'self_cognition_sample': {
21
+ 'label': {
22
+ 'zh': '数据及采样条数',
23
+ 'en': 'Dataset sample size'
24
+ },
25
+ 'info': {
26
+ 'zh': '设置数据集采样的条数',
27
+ 'en': 'Set the dataset sample size'
28
+ }
29
+ },
30
+ 'model_name': {
31
+ 'label': {
32
+ 'zh': '模型认知名称',
33
+ 'en': 'Model name'
34
+ },
35
+ 'info': {
36
+ 'zh': '设置模型应当认知自己的名字, 格式为:中文名字 英文名字,中间以空格分隔',
37
+ 'en': 'Set the name of the model think itself of, the format is Chinesename Englishname, split by space'
38
+ }
39
+ },
40
+ 'model_author': {
41
+ 'label': {
42
+ 'zh': '模型作者',
43
+ 'en': 'Model author'
44
+ },
45
+ 'info': {
46
+ 'zh': '设置模型认知的自己的作者, 格式为:中文作者 英文作者,中间以空格分隔',
47
+ 'en': 'Set the author of the model, the format is Chineseauthor Englishauthor, split by space'
48
+ }
49
+ },
50
+ }
51
+
52
+ @classmethod
53
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
54
+ with gr.Accordion(elem_id='self_cognition', open=False):
55
+ with gr.Row():
56
+ gr.Textbox(elem_id='model_name', scale=20, is_list=True)
57
+ gr.Textbox(elem_id='model_author', scale=20, is_list=True)
ms-swift/swift/utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from .env import (get_dist_setting, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist,
4
+ is_dist_ta, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph,
5
+ use_hf_hub, use_torchacc)
6
+ from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_swanlab_available,
7
+ is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available,
8
+ is_xtuner_available)
9
+ from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl
10
+ from .logger import get_logger
11
+ from .np_utils import get_seed, stat_array, transform_jsonl_to_df
12
+ from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
13
+ from .torch_utils import (Serializer, activate_parameters, find_all_linears, find_embedding, find_layers, find_norm,
14
+ freeze_parameters, gc_collect, get_current_device, get_device, get_device_count,
15
+ get_model_parameter_info, get_n_params_grads, init_process_group, safe_ddp_context,
16
+ set_default_ddp_config, set_device, show_layers, time_synchronize)
17
+ from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
18
+ get_env_args, import_external_file, lower_bound, parse_args, patch_getattr, read_multi_line,
19
+ seed_everything, split_list, subprocess_run, test_time, upper_bound)
ms-swift/swift/utils/__pycache__/np_utils.cpython-310.pyc ADDED
Binary file (1.56 kB). View file
 
ms-swift/swift/utils/constants.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ BIN_EXTENSIONS = [
4
+ '.*.bin',
5
+ '.*.ts',
6
+ '.*.pt',
7
+ '.*.data-00000-of-00001',
8
+ '.*.onnx',
9
+ '.*.meta',
10
+ '.*.pb',
11
+ '.*.index',
12
+ ]
13
+
14
+ PEFT_TYPE_KEY = 'peft_type'
15
+ SWIFT_TYPE_KEY = 'swift_type'
16
+ DEFAULT_ADAPTER = 'default'
17
+
18
+
19
+ class Invoke(object):
20
+ KEY = 'invoked_by'
21
+ THIRD_PARTY = 'third_party'
22
+ PRETRAINED = 'from_pretrained'
23
+ PIPELINE = 'pipeline'
24
+ TRAINER = 'trainer'
25
+ LOCAL_TRAINER = 'local_trainer'
26
+ PREPROCESSOR = 'preprocessor'
27
+ SWIFT = 'swift'
ms-swift/swift/utils/logger.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import importlib.util
3
+ import logging
4
+ import os
5
+ from contextlib import contextmanager
6
+ from types import MethodType
7
+ from typing import Optional
8
+
9
+ from modelscope.utils.logger import get_logger as get_ms_logger
10
+
11
+
12
+ # Avoid circular reference
13
+ def _is_local_master():
14
+ local_rank = int(os.getenv('LOCAL_RANK', -1))
15
+ return local_rank in {-1, 0}
16
+
17
+
18
+ init_loggers = {}
19
+
20
+ # old format
21
+ # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
+ logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s')
23
+
24
+ info_set = set()
25
+ warning_set = set()
26
+
27
+
28
+ def info_once(self, msg, *args, **kwargs):
29
+ hash_id = kwargs.get('hash_id') or msg
30
+ if hash_id in info_set:
31
+ return
32
+ info_set.add(hash_id)
33
+ self.info(msg)
34
+
35
+
36
+ def warning_once(self, msg, *args, **kwargs):
37
+ hash_id = kwargs.get('hash_id') or msg
38
+ if hash_id in warning_set:
39
+ return
40
+ warning_set.add(hash_id)
41
+ self.warning(msg)
42
+
43
+
44
+ def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'):
45
+ """ Get logging logger
46
+
47
+ Args:
48
+ log_file: Log filename, if specified, file handler will be added to
49
+ logger
50
+ log_level: Logging level.
51
+ file_mode: Specifies the mode to open the file, if filename is
52
+ specified (if filemode is unspecified, it defaults to 'w').
53
+ """
54
+ if log_level is None:
55
+ log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
56
+ log_level = getattr(logging, log_level, logging.INFO)
57
+ logger_name = __name__.split('.')[0]
58
+ logger = logging.getLogger(logger_name)
59
+ logger.propagate = False
60
+ if logger_name in init_loggers:
61
+ add_file_handler_if_needed(logger, log_file, file_mode, log_level)
62
+ return logger
63
+
64
+ # handle duplicate logs to the console
65
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
66
+ # to the root logger. As logger.propagate is True by default, this root
67
+ # level handler causes logging messages from rank>0 processes to
68
+ # unexpectedly show up on the console, creating much unwanted clutter.
69
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
70
+ # at the ERROR level.
71
+ for handler in logger.root.handlers:
72
+ if type(handler) is logging.StreamHandler:
73
+ handler.setLevel(logging.ERROR)
74
+
75
+ stream_handler = logging.StreamHandler()
76
+ handlers = [stream_handler]
77
+
78
+ is_worker0 = _is_local_master()
79
+
80
+ if is_worker0 and log_file is not None:
81
+ file_handler = logging.FileHandler(log_file, file_mode)
82
+ handlers.append(file_handler)
83
+
84
+ for handler in handlers:
85
+ handler.setFormatter(logger_format)
86
+ handler.setLevel(log_level)
87
+ logger.addHandler(handler)
88
+
89
+ if is_worker0:
90
+ logger.setLevel(log_level)
91
+ else:
92
+ logger.setLevel(logging.ERROR)
93
+
94
+ init_loggers[logger_name] = True
95
+
96
+ logger.info_once = MethodType(info_once, logger)
97
+ logger.warning_once = MethodType(warning_once, logger)
98
+ return logger
99
+
100
+
101
+ logger = get_logger()
102
+ ms_logger = get_ms_logger()
103
+
104
+ logger.handlers[0].setFormatter(logger_format)
105
+ ms_logger.handlers[0].setFormatter(logger_format)
106
+ log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
107
+ if _is_local_master():
108
+ ms_logger.setLevel(log_level)
109
+ else:
110
+ ms_logger.setLevel(logging.ERROR)
111
+
112
+
113
+ @contextmanager
114
+ def ms_logger_ignore_error():
115
+ ms_logger = get_ms_logger()
116
+ origin_log_level = ms_logger.level
117
+ ms_logger.setLevel(logging.CRITICAL)
118
+ try:
119
+ yield
120
+ finally:
121
+ ms_logger.setLevel(origin_log_level)
122
+
123
+
124
+ def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
125
+ for handler in logger.handlers:
126
+ if isinstance(handler, logging.FileHandler):
127
+ return
128
+
129
+ if importlib.util.find_spec('torch') is not None:
130
+ is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0}
131
+ else:
132
+ is_worker0 = True
133
+
134
+ if is_worker0 and log_file is not None:
135
+ file_handler = logging.FileHandler(log_file, file_mode)
136
+ file_handler.setFormatter(logger_format)
137
+ file_handler.setLevel(log_level)
138
+ logger.addHandler(file_handler)
ms-swift/swift/utils/tb_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import os
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import matplotlib.pyplot as plt
7
+ from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
8
+
9
+ Item = Dict[str, float]
10
+ TB_COLOR, TB_COLOR_SMOOTH = '#FFE2D9', '#FF7043'
11
+
12
+
13
+ def read_tensorboard_file(fpath: str) -> Dict[str, List[Item]]:
14
+ if not os.path.isfile(fpath):
15
+ raise FileNotFoundError(f'fpath: {fpath}')
16
+ ea = EventAccumulator(fpath)
17
+ ea.Reload()
18
+ res: Dict[str, List[Item]] = {}
19
+ tags = ea.Tags()['scalars']
20
+ for tag in tags:
21
+ values = ea.Scalars(tag)
22
+ r: List[Item] = []
23
+ for v in values:
24
+ r.append({'step': v.step, 'value': v.value})
25
+ res[tag] = r
26
+ return res
27
+
28
+
29
+ def tensorboard_smoothing(values: List[float], smooth: float = 0.9) -> List[float]:
30
+ norm_factor = 0
31
+ x = 0
32
+ res: List[float] = []
33
+ for i in range(len(values)):
34
+ x = x * smooth + values[i] # Exponential decay
35
+ norm_factor *= smooth
36
+ norm_factor += 1
37
+ res.append(x / norm_factor)
38
+ return res
39
+
40
+
41
+ def plot_images(images_dir: str,
42
+ tb_dir: str,
43
+ smooth_key: Optional[List[str]] = None,
44
+ smooth_val: float = 0.9,
45
+ figsize: Tuple[int, int] = (8, 5),
46
+ dpi: int = 100) -> None:
47
+ """Using tensorboard's data content to plot images"""
48
+ smooth_key = smooth_key or []
49
+ os.makedirs(images_dir, exist_ok=True)
50
+ fname = [fname for fname in os.listdir(tb_dir) if os.path.isfile(os.path.join(tb_dir, fname))][0]
51
+ tb_path = os.path.join(tb_dir, fname)
52
+ data = read_tensorboard_file(tb_path)
53
+
54
+ for k in data.keys():
55
+ _data = data[k]
56
+ steps = [d['step'] for d in _data]
57
+ values = [d['value'] for d in _data]
58
+ if len(values) == 0:
59
+ continue
60
+ _, ax = plt.subplots(1, 1, squeeze=True, figsize=figsize, dpi=dpi)
61
+ ax.set_title(k)
62
+ if len(values) == 1:
63
+ ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
64
+ elif k in smooth_key:
65
+ ax.plot(steps, values, color=TB_COLOR)
66
+ values_s = tensorboard_smoothing(values, smooth_val)
67
+ ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
68
+ else:
69
+ ax.plot(steps, values, color=TB_COLOR_SMOOTH)
70
+ fpath = os.path.join(images_dir, k.replace('/', '_').replace('.', '_'))
71
+ plt.savefig(fpath, dpi=dpi, bbox_inches='tight')
72
+ plt.close()
ms-swift/swift/utils/torch_utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import gc
3
+ import hashlib
4
+ import os
5
+ import pickle
6
+ import re
7
+ import time
8
+ import uuid
9
+ from bisect import bisect_right
10
+ from contextlib import contextmanager, nullcontext
11
+ from typing import Callable, Dict, List, Optional, Tuple, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ from datasets.utils.filelock import FileLock
18
+ from modelscope.hub.utils.utils import get_cache_dir
19
+ from transformers.integrations import is_deepspeed_zero3_enabled
20
+ from transformers.utils import is_torch_cuda_available, is_torch_mps_available, is_torch_npu_available
21
+
22
+ from .env import get_dist_setting, is_dist, is_dist_ta, is_local_master, is_master
23
+ from .logger import get_logger
24
+ from .utils import deep_getattr
25
+
26
+ logger = get_logger()
27
+
28
+
29
+ def _find_local_mac() -> str:
30
+ mac = uuid.getnode()
31
+ mac_address = ':'.join(('%012x' % mac)[i:i + 2] for i in range(0, 12, 2))
32
+ return mac_address
33
+
34
+
35
+ def get_n_params_grads(model) -> Tuple[List[int], List[int]]:
36
+ n_params, n_grads = [], []
37
+ for p in model.parameters():
38
+ if is_deepspeed_zero3_enabled():
39
+ import deepspeed
40
+ context = deepspeed.zero.GatheredParameters(p)
41
+ else:
42
+ context = nullcontext()
43
+ with context:
44
+ n_params.append(p.numel())
45
+ n_grads.append(p.numel() if p.requires_grad else 0)
46
+ return n_params, n_grads
47
+
48
+
49
+ def get_model_parameter_info(model: nn.Module, name: Optional[str] = None) -> str:
50
+ n_params, n_grads = get_n_params_grads(model)
51
+ n_params = sum(n_params)
52
+ n_grads = sum(n_grads)
53
+ n_buffers = sum(p.numel() for p in model.buffers())
54
+
55
+ if name is None:
56
+ name = model.__class__.__name__
57
+
58
+ n_params /= 1e6
59
+ n_grads /= 1e6
60
+ n_buffers /= 1e6
61
+ s = (f'{name}: '
62
+ f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable '
63
+ f'[{100 * n_grads / n_params:.4f}%]), '
64
+ f'{n_buffers:.4f}M Buffers.')
65
+ return s
66
+
67
+
68
+ def find_sub_module(module: torch.nn.Module, module_name: str) -> List[torch.nn.Module]:
69
+ _modules = list()
70
+ for name, sub_module in module.named_modules():
71
+ if not name:
72
+ continue
73
+ if name.endswith(module_name):
74
+ _modules.append(sub_module)
75
+ return _modules
76
+
77
+
78
+ def show_layers(model: nn.Module, max_lines: Optional[int] = 20) -> None:
79
+ named_p = list(model.named_parameters())
80
+ for i, (n, p) in enumerate(named_p):
81
+ if max_lines is not None and i >= max_lines:
82
+ logger.info('...')
83
+ break
84
+ logger.info(f'[{n}]: requires_grad={p.requires_grad}, dtype={p.dtype}, device={p.device}')
85
+
86
+
87
+ def freeze_parameters(model: nn.Module,
88
+ freeze_parameters_ratio: float,
89
+ freeze_parameters: List[str],
90
+ freeze_parameters_regex: Optional[str] = None) -> None:
91
+ if freeze_parameters_ratio > 0:
92
+ n_parameters = get_n_params_grads(model)[0]
93
+ n_parameters = np.array(n_parameters, dtype=np.int64)
94
+ n_freeze_parameters = int(np.sum(n_parameters) * freeze_parameters_ratio)
95
+ n_parameters_cs = np.cumsum(n_parameters)
96
+ idx = bisect_right(n_parameters_cs, n_freeze_parameters)
97
+ for _, p in zip(range(idx), model.parameters()):
98
+ p.requires_grad = False
99
+
100
+ if len(freeze_parameters) > 0:
101
+ for n, p in model.named_parameters():
102
+ for freeze_p in freeze_parameters:
103
+ if n.startswith(freeze_p):
104
+ p.requires_grad = False
105
+
106
+ if freeze_parameters_regex is not None:
107
+ try:
108
+ pattern = re.compile(freeze_parameters_regex)
109
+ except re.error as e:
110
+ logger.warning(f"Invalid freeze_parameters_regex '{freeze_parameters_regex}': {e}")
111
+ return
112
+
113
+ for n, p in model.named_parameters():
114
+ if pattern.search(n):
115
+ p.requires_grad = False
116
+
117
+
118
+ def activate_parameters(model: nn.Module,
119
+ additional_trainable_parameters: List[str],
120
+ trainable_parameters_regex: Optional[str] = None) -> None:
121
+ has_activate = False
122
+ if len(additional_trainable_parameters) > 0:
123
+ for n, p in model.named_parameters():
124
+ for additional_tp in additional_trainable_parameters:
125
+ if n.startswith(additional_tp):
126
+ p.requires_grad = True
127
+ has_activate = True
128
+ if not has_activate:
129
+ logger.warning('len(additional_trainable_parameters) > 0 but no parameters are activated. '
130
+ f'additional_trainable_parameters: {additional_trainable_parameters}')
131
+
132
+ has_activate = False
133
+ if trainable_parameters_regex is not None:
134
+ try:
135
+ pattern = re.compile(trainable_parameters_regex)
136
+ except re.error as e:
137
+ logger.warning(f"Invalid trainable_parameters_regex '{trainable_parameters_regex}': {e}")
138
+ return
139
+
140
+ for n, p in model.named_parameters():
141
+ if pattern.search(n):
142
+ p.requires_grad = True
143
+ has_activate = True
144
+
145
+ if not has_activate:
146
+ logger.warning('trainable_parameters_regex is provided but no parameters are activated. '
147
+ f'trainable_parameters_regex: {trainable_parameters_regex}')
148
+
149
+
150
+ def time_synchronize() -> float:
151
+ torch.cuda.synchronize()
152
+ return time.perf_counter() # second
153
+
154
+
155
+ def _get_max_memory(device_ids: List[int]) -> Dict[Union[int, str], int]:
156
+ """add feat in accelerate to support MP + DDP"""
157
+ import psutil
158
+ # Make sure CUDA is initialized on each GPU to have the right memory info.
159
+ for i in device_ids:
160
+ _ = torch.tensor([0], device=i)
161
+
162
+ device_ids_set = set(device_ids)
163
+ max_memory = {}
164
+ for i in range(get_device_count()):
165
+ max_memory[i] = 0
166
+ if i in device_ids_set:
167
+ max_memory[i] = torch.cuda.mem_get_info(i)[0]
168
+ max_memory['cpu'] = psutil.virtual_memory().available
169
+ return max_memory
170
+
171
+
172
+ def _sync_max_memory(max_memory: Dict[Union[int, str], int]) -> Dict[Union[int, str], int]:
173
+ """Make sure that the model structure of MP(device_map) is the same, when using DDP."""
174
+ max_memory_list = [v for k, v in max_memory.items() if (v > 0 and k != 'cpu')]
175
+ _, local_rank, world_size, _ = get_dist_setting()
176
+ src_tensor = torch.tensor(max_memory_list).to(local_rank)
177
+ tgt_tensor_list = [torch.zeros_like(src_tensor) for _ in range(world_size)]
178
+ dist.all_gather(tgt_tensor_list, src_tensor)
179
+ tgt_tensor = torch.stack(tgt_tensor_list, dim=0)
180
+ new_max_memory_iter = iter(tgt_tensor.min(dim=0)[0].tolist())
181
+ new_max_memory = {}
182
+ for k, v in max_memory.items():
183
+ new_max_memory[k] = v
184
+ if v > 0 and k != 'cpu':
185
+ new_max_memory[k] = next(new_max_memory_iter)
186
+ return new_max_memory
187
+
188
+
189
+ def find_layers(
190
+ model: nn.Module,
191
+ cond: Callable[[str, nn.Module], bool],
192
+ sub_module: Optional[str] = None,
193
+ min_name_len: Optional[int] = None,
194
+ ) -> List[str]:
195
+ # The content of target_module_names cannot exist in inner_nodes.
196
+ sub_module_str = sub_module
197
+ if sub_module is None:
198
+ sub_module = model
199
+ else:
200
+ sub_module = deep_getattr(model, sub_module)
201
+ inner_nodes = set()
202
+ for name, module in model.named_modules():
203
+ name = re.sub(r'\d+\.', '{}.', name)
204
+ if not cond(name, module):
205
+ inner_nodes.add(name)
206
+ target_module_names = set()
207
+ for name, module in sub_module.named_modules():
208
+ if sub_module_str:
209
+ name = f'{sub_module_str}.{name}' if name else sub_module_str
210
+ if cond(name, module):
211
+ module_name_list = name.split('.')
212
+ module_name = module_name_list.pop()
213
+ i = 1
214
+ for inner_node in inner_nodes:
215
+ while module_name_list and inner_node.endswith(re.sub(
216
+ r'\d+\.', '{}.', module_name)) or min_name_len and i < min_name_len:
217
+ module_name = f'{module_name_list.pop()}.{module_name}'
218
+ i += 1
219
+ target_module_names.add(module_name)
220
+ return list(target_module_names)
221
+
222
+
223
+ def find_norm(model: nn.Module) -> List[str]:
224
+ # find_layer_norm
225
+ return find_layers(
226
+ model,
227
+ lambda name, module: isinstance(module, torch.nn.LayerNorm) or 'rmsnorm' in module.__class__.__name__.lower())
228
+
229
+
230
+ def find_embedding(model: nn.Module) -> List[str]:
231
+ return find_layers(model, lambda name, module: isinstance(module, torch.nn.Embedding))
232
+
233
+
234
+ def find_all_linears(model, model_arch=None, extra_layers=None, sub_module=None):
235
+ if model_arch is None:
236
+ from swift.llm import get_model_arch
237
+ model_arch = get_model_arch(model.model_meta.model_arch)
238
+ # lm_head
239
+ if model_arch and model_arch.lm_head:
240
+ output = model_arch.lm_head
241
+ idx = output.rfind('.')
242
+ lm_head_name = output[idx + 1:]
243
+ else:
244
+ lm_head_name = 'lm_head'
245
+ # 'score', 'classifier': classification model
246
+ # 'v_head': reward model
247
+ ignore_layers = [lm_head_name, 'score', 'v_head', 'classifier'] + ['lora_A', 'lora_B', 'base_layer']
248
+ ignore_linear_cls = [
249
+ 'glulinear' # phi4-mm
250
+ ]
251
+
252
+ def _cond(name, module):
253
+ module_name = module.__class__.__name__.lower()
254
+ if (extra_layers and isinstance(module, tuple(extra_layers)) or
255
+ ('linear' in module_name and all(linear_cls not in module_name
256
+ for linear_cls in ignore_linear_cls))) and all(layer not in name
257
+ for layer in ignore_layers):
258
+ return True
259
+ return False
260
+
261
+ return find_layers(model, _cond, sub_module=sub_module)
262
+
263
+
264
+ @contextmanager
265
+ def safe_ddp_context(hash_id: Optional[str], use_barrier: bool = False):
266
+ if use_barrier and dist.is_initialized():
267
+ if is_dist() or is_dist_ta():
268
+ if not is_master():
269
+ dist.barrier()
270
+ if not is_local_master():
271
+ # Compatible with multi-machine scenarios,
272
+ # where each machine uses different storage hardware.
273
+ dist.barrier()
274
+ yield
275
+ if is_dist() or is_dist_ta():
276
+ if is_master():
277
+ dist.barrier()
278
+ if is_local_master():
279
+ dist.barrier()
280
+ elif hash_id is not None:
281
+ lock_dir = os.path.join(get_cache_dir(), 'lockers')
282
+ os.makedirs(lock_dir, exist_ok=True)
283
+ file_path = hashlib.sha256(hash_id.encode('utf-8')).hexdigest() + '.lock'
284
+ file_path = os.path.join(lock_dir, file_path)
285
+ with FileLock(file_path):
286
+ yield
287
+ else:
288
+ yield
289
+
290
+
291
+ def get_device(local_rank: Optional[Union[str, int]] = None) -> str:
292
+ if local_rank is None:
293
+ local_rank = max(0, get_dist_setting()[1])
294
+ local_rank = str(local_rank)
295
+ if is_torch_npu_available():
296
+ device = 'npu:{}'.format(local_rank)
297
+ elif is_torch_mps_available():
298
+ device = 'mps:{}'.format(local_rank)
299
+ elif is_torch_cuda_available():
300
+ device = 'cuda:{}'.format(local_rank)
301
+ else:
302
+ device = 'cpu'
303
+
304
+ return device
305
+
306
+
307
+ def get_current_device():
308
+ if is_torch_npu_available():
309
+ current_device = torch.npu.current_device()
310
+ elif is_torch_cuda_available():
311
+ current_device = torch.cuda.current_device()
312
+ elif is_torch_mps_available():
313
+ current_device = 'mps'
314
+ else:
315
+ current_device = 'cpu'
316
+ return current_device
317
+
318
+
319
+ def set_device(local_rank: Optional[Union[str, int]] = None):
320
+ if local_rank is None:
321
+ local_rank = max(0, get_dist_setting()[1])
322
+ if is_torch_npu_available():
323
+ torch.npu.set_device(local_rank)
324
+ elif is_torch_cuda_available():
325
+ torch.cuda.set_device(local_rank)
326
+
327
+
328
+ def get_device_count() -> int:
329
+ if is_torch_npu_available():
330
+ return torch.npu.device_count()
331
+ elif is_torch_cuda_available():
332
+ return torch.cuda.device_count()
333
+ else:
334
+ return 0
335
+
336
+
337
+ def gc_collect() -> None:
338
+ gc.collect()
339
+ if is_torch_npu_available():
340
+ torch.npu.empty_cache()
341
+ elif is_torch_mps_available():
342
+ torch.mps.empty_cache()
343
+ elif is_torch_cuda_available():
344
+ torch.cuda.empty_cache()
345
+
346
+
347
+ class Serializer:
348
+
349
+ @staticmethod
350
+ def to_tensor(obj):
351
+ res = pickle.dumps(obj)
352
+ res = np.array([len(res)], dtype=np.int64).tobytes() + res
353
+ res = np.frombuffer(res, dtype=np.uint8).copy()
354
+ res = torch.from_numpy(res)
355
+ return res
356
+
357
+ @staticmethod
358
+ def from_tensor(obj):
359
+ if isinstance(obj, torch.Tensor):
360
+ obj = obj.cpu().numpy()
361
+ res = obj.tobytes()
362
+ buffer_size = np.frombuffer(res[:8], dtype=np.int64)[0]
363
+ res = res[8:]
364
+ return pickle.loads(res[:buffer_size])
365
+
366
+
367
+ def set_default_ddp_config():
368
+ # It runs normally with Python as well.
369
+ rank = int(os.getenv('RANK', -1))
370
+ if rank == -1:
371
+ os.environ['NPROC_PER_NODE'] = '1'
372
+ os.environ['RANK'] = '0'
373
+ os.environ['LOCAL_RANK'] = '0'
374
+ os.environ['WORLD_SIZE'] = '1'
375
+ os.environ['LOCAL_WORLD_SIZE'] = '1'
376
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
377
+ os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
378
+
379
+
380
+ def init_process_group(ddp_backend: Optional[str] = None):
381
+ if dist.is_initialized():
382
+ return
383
+ set_device()
384
+ if ddp_backend is None:
385
+ if is_torch_npu_available():
386
+ ddp_backend = 'hccl'
387
+ elif torch.cuda.is_available():
388
+ ddp_backend = 'nccl'
389
+ else:
390
+ ddp_backend = 'gloo'
391
+ dist.init_process_group(backend=ddp_backend)
ms-swift/tests/deploy/test_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def _test_client(port=8000):
2
+ import time
3
+ import aiohttp
4
+ from swift.llm import InferClient, InferRequest, RequestConfig, load_dataset, run_deploy
5
+ dataset = load_dataset(['AI-ModelScope/alpaca-gpt4-data-zh#1000'], num_proc=4)
6
+ infer_client = InferClient(port=port)
7
+ while True:
8
+ try:
9
+ infer_client.models
10
+ break
11
+ except Exception:
12
+ time.sleep(1)
13
+ pass
14
+ infer_requests = []
15
+ for data in dataset[0]:
16
+ infer_requests.append(InferRequest(**data))
17
+ request_config = RequestConfig(seed=42, max_tokens=256, temperature=0.8)
18
+
19
+ resp = infer_client.infer(infer_requests, request_config=request_config, use_tqdm=False)
20
+ print(len(resp))
21
+
22
+
23
+ def _test(infer_backend):
24
+ import os
25
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
26
+
27
+ from swift.llm import DeployArguments
28
+ from swift.llm import run_deploy
29
+ args = DeployArguments(model='Qwen/Qwen2-7B-Instruct', infer_backend=infer_backend, verbose=False)
30
+ with run_deploy(args) as port:
31
+ _test_client(port)
32
+
33
+
34
+ def test_vllm():
35
+ _test('vllm')
36
+
37
+
38
+ def test_lmdeploy():
39
+ _test('lmdeploy')
40
+
41
+
42
+ def test_pt():
43
+ _test('pt')
44
+
45
+
46
+ def test_vllm_origin():
47
+ import subprocess
48
+ import sys
49
+ from modelscope import snapshot_download
50
+ model_dir = snapshot_download('Qwen/Qwen2-7B-Instruct')
51
+ args = [sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--model', model_dir]
52
+ process = subprocess.Popen(args)
53
+ _test_client()
54
+ process.terminate()
55
+
56
+
57
+ if __name__ == '__main__':
58
+ # test_vllm_origin()
59
+ # test_vllm()
60
+ test_lmdeploy()
61
+ # test_pt()