Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ms-swift/silence_overlaps/only_overlap/.ipynb_checkpoints/overlap5s_isoverlap_train-checkpoint.json +0 -0
- ms-swift/swift/tuners/__pycache__/lora.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/mapping.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/part.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/restuning.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/restuning_components.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/side.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/__pycache__/utils.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/adapter.py +189 -0
- ms-swift/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/peft.py +392 -0
- ms-swift/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc +0 -0
- ms-swift/swift/tuners/scetuning/scetuning_components.py +127 -0
- ms-swift/swift/tuners/side.py +245 -0
- ms-swift/swift/ui/app.py +92 -0
- ms-swift/swift/ui/base.py +388 -0
- ms-swift/swift/ui/llm_eval/__init__.py +1 -0
- ms-swift/swift/ui/llm_eval/eval.py +130 -0
- ms-swift/swift/ui/llm_eval/model.py +78 -0
- ms-swift/swift/ui/llm_export/llm_export.py +191 -0
- ms-swift/swift/ui/llm_export/model.py +83 -0
- ms-swift/swift/ui/llm_export/runtime.py +75 -0
- ms-swift/swift/ui/llm_infer/__init__.py +1 -0
- ms-swift/swift/ui/llm_infer/generate.py +65 -0
- ms-swift/swift/ui/llm_infer/llm_infer.py +396 -0
- ms-swift/swift/ui/llm_infer/model.py +126 -0
- ms-swift/swift/ui/llm_infer/runtime.py +285 -0
- ms-swift/swift/ui/llm_train/__init__.py +1 -0
- ms-swift/swift/ui/llm_train/advanced.py +164 -0
- ms-swift/swift/ui/llm_train/dataset.py +91 -0
- ms-swift/swift/ui/llm_train/hyper.py +129 -0
- ms-swift/swift/ui/llm_train/llamapro.py +40 -0
- ms-swift/swift/ui/llm_train/llm_train.py +420 -0
- ms-swift/swift/ui/llm_train/lora.py +102 -0
- ms-swift/swift/ui/llm_train/model.py +127 -0
- ms-swift/swift/ui/llm_train/quantization.py +68 -0
- ms-swift/swift/ui/llm_train/report_to.py +75 -0
- ms-swift/swift/ui/llm_train/rlhf.py +102 -0
- ms-swift/swift/ui/llm_train/runtime.py +571 -0
- ms-swift/swift/ui/llm_train/save.py +84 -0
- ms-swift/swift/ui/llm_train/self_cog.py +57 -0
- ms-swift/swift/utils/__init__.py +19 -0
- ms-swift/swift/utils/__pycache__/np_utils.cpython-310.pyc +0 -0
- ms-swift/swift/utils/constants.py +27 -0
- ms-swift/swift/utils/logger.py +138 -0
- ms-swift/swift/utils/tb_utils.py +72 -0
- ms-swift/swift/utils/torch_utils.py +391 -0
- ms-swift/tests/deploy/test_dataset.py +61 -0
ms-swift/silence_overlaps/only_overlap/.ipynb_checkpoints/overlap5s_isoverlap_train-checkpoint.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/swift/tuners/__pycache__/lora.cpython-310.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/mapping.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/part.cpython-310.pyc
ADDED
|
Binary file (4.71 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/restuning.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/restuning_components.cpython-310.pyc
ADDED
|
Binary file (9.9 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/side.cpython-310.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
ms-swift/swift/tuners/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
ms-swift/swift/tuners/adapter.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import inspect
|
| 3 |
+
import re
|
| 4 |
+
import types
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from transformers.activations import ACT2CLS
|
| 11 |
+
|
| 12 |
+
from swift.utils.torch_utils import find_sub_module, get_logger
|
| 13 |
+
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AdapterConfig(SwiftConfig):
|
| 20 |
+
"""
|
| 21 |
+
The configuration class for the adapter module.
|
| 22 |
+
|
| 23 |
+
Adapters project input tokens by an MLP layer.
|
| 24 |
+
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
|
| 25 |
+
See http://arxiv.org/abs/1902.00751
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dim(`int`): The dimension of the hidden states
|
| 29 |
+
target_modules(`Union[str, List[str]]`): The feedforward module to be replaced.
|
| 30 |
+
in regex format if this argument is str, else will match with `end with` if List[str].
|
| 31 |
+
hidden_pos(`Union[str, int]`): The position of the hidden state to be passed into the adapter,
|
| 32 |
+
can be int (args) or str (kwargs)
|
| 33 |
+
method_name(`str`): The method to be replaced, default is `forward`
|
| 34 |
+
adapter_length: The length of the adapter length (intermediate length)
|
| 35 |
+
act_layer: The activation layer of the adapter
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
|
| 39 |
+
|
| 40 |
+
target_modules: Union[str, List[str]] = field(
|
| 41 |
+
default=None,
|
| 42 |
+
metadata={
|
| 43 |
+
'help':
|
| 44 |
+
'The feedforward module to be replaced. in regex format if this argument is str, '
|
| 45 |
+
'else will match with `end with` if List[str].'
|
| 46 |
+
})
|
| 47 |
+
|
| 48 |
+
hidden_pos: Union[str, int] = field(
|
| 49 |
+
default=None,
|
| 50 |
+
metadata={
|
| 51 |
+
'help': 'The position of the hidden state to be passed into the adapter, can be int (args) or str (kwargs)'
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
method_name: str = field(default='forward', metadata={'help': 'The method to be replaced, default is `forward`'})
|
| 55 |
+
|
| 56 |
+
adapter_length: int = field(
|
| 57 |
+
default=128, metadata={'help': 'The length of the adapter length (intermediate length)'})
|
| 58 |
+
|
| 59 |
+
act_layer: str = field(default='gelu', metadata={'help': 'The activation layer of the adapter'})
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
from .mapping import SwiftTuners
|
| 63 |
+
self.swift_type = SwiftTuners.ADAPTER
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Adapter(SwiftAdapter):
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def prepare_model(model: nn.Module, config: AdapterConfig, adapter_name: str) -> SwiftOutput:
|
| 70 |
+
"""Prepare a model with `AdapterConfig`"""
|
| 71 |
+
module_keys = [key for key, _ in model.named_modules()]
|
| 72 |
+
|
| 73 |
+
for module_key in module_keys:
|
| 74 |
+
if isinstance(config.target_modules, str):
|
| 75 |
+
target_module_found = re.fullmatch(config.target_modules, module_key)
|
| 76 |
+
else:
|
| 77 |
+
target_module_found = any(module_key.endswith(target_key) for target_key in config.target_modules)
|
| 78 |
+
|
| 79 |
+
if target_module_found: # noqa
|
| 80 |
+
module = model.get_submodule(module_key)
|
| 81 |
+
|
| 82 |
+
def _forward(self, *args, **kwargs):
|
| 83 |
+
args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
|
| 84 |
+
if isinstance(args, (tuple, list, dict)):
|
| 85 |
+
if isinstance(config.hidden_pos, int):
|
| 86 |
+
_type = type(args)
|
| 87 |
+
args = list(args)
|
| 88 |
+
args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
|
| 89 |
+
args = _type(args)
|
| 90 |
+
else:
|
| 91 |
+
args[config.hidden_pos] = getattr(self, f'adapter_{adapter_name}')(args[config.hidden_pos])
|
| 92 |
+
elif isinstance(args, torch.Tensor):
|
| 93 |
+
args = getattr(self, f'adapter_{adapter_name}')(args)
|
| 94 |
+
return args
|
| 95 |
+
|
| 96 |
+
def _feed_forward_chunk(self, attention_output):
|
| 97 |
+
return _forward(self, attention_output)
|
| 98 |
+
|
| 99 |
+
# TODO The `config.method_name` method should not be replaced twice.
|
| 100 |
+
|
| 101 |
+
setattr(module, f'forward_origin_{adapter_name}', getattr(module, config.method_name))
|
| 102 |
+
num_args_in_forward_chunk_fn = len(
|
| 103 |
+
inspect.signature(getattr(module, f'forward_origin_{adapter_name}')).parameters)
|
| 104 |
+
if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1:
|
| 105 |
+
setattr(module, config.method_name, types.MethodType(_feed_forward_chunk, module))
|
| 106 |
+
else:
|
| 107 |
+
setattr(module, config.method_name, types.MethodType(_forward, module))
|
| 108 |
+
adapter_module = AdapterModule(config.dim, adapter_name, module_key, config.adapter_length,
|
| 109 |
+
ACT2CLS[config.act_layer])
|
| 110 |
+
setattr(module, f'adapter_{adapter_name}', adapter_module)
|
| 111 |
+
logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}')
|
| 112 |
+
|
| 113 |
+
def state_dict_callback(state_dict, adapter_name: str, **kwargs):
|
| 114 |
+
return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key}
|
| 115 |
+
|
| 116 |
+
def mark_trainable_callback(model):
|
| 117 |
+
return
|
| 118 |
+
|
| 119 |
+
return SwiftOutput(
|
| 120 |
+
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
|
| 124 |
+
modules = find_sub_module(module, f'adapter_{adapter_name}')
|
| 125 |
+
for _module in modules:
|
| 126 |
+
_module: ActivationMixin
|
| 127 |
+
_module: nn.Module
|
| 128 |
+
_module.set_activation(adapter_name, activate)
|
| 129 |
+
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class AdapterModule(nn.Module, ActivationMixin):
|
| 133 |
+
"""The implementation of adapter tuning method.
|
| 134 |
+
|
| 135 |
+
Adapters project input tokens by an MLP layer.
|
| 136 |
+
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
|
| 137 |
+
See http://arxiv.org/abs/1902.00751
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
dim: An integer indicating the embedding dimension.
|
| 141 |
+
adapter_length: An integer indicating the length of adapter tuning.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
dim,
|
| 147 |
+
adapter_name,
|
| 148 |
+
module_key,
|
| 149 |
+
adapter_length=None,
|
| 150 |
+
act_layer=nn.GELU,
|
| 151 |
+
):
|
| 152 |
+
super(AdapterModule, self).__init__()
|
| 153 |
+
super(nn.Module, self).__init__(module_key)
|
| 154 |
+
self.dim = dim
|
| 155 |
+
self.adapter_name = adapter_name
|
| 156 |
+
self.adapter_length = adapter_length
|
| 157 |
+
self.linear1 = nn.Linear(dim, adapter_length)
|
| 158 |
+
self.act = act_layer()
|
| 159 |
+
self.linear2 = nn.Linear(adapter_length, dim)
|
| 160 |
+
self.init_weights()
|
| 161 |
+
self._prepared = False
|
| 162 |
+
self.mark_all_sub_modules_as_plugin()
|
| 163 |
+
|
| 164 |
+
def init_weights(self):
|
| 165 |
+
|
| 166 |
+
def _init_weights(m):
|
| 167 |
+
if isinstance(m, nn.Linear):
|
| 168 |
+
nn.init.xavier_uniform_(m.weight)
|
| 169 |
+
nn.init.normal_(m.bias, std=1e-6)
|
| 170 |
+
|
| 171 |
+
self.apply(_init_weights)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, identity=None):
|
| 174 |
+
if not self.is_activated(self.adapter_name):
|
| 175 |
+
return x
|
| 176 |
+
if not self._prepared:
|
| 177 |
+
self.linear1.to(x.device)
|
| 178 |
+
self.act.to(x.device)
|
| 179 |
+
self.linear2.to(x.device)
|
| 180 |
+
self._prepared = True
|
| 181 |
+
|
| 182 |
+
x_dtype = x.dtype
|
| 183 |
+
x = x.to(self.linear1.weight.dtype)
|
| 184 |
+
out = self.linear2(self.act(self.linear1(x)))
|
| 185 |
+
if identity is None:
|
| 186 |
+
identity = x
|
| 187 |
+
identity = identity.to(out.dtype)
|
| 188 |
+
out = identity + out
|
| 189 |
+
return out.to(x_dtype)
|
ms-swift/swift/tuners/longlora/__pycache__/longlora.cpython-310.pyc
ADDED
|
Binary file (4.2 kB). View file
|
|
|
ms-swift/swift/tuners/peft.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 3 |
+
import os.path
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from functools import partial, reduce
|
| 6 |
+
from types import MethodType
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import peft
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn
|
| 13 |
+
import transformers
|
| 14 |
+
from modelscope import snapshot_download
|
| 15 |
+
from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig,
|
| 16 |
+
PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
|
| 17 |
+
PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig,
|
| 18 |
+
PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config,
|
| 19 |
+
get_peft_model, get_peft_model_state_dict)
|
| 20 |
+
from peft.config import PeftConfigMixin
|
| 21 |
+
from peft.tuners import lora
|
| 22 |
+
from peft.tuners.adalora import AdaLoraModel, RankAllocator
|
| 23 |
+
from peft.tuners.lora import Embedding
|
| 24 |
+
from transformers import Trainer
|
| 25 |
+
|
| 26 |
+
from swift.utils import get_logger
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from peft import FourierFTModel
|
| 30 |
+
except ImportError:
|
| 31 |
+
FourierFTModel = None
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from peft import BoneModel
|
| 35 |
+
except ImportError:
|
| 36 |
+
BoneModel = None
|
| 37 |
+
|
| 38 |
+
logger = get_logger()
|
| 39 |
+
dispatchers = []
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class LoraConfig(peft.LoraConfig):
|
| 44 |
+
lora_dtype: Optional[str] = field(
|
| 45 |
+
default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})
|
| 46 |
+
|
| 47 |
+
lorap_lr_ratio: Optional[float] = field(default=None, metadata={'help': 'The lr ratio of lora_B in lora+'})
|
| 48 |
+
|
| 49 |
+
lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})
|
| 50 |
+
|
| 51 |
+
def to_peft_config(self) -> peft.LoraConfig:
|
| 52 |
+
_dict = asdict(self)
|
| 53 |
+
_dict.pop('lora_dtype')
|
| 54 |
+
_dict.pop('lorap_lr_ratio')
|
| 55 |
+
_dict.pop('lorap_emb_lr')
|
| 56 |
+
return peft.LoraConfig(**_dict)
|
| 57 |
+
|
| 58 |
+
def save_pretrained(self, save_directory: str, **kwargs) -> None:
|
| 59 |
+
self.to_peft_config().save_pretrained(save_directory, **kwargs)
|
| 60 |
+
additional_args = {
|
| 61 |
+
'lora_dtype': self.lora_dtype,
|
| 62 |
+
'lorap_lr_ratio': self.lorap_lr_ratio,
|
| 63 |
+
'lorap_emb_lr': self.lorap_emb_lr,
|
| 64 |
+
}
|
| 65 |
+
with open(os.path.join(save_directory, 'additional_config.json'), 'w', encoding='utf-8') as f:
|
| 66 |
+
json.dump(additional_args, f)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
|
| 70 |
+
if hasattr(PeftConfigMixin, 'from_pretrained_origin'):
|
| 71 |
+
self = PeftConfigMixin.from_pretrained_origin(pretrained_model_name_or_path, subfolder, **kwargs)
|
| 72 |
+
else:
|
| 73 |
+
self = super(LoraConfig, cls).from_pretrained(pretrained_model_name_or_path, subfolder, **kwargs)
|
| 74 |
+
|
| 75 |
+
if type(self) == peft.LoraConfig:
|
| 76 |
+
self = LoraConfig(**self.to_dict())
|
| 77 |
+
|
| 78 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, 'additional_config.json')):
|
| 79 |
+
with open(
|
| 80 |
+
os.path.join(pretrained_model_name_or_path, 'additional_config.json'), 'r', encoding='utf-8') as f:
|
| 81 |
+
_json = json.load(f)
|
| 82 |
+
for key, value in _json.items():
|
| 83 |
+
setattr(self, key, value)
|
| 84 |
+
|
| 85 |
+
return self
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
|
| 89 |
+
all_supported_names = ('linear', )
|
| 90 |
+
all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D, lora.Linear)
|
| 91 |
+
target_modules = getattr(peft_config, 'target_modules', None)
|
| 92 |
+
if target is None:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
if isinstance(target_modules, str) and not any(
|
| 96 |
+
[name in target.__class__.__name__.lower()
|
| 97 |
+
for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str):
|
| 107 |
+
if lora_dtype is not None:
|
| 108 |
+
torch_dtype = eval(f'torch.{lora_dtype}')
|
| 109 |
+
if hasattr(target, 'lora_A') and adapter_name in target.lora_A:
|
| 110 |
+
target.lora_A[adapter_name].to(torch_dtype)
|
| 111 |
+
target.lora_B[adapter_name].to(torch_dtype)
|
| 112 |
+
if hasattr(target, 'lora_embedding_A') and adapter_name in target.lora_embedding_A:
|
| 113 |
+
target.lora_embedding_A[adapter_name].to(torch_dtype)
|
| 114 |
+
target.lora_embedding_B[adapter_name].to(torch_dtype)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def create_optimizer_param_groups(self: PeftModel, **defaults):
|
| 118 |
+
if not isinstance(self.peft_config[self.active_adapter],
|
| 119 |
+
LoraConfig) or self.peft_config[self.active_adapter].lorap_lr_ratio is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def get_module(name):
|
| 123 |
+
parent_idx = 2 if 'lora' in name else 1
|
| 124 |
+
module_names = name.split(sep='.')[:-parent_idx]
|
| 125 |
+
module = reduce(getattr, module_names, self.base_model)
|
| 126 |
+
return module
|
| 127 |
+
|
| 128 |
+
param_groups = {
|
| 129 |
+
'groupA': {},
|
| 130 |
+
'groupB': {},
|
| 131 |
+
'groupB_no_decay': {},
|
| 132 |
+
'embedding': {},
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
decay_parameters = Trainer.get_decay_parameter_names(None, self.base_model)
|
| 136 |
+
for name, param in self.base_model.named_parameters():
|
| 137 |
+
if not param.requires_grad:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
module = get_module(name)
|
| 141 |
+
if isinstance(module, Embedding):
|
| 142 |
+
param_groups['embedding'][name] = param
|
| 143 |
+
elif 'lora_B' in name or param.ndim == 1:
|
| 144 |
+
if name in decay_parameters:
|
| 145 |
+
param_groups['groupB'][name] = param
|
| 146 |
+
else:
|
| 147 |
+
param_groups['groupB_no_decay'][name] = param
|
| 148 |
+
else:
|
| 149 |
+
param_groups['groupA'][name] = param
|
| 150 |
+
|
| 151 |
+
lr = defaults['lr']
|
| 152 |
+
weight_decay = defaults.get('weight_decay', 0.0)
|
| 153 |
+
|
| 154 |
+
param_groups = [
|
| 155 |
+
{
|
| 156 |
+
'params': list(param_groups['groupA'].values()),
|
| 157 |
+
'weight_decay': weight_decay,
|
| 158 |
+
'lr': lr,
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
'params': list(param_groups['embedding'].values()),
|
| 162 |
+
'weight_decay': weight_decay,
|
| 163 |
+
'lr': self.peft_config[self.active_adapter].lorap_emb_lr,
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
'params': list(param_groups['groupB'].values()),
|
| 167 |
+
'weight_decay': weight_decay,
|
| 168 |
+
'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
'params': list(param_groups['groupB_no_decay'].values()),
|
| 172 |
+
'weight_decay': 0.0,
|
| 173 |
+
'lr': lr * self.peft_config[self.active_adapter].lorap_lr_ratio,
|
| 174 |
+
},
|
| 175 |
+
]
|
| 176 |
+
return param_groups
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def adalora_forward(self, *args, **kwargs):
|
| 180 |
+
from peft.utils.integrations import gather_params_ctx
|
| 181 |
+
outputs = self.model.forward(*args, **kwargs)
|
| 182 |
+
|
| 183 |
+
if (getattr(outputs, 'loss', None) is not None) and isinstance(outputs.loss, torch.Tensor):
|
| 184 |
+
# Calculate the orthogonal regularization
|
| 185 |
+
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
|
| 186 |
+
|
| 187 |
+
if orth_reg_weight <= 0:
|
| 188 |
+
raise ValueError('orth_reg_weight should be greater than 0. ')
|
| 189 |
+
|
| 190 |
+
regu_loss = 0
|
| 191 |
+
num_param = 0
|
| 192 |
+
for n, p in self.model.named_parameters():
|
| 193 |
+
if ('lora_A' in n or 'lora_B' in n) and self.trainable_adapter_name in n:
|
| 194 |
+
if p.shape == torch.Size([0]):
|
| 195 |
+
with gather_params_ctx(p, fwd_module=self):
|
| 196 |
+
para_cov = p @ p.T if 'lora_A' in n else p.T @ p
|
| 197 |
+
else:
|
| 198 |
+
para_cov = p @ p.T if 'lora_A' in n else p.T @ p
|
| 199 |
+
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
|
| 200 |
+
I.requires_grad = False
|
| 201 |
+
num_param += 1
|
| 202 |
+
if isinstance(regu_loss, torch.Tensor):
|
| 203 |
+
regu_loss = regu_loss.to(para_cov.device)
|
| 204 |
+
regu_loss += torch.norm(para_cov - I, p='fro')
|
| 205 |
+
if num_param > 0:
|
| 206 |
+
regu_loss = regu_loss / num_param
|
| 207 |
+
else:
|
| 208 |
+
regu_loss = 0
|
| 209 |
+
if isinstance(regu_loss, torch.Tensor) and isinstance(outputs.loss, torch.Tensor):
|
| 210 |
+
regu_loss = regu_loss.to(outputs.loss.device)
|
| 211 |
+
outputs.loss += orth_reg_weight * regu_loss
|
| 212 |
+
return outputs
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def adalora_mask_to_budget(self, model, budget):
|
| 216 |
+
value_ipt = {}
|
| 217 |
+
vector_ipt = {}
|
| 218 |
+
triplet_ipt = {}
|
| 219 |
+
# Get the importance score for A, E, B
|
| 220 |
+
for n, p in model.named_parameters():
|
| 221 |
+
if f'lora_A.{self.adapter_name}' in n:
|
| 222 |
+
entry_ipt = self._element_score(n)
|
| 223 |
+
comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
|
| 224 |
+
name_m = n.replace('lora_A', '%s')
|
| 225 |
+
if name_m not in vector_ipt:
|
| 226 |
+
vector_ipt[name_m] = [comb_ipt]
|
| 227 |
+
else:
|
| 228 |
+
vector_ipt[name_m].append(comb_ipt)
|
| 229 |
+
if f'lora_B.{self.adapter_name}' in n:
|
| 230 |
+
entry_ipt = self._element_score(n)
|
| 231 |
+
comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
|
| 232 |
+
name_m = n.replace('lora_B', '%s')
|
| 233 |
+
if name_m not in vector_ipt:
|
| 234 |
+
vector_ipt[name_m] = [comb_ipt]
|
| 235 |
+
else:
|
| 236 |
+
vector_ipt[name_m].append(comb_ipt)
|
| 237 |
+
if f'lora_E.{self.adapter_name}' in n:
|
| 238 |
+
entry_ipt = self._element_score(n)
|
| 239 |
+
name_m = n.replace('lora_E', '%s')
|
| 240 |
+
value_ipt[name_m] = entry_ipt
|
| 241 |
+
|
| 242 |
+
all_score = []
|
| 243 |
+
# Calculate the score for each triplet
|
| 244 |
+
for name_m in vector_ipt:
|
| 245 |
+
ipt_E = value_ipt[name_m]
|
| 246 |
+
ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
|
| 247 |
+
sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
|
| 248 |
+
name_E = name_m % 'lora_E'
|
| 249 |
+
triplet_ipt[name_E] = sum_ipt.view(-1, 1)
|
| 250 |
+
sum_ipt = sum_ipt.view(-1)
|
| 251 |
+
if all_score:
|
| 252 |
+
sum_ipt = sum_ipt.to(all_score[0].device)
|
| 253 |
+
all_score.append(sum_ipt)
|
| 254 |
+
|
| 255 |
+
# Get the threshold by ranking ipt
|
| 256 |
+
mask_threshold = torch.kthvalue(
|
| 257 |
+
torch.cat(all_score),
|
| 258 |
+
k=self.init_bgt - budget,
|
| 259 |
+
)[0].item()
|
| 260 |
+
|
| 261 |
+
rank_pattern = {}
|
| 262 |
+
# Mask the unimportant triplets
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
for n, p in model.named_parameters():
|
| 265 |
+
if f'lora_E.{self.adapter_name}' in n:
|
| 266 |
+
p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
|
| 267 |
+
rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
|
| 268 |
+
return rank_pattern
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def keep_device_forward(self, *args, **kwargs):
|
| 272 |
+
x = args[0]
|
| 273 |
+
if self.weight.device != x.device:
|
| 274 |
+
return self.forward_origin(x.to(self.weight.device), *args[1:], **kwargs)
|
| 275 |
+
else:
|
| 276 |
+
return self.forward_origin(*args, **kwargs)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def hot_patch_peft_module():
|
| 280 |
+
from peft.tuners.lora import LoraLayer
|
| 281 |
+
if hasattr('LoraModel', '_create_and_replace_origin'):
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
# Fix Lora does not support NonDynamicallyQuantizableLinear
|
| 285 |
+
LoraModel._create_and_replace_origin = LoraModel._create_and_replace
|
| 286 |
+
LoraModel._create_and_replace = _create_and_replace_hook
|
| 287 |
+
AdaLoraModel._create_and_replace_origin = AdaLoraModel._create_and_replace
|
| 288 |
+
AdaLoraModel._create_and_replace = _create_and_replace_hook
|
| 289 |
+
VeraModel._create_and_replace_origin = VeraModel._create_and_replace
|
| 290 |
+
VeraModel._create_and_replace = _create_and_replace_hook
|
| 291 |
+
BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
|
| 292 |
+
BOFTModel._create_and_replace = _create_and_replace_hook
|
| 293 |
+
if FourierFTModel is not None:
|
| 294 |
+
FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
|
| 295 |
+
FourierFTModel._create_and_replace = _create_and_replace_hook
|
| 296 |
+
if BoneModel is not None:
|
| 297 |
+
BoneModel._create_and_replace_origin = BoneModel._create_and_replace
|
| 298 |
+
BoneModel._create_and_replace = _create_and_replace_hook
|
| 299 |
+
|
| 300 |
+
# Support type conversion
|
| 301 |
+
def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str):
|
| 302 |
+
|
| 303 |
+
self.__init_origin__(model, config, adapter_name)
|
| 304 |
+
active_adapters = self.active_adapter
|
| 305 |
+
if isinstance(active_adapters, str):
|
| 306 |
+
active_adapters = [active_adapters]
|
| 307 |
+
for active_adapter in active_adapters:
|
| 308 |
+
active_config = config[active_adapter] if isinstance(config, dict) else config
|
| 309 |
+
if hasattr(active_config, 'lora_dtype'):
|
| 310 |
+
for name, module in model.named_modules():
|
| 311 |
+
if isinstance(module, LoraLayer):
|
| 312 |
+
_convert_dtype(module, active_adapter, active_config.lora_dtype)
|
| 313 |
+
for lora in list(module.lora_A.values()) + list(module.lora_B.values()):
|
| 314 |
+
if not hasattr(lora, 'forward_origin'):
|
| 315 |
+
lora.forward_origin = lora.forward
|
| 316 |
+
lora.forward = MethodType(keep_device_forward, lora)
|
| 317 |
+
|
| 318 |
+
LoraModel.__init_origin__ = LoraModel.__init__
|
| 319 |
+
LoraModel.__init__ = __new_init__
|
| 320 |
+
|
| 321 |
+
# Support LoRA+
|
| 322 |
+
PeftModel.create_optimizer_param_groups = create_optimizer_param_groups
|
| 323 |
+
|
| 324 |
+
PeftConfigMixin.from_pretrained_origin = PeftConfigMixin.from_pretrained
|
| 325 |
+
PeftConfigMixin.from_pretrained = LoraConfig.from_pretrained
|
| 326 |
+
|
| 327 |
+
# Compatible with SwiftModel
|
| 328 |
+
def dummy_function(*args, **kwargs):
|
| 329 |
+
logger.warn(f'The function {kwargs["func"]} has no effects, consider using other functions.')
|
| 330 |
+
|
| 331 |
+
PeftModel.activate_adapter = PeftModel.set_adapter
|
| 332 |
+
PeftModel.deactivate_adapter = partial(dummy_function, func='deactivate_adapter')
|
| 333 |
+
PeftModel.set_active_adapters = partial(dummy_function, func='set_active_adapters')
|
| 334 |
+
|
| 335 |
+
# Fix adalora does not support device_map
|
| 336 |
+
AdaLoraModel.forward = adalora_forward
|
| 337 |
+
RankAllocator.mask_to_budget = adalora_mask_to_budget
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_wrapped_class(module_class):
|
| 341 |
+
"""Get a custom wrapper class for peft classes to download the models from the ModelScope hub
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
module_class: The actual module class
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
The wrapper
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
class PeftWrapper(module_class):
|
| 351 |
+
|
| 352 |
+
@classmethod
|
| 353 |
+
def from_pretrained(cls, model, model_id, *args, revision: Optional[str] = None, **kwargs):
|
| 354 |
+
if not os.path.exists(model_id):
|
| 355 |
+
model_id = snapshot_download(model_id, revision=revision)
|
| 356 |
+
return module_class.from_pretrained(model, model_id, *args, **kwargs)
|
| 357 |
+
|
| 358 |
+
PeftWrapper.__name__ = module_class.__name__
|
| 359 |
+
PeftWrapper.__qualname__ = module_class.__qualname__
|
| 360 |
+
return PeftWrapper
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def wrap_module(module):
|
| 364 |
+
if not hasattr(module, 'from_pretrained'):
|
| 365 |
+
return module
|
| 366 |
+
|
| 367 |
+
return get_wrapped_class(module)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
hot_patch_peft_module()
|
| 371 |
+
PeftModel = wrap_module(PeftModel)
|
| 372 |
+
PeftConfig = wrap_module(PeftConfig)
|
| 373 |
+
PeftModelForSeq2SeqLM = wrap_module(PeftModelForSeq2SeqLM)
|
| 374 |
+
PeftModelForSequenceClassification = wrap_module(PeftModelForSequenceClassification)
|
| 375 |
+
PeftModelForTokenClassification = wrap_module(PeftModelForTokenClassification)
|
| 376 |
+
PeftModelForCausalLM = wrap_module(PeftModelForCausalLM)
|
| 377 |
+
PromptEncoderConfig = wrap_module(PromptEncoderConfig)
|
| 378 |
+
PromptTuningConfig = wrap_module(PromptTuningConfig)
|
| 379 |
+
PrefixTuningConfig = wrap_module(PrefixTuningConfig)
|
| 380 |
+
PromptLearningConfig = wrap_module(PromptLearningConfig)
|
| 381 |
+
LoraConfig = wrap_module(LoraConfig)
|
| 382 |
+
AdaLoraConfig = wrap_module(AdaLoraConfig)
|
| 383 |
+
LoHaConfig = wrap_module(LoHaConfig)
|
| 384 |
+
LoKrConfig = wrap_module(LoKrConfig)
|
| 385 |
+
LoftQConfig = wrap_module(LoftQConfig)
|
| 386 |
+
OFTConfig = wrap_module(OFTConfig)
|
| 387 |
+
BOFTConfig = wrap_module(BOFTConfig)
|
| 388 |
+
VeraConfig = wrap_module(VeraConfig)
|
| 389 |
+
OFTConfig = wrap_module(OFTConfig)
|
| 390 |
+
get_peft_config = get_peft_config
|
| 391 |
+
get_peft_model_state_dict = get_peft_model_state_dict
|
| 392 |
+
get_peft_model = get_peft_model
|
ms-swift/swift/tuners/scetuning/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
ms-swift/swift/tuners/scetuning/__pycache__/scetuning.cpython-310.pyc
ADDED
|
Binary file (8.37 kB). View file
|
|
|
ms-swift/swift/tuners/scetuning/__pycache__/scetuning_components.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
ms-swift/swift/tuners/scetuning/scetuning_components.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from swift.utils.logger import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def detach_tensors(feats):
|
| 13 |
+
if type(feats) in [list, tuple]:
|
| 14 |
+
feats = [detach_tensors(feat) if feat is not None else None for feat in feats]
|
| 15 |
+
elif isinstance(feats, dict):
|
| 16 |
+
feats = {key: detach_tensors(val) for key, val in feats.items()}
|
| 17 |
+
elif isinstance(feats, torch.Tensor):
|
| 18 |
+
feats = feats.detach()
|
| 19 |
+
else:
|
| 20 |
+
feats = feats.detach()
|
| 21 |
+
return feats
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def probe_tensors(module, feats, name):
|
| 25 |
+
feats = detach_tensors(feats)
|
| 26 |
+
setattr(module, name, feats)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def probe_input_pre_hook(self, args):
|
| 30 |
+
input = args[0]
|
| 31 |
+
probe_tensors(self, input, 'probe_input_data')
|
| 32 |
+
return args
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def probe_output_hook(self, args, result):
|
| 36 |
+
output = result
|
| 37 |
+
probe_tensors(self, output, 'probe_output_data')
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def choose_weight_type(weight_type, dim):
|
| 42 |
+
if weight_type == 'gate':
|
| 43 |
+
scaling = nn.Linear(dim, 1)
|
| 44 |
+
elif weight_type == 'scale':
|
| 45 |
+
scaling = nn.Parameter(torch.Tensor(1))
|
| 46 |
+
scaling.data.fill_(1)
|
| 47 |
+
elif weight_type == 'scale_channel':
|
| 48 |
+
scaling = nn.Parameter(torch.Tensor(dim))
|
| 49 |
+
scaling.data.fill_(1)
|
| 50 |
+
elif weight_type and weight_type.startswith('scalar'):
|
| 51 |
+
scaling = float(weight_type.split('_')[-1])
|
| 52 |
+
else:
|
| 53 |
+
scaling = None
|
| 54 |
+
return scaling
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_weight_value(weight_type, scaling, x):
|
| 58 |
+
if weight_type in ['gate']:
|
| 59 |
+
scaling = torch.mean(torch.sigmoid(scaling(x)), dim=1).view(-1, 1, 1)
|
| 60 |
+
elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'):
|
| 61 |
+
scaling = scaling
|
| 62 |
+
else:
|
| 63 |
+
scaling = None
|
| 64 |
+
return scaling
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SCEAdapter(nn.Module):
|
| 68 |
+
|
| 69 |
+
def __init__(self,
|
| 70 |
+
dim,
|
| 71 |
+
adapter_length,
|
| 72 |
+
adapter_type=None,
|
| 73 |
+
adapter_weight=None,
|
| 74 |
+
act_layer=nn.GELU,
|
| 75 |
+
zero_init_last=True,
|
| 76 |
+
use_bias=True):
|
| 77 |
+
super(SCEAdapter, self).__init__()
|
| 78 |
+
self.dim = dim
|
| 79 |
+
self.adapter_length = adapter_length
|
| 80 |
+
self.adapter_type = adapter_type
|
| 81 |
+
self.adapter_weight = adapter_weight
|
| 82 |
+
self.zero_init_last = zero_init_last
|
| 83 |
+
self.ln1 = nn.Linear(dim, adapter_length, bias=use_bias)
|
| 84 |
+
self.activate = act_layer()
|
| 85 |
+
self.ln2 = nn.Linear(adapter_length, dim, bias=use_bias)
|
| 86 |
+
self.init_weights()
|
| 87 |
+
self.init_scaling()
|
| 88 |
+
|
| 89 |
+
def _zero_init_weights(self, m):
|
| 90 |
+
if isinstance(m, nn.Linear):
|
| 91 |
+
nn.init.zeros_(m.weight)
|
| 92 |
+
nn.init.zeros_(m.bias)
|
| 93 |
+
|
| 94 |
+
def _kaiming_init_weights(self, m):
|
| 95 |
+
if isinstance(m, nn.Linear):
|
| 96 |
+
nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
|
| 97 |
+
|
| 98 |
+
def init_weights(self):
|
| 99 |
+
self._kaiming_init_weights(self.ln1)
|
| 100 |
+
if self.zero_init_last:
|
| 101 |
+
self._zero_init_weights(self.ln2)
|
| 102 |
+
else:
|
| 103 |
+
self._kaiming_init_weights(self.ln2)
|
| 104 |
+
|
| 105 |
+
def init_scaling(self):
|
| 106 |
+
if self.adapter_weight:
|
| 107 |
+
self.scaling = choose_weight_type(self.adapter_weight, self.dim)
|
| 108 |
+
else:
|
| 109 |
+
self.scaling = None
|
| 110 |
+
|
| 111 |
+
def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
|
| 112 |
+
if x_shortcut is None:
|
| 113 |
+
x_shortcut = x
|
| 114 |
+
x_shape = x.shape
|
| 115 |
+
if len(x_shape) == 4:
|
| 116 |
+
b, d, h, w = x_shape
|
| 117 |
+
x = x.permute(0, 2, 3, 1).reshape(b, h * w, d)
|
| 118 |
+
out = self.ln2(self.activate(self.ln1(x)))
|
| 119 |
+
if self.adapter_weight:
|
| 120 |
+
scaling = get_weight_value(self.adapter_weight, self.scaling, out)
|
| 121 |
+
out = out * scaling if scaling is not None else out
|
| 122 |
+
if len(x_shape) == 4:
|
| 123 |
+
b, d, h, w = x_shape
|
| 124 |
+
out = out.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
|
| 125 |
+
if use_shortcut:
|
| 126 |
+
out = x_shortcut + out
|
| 127 |
+
return out
|
ms-swift/swift/tuners/side.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import copy
|
| 3 |
+
import re
|
| 4 |
+
import types
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from functools import partial
|
| 8 |
+
from itertools import repeat
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from swift.utils.logger import get_logger
|
| 15 |
+
from swift.utils.torch_utils import find_sub_module
|
| 16 |
+
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class SideConfig(SwiftConfig):
|
| 23 |
+
"""
|
| 24 |
+
The configuration class for the side module.
|
| 25 |
+
|
| 26 |
+
Side-Tuning only needs to train one side network and
|
| 27 |
+
weights the output of pre-trained model and side network.
|
| 28 |
+
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
|
| 29 |
+
by Zhang et al.(2019)
|
| 30 |
+
See https://arxiv.org/abs/1912.13503
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
target_modules: The feedforward module to be replaced, in regex format
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
|
| 37 |
+
|
| 38 |
+
target_modules: str = field(
|
| 39 |
+
default=None, metadata={'help': 'The target module to be replaced, in full match format'})
|
| 40 |
+
|
| 41 |
+
side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'})
|
| 42 |
+
|
| 43 |
+
source_hidden_pos: Union[str, int] = field(
|
| 44 |
+
default=0,
|
| 45 |
+
metadata={
|
| 46 |
+
'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)'
|
| 47 |
+
})
|
| 48 |
+
|
| 49 |
+
target_hidden_pos: Union[str, int] = field(
|
| 50 |
+
default=0,
|
| 51 |
+
metadata={
|
| 52 |
+
'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)'
|
| 53 |
+
})
|
| 54 |
+
|
| 55 |
+
def __post_init__(self):
|
| 56 |
+
from .mapping import SwiftTuners
|
| 57 |
+
self.swift_type = SwiftTuners.SIDE
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Side(SwiftAdapter):
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput:
|
| 64 |
+
"""Prepare a model with `SideConfig`"""
|
| 65 |
+
module_keys = [key for key, _ in model.named_modules()]
|
| 66 |
+
|
| 67 |
+
for module_key in module_keys:
|
| 68 |
+
if re.fullmatch(config.target_modules, module_key): # noqa
|
| 69 |
+
tgt_module = model.get_submodule(module_key)
|
| 70 |
+
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
|
| 71 |
+
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
|
| 72 |
+
raise Exception(
|
| 73 |
+
f'Type of {type(tgt_module)} may not be supported because of its customized forward')
|
| 74 |
+
|
| 75 |
+
def _forward(self, *args, **kwargs):
|
| 76 |
+
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
|
| 77 |
+
|
| 78 |
+
if isinstance(config.source_hidden_pos, int):
|
| 79 |
+
x = args[config.source_hidden_pos]
|
| 80 |
+
else:
|
| 81 |
+
x = kwargs[config.source_hidden_pos]
|
| 82 |
+
|
| 83 |
+
x_main = args_main[config.target_hidden_pos] \
|
| 84 |
+
if isinstance(args_main, (tuple, list, dict)) else args_main
|
| 85 |
+
out = getattr(self, f'side_{adapter_name}')(x, x_main)
|
| 86 |
+
if isinstance(args_main, (tuple, list, dict)):
|
| 87 |
+
args_main[config.target_hidden_pos] = out
|
| 88 |
+
else:
|
| 89 |
+
args_main = out
|
| 90 |
+
return args_main
|
| 91 |
+
|
| 92 |
+
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'):
|
| 93 |
+
tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
|
| 94 |
+
|
| 95 |
+
def forward_seq(self, input, *args, **kwargs):
|
| 96 |
+
for idx, module in enumerate(self):
|
| 97 |
+
if idx >= len(tgt_module.tgt_module_keys):
|
| 98 |
+
continue
|
| 99 |
+
input = module(input)
|
| 100 |
+
return input
|
| 101 |
+
|
| 102 |
+
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module))
|
| 103 |
+
else:
|
| 104 |
+
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
|
| 105 |
+
tgt_module.forward = types.MethodType(_forward, tgt_module)
|
| 106 |
+
side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name)
|
| 107 |
+
setattr(tgt_module, f'side_{adapter_name}', side_module)
|
| 108 |
+
logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')
|
| 109 |
+
|
| 110 |
+
def state_dict_callback(state_dict, adapter_name, **kwargs):
|
| 111 |
+
return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}
|
| 112 |
+
|
| 113 |
+
def mark_trainable_callback(model):
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
return SwiftOutput(
|
| 117 |
+
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
|
| 121 |
+
modules = find_sub_module(module, f'side_{adapter_name}')
|
| 122 |
+
for _module in modules:
|
| 123 |
+
_module: ActivationMixin
|
| 124 |
+
_module: nn.Module
|
| 125 |
+
_module.set_activation(adapter_name, activate)
|
| 126 |
+
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SideModule(nn.Module, ActivationMixin):
|
| 130 |
+
"""The implementation of vision side-tuning method.
|
| 131 |
+
|
| 132 |
+
Side-Tuning only needs to train one side network and
|
| 133 |
+
weights the output of pre-trained model and side network.
|
| 134 |
+
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
|
| 135 |
+
by Zhang et al.(2019)
|
| 136 |
+
See https://arxiv.org/abs/1912.13503
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
side_module_name: The name of the additive side networks.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'):
|
| 143 |
+
super(SideModule, self).__init__()
|
| 144 |
+
super(nn.Module, self).__init__(module_key)
|
| 145 |
+
self.adapter_name = adapter_name
|
| 146 |
+
|
| 147 |
+
side_module_name = side_module_name.lower()
|
| 148 |
+
if side_module_name == 'fcn4':
|
| 149 |
+
self.side_net = FCN4(out_dims=dim)
|
| 150 |
+
elif side_module_name == 'mlp':
|
| 151 |
+
self.side_net = Mlp(dim)
|
| 152 |
+
elif side_module_name == 'alexnet':
|
| 153 |
+
import torchvision
|
| 154 |
+
mm = torchvision.models.alexnet(pretrained=True)
|
| 155 |
+
self.side_net = nn.Sequential(
|
| 156 |
+
OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()),
|
| 157 |
+
('fc', nn.Linear(9216, dim, bias=False))]))
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f'Unsupported side_module_name: {side_module_name}')
|
| 160 |
+
self.alpha = nn.Parameter(torch.tensor(0.0))
|
| 161 |
+
self.mark_all_sub_modules_as_plugin()
|
| 162 |
+
|
| 163 |
+
def forward(self, x, x_main):
|
| 164 |
+
if not self.is_activated(self.adapter_name):
|
| 165 |
+
return x_main
|
| 166 |
+
alpha_squashed = torch.sigmoid(self.alpha)
|
| 167 |
+
x_side = self.side_net(x)
|
| 168 |
+
x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side
|
| 169 |
+
return x_out
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class FCN4(nn.Module):
|
| 173 |
+
"""The implementation of simple FCN4 network for side network.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(self, out_dims=-1, **kwargs):
|
| 177 |
+
super(FCN4, self).__init__(**kwargs)
|
| 178 |
+
|
| 179 |
+
self.conv1 = nn.Sequential(
|
| 180 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16),
|
| 181 |
+
nn.ReLU())
|
| 182 |
+
self.conv2 = nn.Sequential(
|
| 183 |
+
nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16),
|
| 184 |
+
nn.ReLU())
|
| 185 |
+
self.conv3 = nn.Sequential(
|
| 186 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32),
|
| 187 |
+
nn.ReLU())
|
| 188 |
+
self.conv4 = nn.Sequential(
|
| 189 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64),
|
| 190 |
+
nn.ReLU())
|
| 191 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 192 |
+
if out_dims > 0:
|
| 193 |
+
self.fc = nn.Linear(64, out_dims)
|
| 194 |
+
else:
|
| 195 |
+
self.fc = None
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
x = self.conv1(x)
|
| 199 |
+
x = self.conv2(x)
|
| 200 |
+
x = self.conv3(x)
|
| 201 |
+
x = self.conv4(x)
|
| 202 |
+
x = self.pool(x)
|
| 203 |
+
x = x.view(x.size(0), -1)
|
| 204 |
+
if self.fc is not None:
|
| 205 |
+
x = self.fc(x)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Mlp(nn.Module):
|
| 210 |
+
""" MLP as used in Vision Transformer.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
in_features,
|
| 216 |
+
hidden_features=None,
|
| 217 |
+
out_features=None,
|
| 218 |
+
act_layer=nn.GELU,
|
| 219 |
+
norm_layer=None,
|
| 220 |
+
bias=True,
|
| 221 |
+
drop=0.,
|
| 222 |
+
use_conv=False,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
out_features = out_features or in_features
|
| 226 |
+
hidden_features = hidden_features or in_features
|
| 227 |
+
bias = tuple(repeat(bias, 2))
|
| 228 |
+
drop_probs = tuple(repeat(drop, 2))
|
| 229 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 230 |
+
|
| 231 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 232 |
+
self.act = act_layer()
|
| 233 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 234 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 235 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 236 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
x = self.fc1(x)
|
| 240 |
+
x = self.act(x)
|
| 241 |
+
x = self.drop1(x)
|
| 242 |
+
x = self.norm(x)
|
| 243 |
+
x = self.fc2(x)
|
| 244 |
+
x = self.drop2(x)
|
| 245 |
+
return x
|
ms-swift/swift/ui/app.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from packaging import version
|
| 8 |
+
from transformers.utils import strtobool
|
| 9 |
+
|
| 10 |
+
import swift
|
| 11 |
+
from swift.llm import DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SwiftPipeline, WebUIArguments
|
| 12 |
+
from swift.ui.llm_eval.llm_eval import LLMEval
|
| 13 |
+
from swift.ui.llm_export.llm_export import LLMExport
|
| 14 |
+
from swift.ui.llm_infer.llm_infer import LLMInfer
|
| 15 |
+
from swift.ui.llm_train.llm_train import LLMTrain
|
| 16 |
+
|
| 17 |
+
locale_dict = {
|
| 18 |
+
'title': {
|
| 19 |
+
'zh': '🚀SWIFT: 轻量级大模型训练推理框架',
|
| 20 |
+
'en': '🚀SWIFT: Scalable lightWeight Infrastructure for Fine-Tuning and Inference'
|
| 21 |
+
},
|
| 22 |
+
'sub_title': {
|
| 23 |
+
'zh':
|
| 24 |
+
'请查看 <a href=\"https://github.com/modelscope/swift/tree/main/docs/source\" target=\"_blank\">'
|
| 25 |
+
'SWIFT 文档</a>来查看更多功能,使用SWIFT_UI_LANG=en环境变量来切换英文界面',
|
| 26 |
+
'en':
|
| 27 |
+
'Please check <a href=\"https://github.com/modelscope/swift/tree/main/docs/source_en\" target=\"_blank\">'
|
| 28 |
+
'SWIFT Documentation</a> for more usages, Use SWIFT_UI_LANG=zh variable to switch to Chinese UI',
|
| 29 |
+
},
|
| 30 |
+
'star_beggar': {
|
| 31 |
+
'zh':
|
| 32 |
+
'喜欢<a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>就动动手指给我们加个star吧🥺 ',
|
| 33 |
+
'en':
|
| 34 |
+
'If you like <a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>, '
|
| 35 |
+
'please take a few seconds to star us🥺 '
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SwiftWebUI(SwiftPipeline):
|
| 41 |
+
|
| 42 |
+
args_class = WebUIArguments
|
| 43 |
+
args: args_class
|
| 44 |
+
|
| 45 |
+
def run(self):
|
| 46 |
+
lang = os.environ.get('SWIFT_UI_LANG') or self.args.lang
|
| 47 |
+
share_env = os.environ.get('WEBUI_SHARE')
|
| 48 |
+
share = strtobool(share_env) if share_env else self.args.share
|
| 49 |
+
server = os.environ.get('WEBUI_SERVER') or self.args.server_name
|
| 50 |
+
port_env = os.environ.get('WEBUI_PORT')
|
| 51 |
+
port = int(port_env) if port_env else self.args.server_port
|
| 52 |
+
LLMTrain.set_lang(lang)
|
| 53 |
+
LLMInfer.set_lang(lang)
|
| 54 |
+
LLMExport.set_lang(lang)
|
| 55 |
+
LLMEval.set_lang(lang)
|
| 56 |
+
with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app:
|
| 57 |
+
try:
|
| 58 |
+
_version = swift.__version__
|
| 59 |
+
except AttributeError:
|
| 60 |
+
_version = ''
|
| 61 |
+
gr.HTML(f"<h1><center>{locale_dict['title'][lang]}({_version})</center></h1>")
|
| 62 |
+
gr.HTML(f"<h3><center>{locale_dict['sub_title'][lang]}</center></h3>")
|
| 63 |
+
with gr.Tabs():
|
| 64 |
+
LLMTrain.build_ui(LLMTrain)
|
| 65 |
+
LLMInfer.build_ui(LLMInfer)
|
| 66 |
+
LLMExport.build_ui(LLMExport)
|
| 67 |
+
LLMEval.build_ui(LLMEval)
|
| 68 |
+
|
| 69 |
+
concurrent = {}
|
| 70 |
+
if version.parse(gr.__version__) < version.parse('4.0.0'):
|
| 71 |
+
concurrent = {'concurrency_count': 5}
|
| 72 |
+
app.load(
|
| 73 |
+
partial(LLMTrain.update_input_model, arg_cls=RLHFArguments),
|
| 74 |
+
inputs=[LLMTrain.element('model')],
|
| 75 |
+
outputs=[LLMTrain.element('train_record')] + list(LLMTrain.valid_elements().values()))
|
| 76 |
+
app.load(
|
| 77 |
+
partial(LLMInfer.update_input_model, arg_cls=DeployArguments, has_record=False),
|
| 78 |
+
inputs=[LLMInfer.element('model')],
|
| 79 |
+
outputs=list(LLMInfer.valid_elements().values()))
|
| 80 |
+
app.load(
|
| 81 |
+
partial(LLMExport.update_input_model, arg_cls=ExportArguments, has_record=False),
|
| 82 |
+
inputs=[LLMExport.element('model')],
|
| 83 |
+
outputs=list(LLMExport.valid_elements().values()))
|
| 84 |
+
app.load(
|
| 85 |
+
partial(LLMEval.update_input_model, arg_cls=EvalArguments, has_record=False),
|
| 86 |
+
inputs=[LLMEval.element('model')],
|
| 87 |
+
outputs=list(LLMEval.valid_elements().values()))
|
| 88 |
+
app.queue(**concurrent).launch(server_name=server, inbrowser=True, server_port=port, height=800, share=share)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def webui_main(args: Union[List[str], WebUIArguments, None] = None):
|
| 92 |
+
return SwiftWebUI(args).main()
|
ms-swift/swift/ui/base.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import dataclasses
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import typing
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from dataclasses import fields
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from functools import wraps
|
| 11 |
+
from typing import Any, Dict, List, Type
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import json
|
| 15 |
+
from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video
|
| 16 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 17 |
+
|
| 18 |
+
from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta
|
| 19 |
+
|
| 20 |
+
all_langs = ['zh', 'en']
|
| 21 |
+
builder: Type['BaseUI'] = None
|
| 22 |
+
base_builder: Type['BaseUI'] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def update_data(fn):
|
| 26 |
+
|
| 27 |
+
@wraps(fn)
|
| 28 |
+
def wrapper(*args, **kwargs):
|
| 29 |
+
elem_id = kwargs.get('elem_id', None)
|
| 30 |
+
self = args[0]
|
| 31 |
+
|
| 32 |
+
if builder is not None:
|
| 33 |
+
choices = base_builder.choice(elem_id)
|
| 34 |
+
if choices:
|
| 35 |
+
choices = [str(choice) if choice is not None else None for choice in choices]
|
| 36 |
+
kwargs['choices'] = choices
|
| 37 |
+
|
| 38 |
+
if not isinstance(self, (Tab, TabItem, Accordion)) and 'interactive' not in kwargs: # noqa
|
| 39 |
+
kwargs['interactive'] = True
|
| 40 |
+
|
| 41 |
+
if 'is_list' in kwargs:
|
| 42 |
+
self.is_list = kwargs.pop('is_list')
|
| 43 |
+
|
| 44 |
+
if base_builder and base_builder.default(elem_id) is not None and not kwargs.get('value'):
|
| 45 |
+
kwargs['value'] = base_builder.default(elem_id)
|
| 46 |
+
|
| 47 |
+
if builder is not None:
|
| 48 |
+
if elem_id in builder.locales(builder.lang):
|
| 49 |
+
values = builder.locale(elem_id, builder.lang)
|
| 50 |
+
if 'info' in values:
|
| 51 |
+
kwargs['info'] = values['info']
|
| 52 |
+
if 'value' in values:
|
| 53 |
+
kwargs['value'] = values['value']
|
| 54 |
+
if 'label' in values:
|
| 55 |
+
kwargs['label'] = values['label']
|
| 56 |
+
if hasattr(builder, 'visible'):
|
| 57 |
+
kwargs['visible'] = builder.visible
|
| 58 |
+
argument = base_builder.argument(elem_id)
|
| 59 |
+
if argument and 'label' in kwargs:
|
| 60 |
+
kwargs['label'] = kwargs['label'] + f'({argument})'
|
| 61 |
+
|
| 62 |
+
kwargs['elem_classes'] = 'align'
|
| 63 |
+
ret = fn(self, **kwargs)
|
| 64 |
+
self.constructor_args.update(kwargs)
|
| 65 |
+
|
| 66 |
+
if builder is not None:
|
| 67 |
+
builder.element_dict[elem_id] = self
|
| 68 |
+
return ret
|
| 69 |
+
|
| 70 |
+
return wrapper
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
Textbox.__init__ = update_data(Textbox.__init__)
|
| 74 |
+
Dropdown.__init__ = update_data(Dropdown.__init__)
|
| 75 |
+
Checkbox.__init__ = update_data(Checkbox.__init__)
|
| 76 |
+
Slider.__init__ = update_data(Slider.__init__)
|
| 77 |
+
TabItem.__init__ = update_data(TabItem.__init__)
|
| 78 |
+
Accordion.__init__ = update_data(Accordion.__init__)
|
| 79 |
+
Button.__init__ = update_data(Button.__init__)
|
| 80 |
+
File.__init__ = update_data(File.__init__)
|
| 81 |
+
Image.__init__ = update_data(Image.__init__)
|
| 82 |
+
Video.__init__ = update_data(Video.__init__)
|
| 83 |
+
Audio.__init__ = update_data(Audio.__init__)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class BaseUI:
|
| 87 |
+
|
| 88 |
+
choice_dict: Dict[str, List] = {}
|
| 89 |
+
default_dict: Dict[str, Any] = {}
|
| 90 |
+
locale_dict: Dict[str, Dict] = {}
|
| 91 |
+
element_dict: Dict[str, Dict] = {}
|
| 92 |
+
arguments: Dict[str, str] = {}
|
| 93 |
+
sub_ui: List[Type['BaseUI']] = []
|
| 94 |
+
group: str = None
|
| 95 |
+
lang: str = all_langs[0]
|
| 96 |
+
int_regex = r'^[-+]?[0-9]+$'
|
| 97 |
+
float_regex = r'[-+]?(?:\d*\.*\d+)'
|
| 98 |
+
bool_regex = r'^(T|t)rue$|^(F|f)alse$'
|
| 99 |
+
cache_dir = os.path.join(get_cache_dir(), 'swift-web-ui')
|
| 100 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 101 |
+
quote = '\'' if sys.platform != 'win32' else '"'
|
| 102 |
+
visible = True
|
| 103 |
+
_locale = {
|
| 104 |
+
'local_dir_alert': {
|
| 105 |
+
'value': {
|
| 106 |
+
'zh': '无法识别model_type和template,请手动选择',
|
| 107 |
+
'en': 'Cannot recognize the model_type and template, please choose manually'
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
@classmethod
|
| 113 |
+
def build_ui(cls, base_tab: Type['BaseUI']):
|
| 114 |
+
"""Build UI"""
|
| 115 |
+
global builder, base_builder
|
| 116 |
+
cls.element_dict = {}
|
| 117 |
+
old_builder = builder
|
| 118 |
+
old_base_builder = base_builder
|
| 119 |
+
builder = cls
|
| 120 |
+
base_builder = base_tab
|
| 121 |
+
cls.do_build_ui(base_tab)
|
| 122 |
+
builder = old_builder
|
| 123 |
+
base_builder = old_base_builder
|
| 124 |
+
if cls is base_tab:
|
| 125 |
+
for ui in cls.sub_ui:
|
| 126 |
+
ui.after_build_ui(base_tab)
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 134 |
+
"""Build UI"""
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def save_cache(cls, key, value):
|
| 139 |
+
timestamp = str(int(time.time()))
|
| 140 |
+
key = key.replace('/', '-')
|
| 141 |
+
filename = os.path.join(cls.cache_dir, key + '-' + timestamp)
|
| 142 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 143 |
+
json.dump(value, f)
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def list_cache(cls, key):
|
| 147 |
+
files = []
|
| 148 |
+
key = key.replace('/', '-')
|
| 149 |
+
for _, _, filenames in os.walk(cls.cache_dir):
|
| 150 |
+
for filename in filenames:
|
| 151 |
+
if filename.startswith(key):
|
| 152 |
+
idx = filename.rfind('-')
|
| 153 |
+
key, ts = filename[:idx], filename[idx + 1:]
|
| 154 |
+
dt_object = datetime.fromtimestamp(int(ts))
|
| 155 |
+
formatted_time = dt_object.strftime('%Y/%m/%d %H:%M:%S')
|
| 156 |
+
files.append(formatted_time)
|
| 157 |
+
return sorted(files, reverse=True)
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def load_cache(cls, key, timestamp) -> BaseArguments:
|
| 161 |
+
dt_object = datetime.strptime(timestamp, '%Y/%m/%d %H:%M:%S')
|
| 162 |
+
timestamp = int(dt_object.timestamp())
|
| 163 |
+
key = key.replace('/', '-')
|
| 164 |
+
filename = key + '-' + str(timestamp)
|
| 165 |
+
with open(os.path.join(cls.cache_dir, filename), 'r', encoding='utf-8') as f:
|
| 166 |
+
return json.load(f)
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def clear_cache(cls, key):
|
| 170 |
+
key = key.replace('/', '-')
|
| 171 |
+
for _, _, filenames in os.walk(cls.cache_dir):
|
| 172 |
+
for filename in filenames:
|
| 173 |
+
if filename.startswith(key):
|
| 174 |
+
os.remove(os.path.join(cls.cache_dir, filename))
|
| 175 |
+
|
| 176 |
+
@classmethod
|
| 177 |
+
def choice(cls, elem_id):
|
| 178 |
+
"""Get choice by elem_id"""
|
| 179 |
+
for sub_ui in BaseUI.sub_ui:
|
| 180 |
+
_choice = sub_ui.choice(elem_id)
|
| 181 |
+
if _choice:
|
| 182 |
+
return _choice
|
| 183 |
+
return cls.choice_dict.get(elem_id, [])
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def default(cls, elem_id):
|
| 187 |
+
"""Get choice by elem_id"""
|
| 188 |
+
if elem_id in cls.default_dict:
|
| 189 |
+
return cls.default_dict.get(elem_id)
|
| 190 |
+
for sub_ui in BaseUI.sub_ui:
|
| 191 |
+
_choice = sub_ui.default(elem_id)
|
| 192 |
+
if _choice:
|
| 193 |
+
return _choice
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def locale(cls, elem_id, lang):
|
| 198 |
+
"""Get locale by elem_id"""
|
| 199 |
+
return cls.locales(lang)[elem_id]
|
| 200 |
+
|
| 201 |
+
@classmethod
|
| 202 |
+
def locales(cls, lang):
|
| 203 |
+
"""Get locale by lang"""
|
| 204 |
+
locales = OrderedDict()
|
| 205 |
+
for sub_ui in cls.sub_ui:
|
| 206 |
+
_locales = sub_ui.locales(lang)
|
| 207 |
+
locales.update(_locales)
|
| 208 |
+
for key, value in cls.locale_dict.items():
|
| 209 |
+
locales[key] = {k: v[lang] for k, v in value.items()}
|
| 210 |
+
return locales
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def elements(cls):
|
| 214 |
+
"""Get all elements"""
|
| 215 |
+
elements = OrderedDict()
|
| 216 |
+
elements.update(cls.element_dict)
|
| 217 |
+
for sub_ui in cls.sub_ui:
|
| 218 |
+
_elements = sub_ui.elements()
|
| 219 |
+
elements.update(_elements)
|
| 220 |
+
return elements
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def valid_elements(cls):
|
| 224 |
+
valid_elements = OrderedDict()
|
| 225 |
+
elements = cls.elements()
|
| 226 |
+
for key, value in elements.items():
|
| 227 |
+
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record':
|
| 228 |
+
valid_elements[key] = value
|
| 229 |
+
return valid_elements
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def element_keys(cls):
|
| 233 |
+
return list(cls.elements().keys())
|
| 234 |
+
|
| 235 |
+
@classmethod
|
| 236 |
+
def valid_element_keys(cls):
|
| 237 |
+
return [
|
| 238 |
+
key for key, value in cls.elements().items()
|
| 239 |
+
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record'
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def element(cls, elem_id):
|
| 244 |
+
"""Get element by elem_id"""
|
| 245 |
+
elements = cls.elements()
|
| 246 |
+
return elements[elem_id]
|
| 247 |
+
|
| 248 |
+
@classmethod
|
| 249 |
+
def argument(cls, elem_id):
|
| 250 |
+
"""Get argument by elem_id"""
|
| 251 |
+
return cls.arguments.get(elem_id)
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def set_lang(cls, lang):
|
| 255 |
+
cls.lang = lang
|
| 256 |
+
for sub_ui in cls.sub_ui:
|
| 257 |
+
sub_ui.lang = lang
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
def get_choices_from_dataclass(dataclass):
|
| 261 |
+
choice_dict = {}
|
| 262 |
+
for f in fields(dataclass):
|
| 263 |
+
default_value = f.default
|
| 264 |
+
if 'MISSING_TYPE' in str(default_value):
|
| 265 |
+
default_value = None
|
| 266 |
+
if 'choices' in f.metadata:
|
| 267 |
+
choice_dict[f.name] = list(f.metadata['choices'])
|
| 268 |
+
if 'Literal' in str(f.type) and typing.get_args(f.type):
|
| 269 |
+
choice_dict[f.name] = list(typing.get_args(f.type))
|
| 270 |
+
if f.name in choice_dict and default_value not in choice_dict[f.name]:
|
| 271 |
+
choice_dict[f.name].insert(0, default_value)
|
| 272 |
+
return choice_dict
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def get_default_value_from_dataclass(dataclass):
|
| 276 |
+
default_dict = {}
|
| 277 |
+
for f in fields(dataclass):
|
| 278 |
+
if f.default.__class__ is dataclasses._MISSING_TYPE:
|
| 279 |
+
default_dict[f.name] = f.default_factory()
|
| 280 |
+
else:
|
| 281 |
+
default_dict[f.name] = f.default
|
| 282 |
+
if isinstance(default_dict[f.name], list):
|
| 283 |
+
try:
|
| 284 |
+
default_dict[f.name] = ' '.join(default_dict[f.name])
|
| 285 |
+
except TypeError:
|
| 286 |
+
default_dict[f.name] = None
|
| 287 |
+
if not default_dict[f.name]:
|
| 288 |
+
default_dict[f.name] = None
|
| 289 |
+
return default_dict
|
| 290 |
+
|
| 291 |
+
@staticmethod
|
| 292 |
+
def get_argument_names(dataclass):
|
| 293 |
+
arguments = {}
|
| 294 |
+
for f in fields(dataclass):
|
| 295 |
+
arguments[f.name] = f'--{f.name}'
|
| 296 |
+
return arguments
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def update_input_model(cls, model, allow_keys=None, has_record=True, arg_cls=BaseArguments, is_ref_model=False):
|
| 300 |
+
keys = cls.valid_element_keys()
|
| 301 |
+
if allow_keys:
|
| 302 |
+
keys = [key for key in keys if key in allow_keys]
|
| 303 |
+
|
| 304 |
+
if not model:
|
| 305 |
+
ret = [gr.update()] * (len(keys) + int(has_record))
|
| 306 |
+
if len(ret) == 1:
|
| 307 |
+
return ret[0]
|
| 308 |
+
else:
|
| 309 |
+
return ret
|
| 310 |
+
|
| 311 |
+
model_meta = get_matched_model_meta(model)
|
| 312 |
+
local_args_path = os.path.join(model, 'args.json')
|
| 313 |
+
if model_meta is None and not os.path.exists(local_args_path):
|
| 314 |
+
gr.Info(cls._locale['local_dir_alert']['value'][cls.lang])
|
| 315 |
+
ret = [gr.update()] * (len(keys) + int(has_record))
|
| 316 |
+
if len(ret) == 1:
|
| 317 |
+
return ret[0]
|
| 318 |
+
else:
|
| 319 |
+
return ret
|
| 320 |
+
|
| 321 |
+
if os.path.exists(local_args_path):
|
| 322 |
+
try:
|
| 323 |
+
if hasattr(arg_cls, 'resume_from_checkpoint'):
|
| 324 |
+
try:
|
| 325 |
+
args = arg_cls(resume_from_checkpoint=model, load_data_args=True)
|
| 326 |
+
except Exception as e:
|
| 327 |
+
if 'using `--model`' in str(e): # TODO a dirty fix
|
| 328 |
+
args = arg_cls(model=model, load_data_args=True)
|
| 329 |
+
else:
|
| 330 |
+
raise e
|
| 331 |
+
else:
|
| 332 |
+
args = arg_cls(ckpt_dir=model, load_data_args=True)
|
| 333 |
+
except ValueError:
|
| 334 |
+
return [gr.update()] * (len(keys) + int(has_record))
|
| 335 |
+
values = []
|
| 336 |
+
for key in keys:
|
| 337 |
+
arg_value = getattr(args, key, None)
|
| 338 |
+
if arg_value and key != 'model':
|
| 339 |
+
if key in ('torch_dtype', 'bnb_4bit_compute_dtype'):
|
| 340 |
+
arg_value = str(arg_value).split('.')[1]
|
| 341 |
+
if isinstance(arg_value, list) and key != 'dataset':
|
| 342 |
+
try:
|
| 343 |
+
arg_value = ' '.join(arg_value)
|
| 344 |
+
except Exception:
|
| 345 |
+
arg_value = None
|
| 346 |
+
values.append(gr.update(value=arg_value))
|
| 347 |
+
else:
|
| 348 |
+
values.append(gr.update())
|
| 349 |
+
ret = [gr.update(choices=[])] * int(has_record) + values
|
| 350 |
+
if len(ret) == 1:
|
| 351 |
+
return ret[0]
|
| 352 |
+
else:
|
| 353 |
+
return ret
|
| 354 |
+
else:
|
| 355 |
+
values = []
|
| 356 |
+
for key in keys:
|
| 357 |
+
if key not in ('template', 'model_type', 'ref_model_type', 'system'):
|
| 358 |
+
values.append(gr.update())
|
| 359 |
+
elif key in ('template', 'model_type', 'ref_model_type'):
|
| 360 |
+
if key == 'ref_model_type':
|
| 361 |
+
if is_ref_model:
|
| 362 |
+
values.append(gr.update(value=getattr(model_meta, 'model_type')))
|
| 363 |
+
else:
|
| 364 |
+
values.append(gr.update())
|
| 365 |
+
else:
|
| 366 |
+
values.append(gr.update(value=getattr(model_meta, key)))
|
| 367 |
+
else:
|
| 368 |
+
values.append(gr.update(value=TEMPLATE_MAPPING[model_meta.template].default_system))
|
| 369 |
+
|
| 370 |
+
if has_record:
|
| 371 |
+
return [gr.update(choices=cls.list_cache(model))] + values
|
| 372 |
+
else:
|
| 373 |
+
if len(values) == 1:
|
| 374 |
+
return values[0]
|
| 375 |
+
return values
|
| 376 |
+
|
| 377 |
+
@classmethod
|
| 378 |
+
def update_all_settings(cls, model, train_record, base_tab):
|
| 379 |
+
if not train_record:
|
| 380 |
+
return [gr.update()] * len(cls.elements())
|
| 381 |
+
cache = cls.load_cache(model, train_record)
|
| 382 |
+
updates = []
|
| 383 |
+
for key, value in base_tab.valid_elements().items():
|
| 384 |
+
if key in cache:
|
| 385 |
+
updates.append(gr.update(value=cache[key]))
|
| 386 |
+
else:
|
| 387 |
+
updates.append(gr.update())
|
| 388 |
+
return updates
|
ms-swift/swift/ui/llm_eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
ms-swift/swift/ui/llm_eval/eval.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
from swift.utils import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Eval(BaseUI):
|
| 13 |
+
|
| 14 |
+
group = 'llm_eval'
|
| 15 |
+
|
| 16 |
+
locale_dict = {
|
| 17 |
+
'eval_backend': {
|
| 18 |
+
'label': {
|
| 19 |
+
'zh': '评测后端',
|
| 20 |
+
'en': 'Eval backend'
|
| 21 |
+
},
|
| 22 |
+
'info': {
|
| 23 |
+
'zh': '选择评测后端',
|
| 24 |
+
'en': 'Select eval backend'
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
'eval_dataset': {
|
| 28 |
+
'label': {
|
| 29 |
+
'zh': '评测数据集',
|
| 30 |
+
'en': 'Evaluation dataset'
|
| 31 |
+
},
|
| 32 |
+
'info': {
|
| 33 |
+
'zh': '选择评测数据集,支持多选 (先选择评测后端)',
|
| 34 |
+
'en': 'Select eval dataset, multiple datasets supported (select eval backend first)'
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
'eval_limit': {
|
| 38 |
+
'label': {
|
| 39 |
+
'zh': '评测数据个数',
|
| 40 |
+
'en': 'Eval numbers for each dataset'
|
| 41 |
+
},
|
| 42 |
+
'info': {
|
| 43 |
+
'zh': '每个评测集的取样数',
|
| 44 |
+
'en': 'Number of rows sampled from each dataset'
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
'eval_output_dir': {
|
| 48 |
+
'label': {
|
| 49 |
+
'zh': '评测输出目录',
|
| 50 |
+
'en': 'Eval output dir'
|
| 51 |
+
},
|
| 52 |
+
'info': {
|
| 53 |
+
'zh': '评测结果的输出目录',
|
| 54 |
+
'en': 'The dir to save the eval results'
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
'custom_eval_config': {
|
| 58 |
+
'label': {
|
| 59 |
+
'zh': '自定义数据集评测配置',
|
| 60 |
+
'en': 'Custom eval config'
|
| 61 |
+
},
|
| 62 |
+
'info': {
|
| 63 |
+
'zh': '可以使用该配置评测自己的数据集,详见github文档的评测部分',
|
| 64 |
+
'en': 'Use this config to eval your own datasets, check the docs in github for details'
|
| 65 |
+
}
|
| 66 |
+
},
|
| 67 |
+
'eval_url': {
|
| 68 |
+
'label': {
|
| 69 |
+
'zh': '评测链接',
|
| 70 |
+
'en': 'The eval url'
|
| 71 |
+
},
|
| 72 |
+
'info': {
|
| 73 |
+
'zh':
|
| 74 |
+
'OpenAI样式的评测链接(如:http://localhost:8080/v1/chat/completions),用于评测接口(模型类型输入为实际模型类型)',
|
| 75 |
+
'en':
|
| 76 |
+
'The OpenAI style link(like: http://localhost:8080/v1/chat/completions) for '
|
| 77 |
+
'evaluation(Input actual model type into model_type)'
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
'api_key': {
|
| 81 |
+
'label': {
|
| 82 |
+
'zh': '接口token',
|
| 83 |
+
'en': 'The url token'
|
| 84 |
+
},
|
| 85 |
+
'info': {
|
| 86 |
+
'zh': 'eval_url的token',
|
| 87 |
+
'en': 'The token used with eval_url'
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
'infer_backend': {
|
| 91 |
+
'label': {
|
| 92 |
+
'zh': '推理框架',
|
| 93 |
+
'en': 'Infer backend'
|
| 94 |
+
},
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 100 |
+
try:
|
| 101 |
+
from swift.llm.argument.eval_args import EvalArguments
|
| 102 |
+
eval_dataset_dict = EvalArguments.list_eval_dataset()
|
| 103 |
+
default_backend = EvalArguments.eval_backend
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.warn(e)
|
| 106 |
+
eval_dataset_dict = {}
|
| 107 |
+
default_backend = None
|
| 108 |
+
|
| 109 |
+
with gr.Row():
|
| 110 |
+
gr.Dropdown(elem_id='eval_backend', choices=list(eval_dataset_dict.keys()), value=default_backend, scale=20)
|
| 111 |
+
gr.Dropdown(
|
| 112 |
+
elem_id='eval_dataset',
|
| 113 |
+
is_list=True,
|
| 114 |
+
choices=eval_dataset_dict.get(default_backend, []),
|
| 115 |
+
multiselect=True,
|
| 116 |
+
allow_custom_value=True,
|
| 117 |
+
scale=20)
|
| 118 |
+
gr.Textbox(elem_id='eval_limit', scale=20)
|
| 119 |
+
gr.Dropdown(elem_id='infer_backend', scale=20)
|
| 120 |
+
with gr.Row():
|
| 121 |
+
gr.Textbox(elem_id='custom_eval_config', scale=20)
|
| 122 |
+
gr.Textbox(elem_id='eval_output_dir', scale=20)
|
| 123 |
+
gr.Textbox(elem_id='eval_url', scale=20)
|
| 124 |
+
gr.Textbox(elem_id='api_key', scale=20)
|
| 125 |
+
|
| 126 |
+
def update_eval_dataset(backend):
|
| 127 |
+
return gr.update(choices=eval_dataset_dict[backend])
|
| 128 |
+
|
| 129 |
+
cls.element('eval_backend').change(update_eval_dataset, [cls.element('eval_backend')],
|
| 130 |
+
[cls.element('eval_dataset')])
|
ms-swift/swift/ui/llm_eval/model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from swift.llm import TEMPLATE_MAPPING, EvalArguments, ModelType
|
| 8 |
+
from swift.llm.model.register import get_all_models
|
| 9 |
+
from swift.ui.base import BaseUI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Model(BaseUI):
|
| 13 |
+
|
| 14 |
+
group = 'llm_eval'
|
| 15 |
+
|
| 16 |
+
locale_dict = {
|
| 17 |
+
'checkpoint': {
|
| 18 |
+
'value': {
|
| 19 |
+
'zh': '训练后的模型',
|
| 20 |
+
'en': 'Trained model'
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
'model_type': {
|
| 24 |
+
'label': {
|
| 25 |
+
'zh': '选择模型类型',
|
| 26 |
+
'en': 'Select Model Type'
|
| 27 |
+
},
|
| 28 |
+
'info': {
|
| 29 |
+
'zh': 'SWIFT已支持的模型类型',
|
| 30 |
+
'en': 'Base model type supported by SWIFT'
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
'model': {
|
| 34 |
+
'label': {
|
| 35 |
+
'zh': '模型id或路径',
|
| 36 |
+
'en': 'Model id or path'
|
| 37 |
+
},
|
| 38 |
+
'info': {
|
| 39 |
+
'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
|
| 40 |
+
'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
'reset': {
|
| 44 |
+
'value': {
|
| 45 |
+
'zh': '恢复初始值',
|
| 46 |
+
'en': 'Reset to default'
|
| 47 |
+
},
|
| 48 |
+
},
|
| 49 |
+
'template': {
|
| 50 |
+
'label': {
|
| 51 |
+
'zh': '模型Prompt模板类型',
|
| 52 |
+
'en': 'Prompt template type'
|
| 53 |
+
},
|
| 54 |
+
'info': {
|
| 55 |
+
'zh': '选择匹配模型的Prompt模板',
|
| 56 |
+
'en': 'Choose the template type of the model'
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 63 |
+
with gr.Row():
|
| 64 |
+
gr.Dropdown(
|
| 65 |
+
elem_id='model',
|
| 66 |
+
scale=20,
|
| 67 |
+
choices=get_all_models(),
|
| 68 |
+
value='Qwen/Qwen2.5-7B-Instruct',
|
| 69 |
+
allow_custom_value=True)
|
| 70 |
+
gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
|
| 71 |
+
gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 75 |
+
cls.element('model').change(
|
| 76 |
+
partial(cls.update_input_model, arg_cls=EvalArguments, has_record=False),
|
| 77 |
+
inputs=[cls.element('model')],
|
| 78 |
+
outputs=list(cls.valid_elements().values()))
|
ms-swift/swift/ui/llm_export/llm_export.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Type
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import json
|
| 12 |
+
import torch
|
| 13 |
+
from json import JSONDecodeError
|
| 14 |
+
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
| 15 |
+
|
| 16 |
+
from swift.llm import ExportArguments
|
| 17 |
+
from swift.ui.base import BaseUI
|
| 18 |
+
from swift.ui.llm_export.export import Export
|
| 19 |
+
from swift.ui.llm_export.model import Model
|
| 20 |
+
from swift.ui.llm_export.runtime import ExportRuntime
|
| 21 |
+
from swift.utils import get_device_count
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LLMExport(BaseUI):
|
| 25 |
+
group = 'llm_export'
|
| 26 |
+
|
| 27 |
+
sub_ui = [Model, Export, ExportRuntime]
|
| 28 |
+
|
| 29 |
+
locale_dict = {
|
| 30 |
+
'llm_export': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': 'LLM导出',
|
| 33 |
+
'en': 'LLM export',
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
'more_params': {
|
| 37 |
+
'label': {
|
| 38 |
+
'zh': '更多参数',
|
| 39 |
+
'en': 'More params'
|
| 40 |
+
},
|
| 41 |
+
'info': {
|
| 42 |
+
'zh': '以json格式或--xxx xxx命令行格式填入',
|
| 43 |
+
'en': 'Fill in with json format or --xxx xxx cmd format'
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
'export': {
|
| 47 |
+
'value': {
|
| 48 |
+
'zh': '开始导出',
|
| 49 |
+
'en': 'Begin Export'
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
'gpu_id': {
|
| 53 |
+
'label': {
|
| 54 |
+
'zh': '选择可用GPU',
|
| 55 |
+
'en': 'Choose GPU'
|
| 56 |
+
},
|
| 57 |
+
'info': {
|
| 58 |
+
'zh': '选择使用的GPU号,如CUDA不可用只能选择CPU',
|
| 59 |
+
'en': 'Select GPU to export'
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
choice_dict = BaseUI.get_choices_from_dataclass(ExportArguments)
|
| 65 |
+
default_dict = BaseUI.get_default_value_from_dataclass(ExportArguments)
|
| 66 |
+
arguments = BaseUI.get_argument_names(ExportArguments)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 70 |
+
with gr.TabItem(elem_id='llm_export', label=''):
|
| 71 |
+
default_device = 'cpu'
|
| 72 |
+
device_count = get_device_count()
|
| 73 |
+
if device_count > 0:
|
| 74 |
+
default_device = '0'
|
| 75 |
+
with gr.Blocks():
|
| 76 |
+
Model.build_ui(base_tab)
|
| 77 |
+
Export.build_ui(base_tab)
|
| 78 |
+
ExportRuntime.build_ui(base_tab)
|
| 79 |
+
with gr.Row():
|
| 80 |
+
gr.Textbox(elem_id='more_params', lines=4, scale=20)
|
| 81 |
+
gr.Button(elem_id='export', scale=2, variant='primary')
|
| 82 |
+
gr.Dropdown(
|
| 83 |
+
elem_id='gpu_id',
|
| 84 |
+
multiselect=True,
|
| 85 |
+
choices=[str(i) for i in range(device_count)] + ['cpu'],
|
| 86 |
+
value=default_device,
|
| 87 |
+
scale=8)
|
| 88 |
+
|
| 89 |
+
cls.element('export').click(
|
| 90 |
+
cls.export_model, list(base_tab.valid_elements().values()),
|
| 91 |
+
[cls.element('runtime_tab'), cls.element('running_tasks')])
|
| 92 |
+
|
| 93 |
+
base_tab.element('running_tasks').change(
|
| 94 |
+
partial(ExportRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
|
| 95 |
+
list(base_tab.valid_elements().values()) + [cls.element('log')])
|
| 96 |
+
ExportRuntime.element('kill_task').click(
|
| 97 |
+
ExportRuntime.kill_task,
|
| 98 |
+
[ExportRuntime.element('running_tasks')],
|
| 99 |
+
[ExportRuntime.element('running_tasks')] + [ExportRuntime.element('log')],
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def export(cls, *args):
|
| 104 |
+
export_args = cls.get_default_value_from_dataclass(ExportArguments)
|
| 105 |
+
kwargs = {}
|
| 106 |
+
kwargs_is_list = {}
|
| 107 |
+
other_kwargs = {}
|
| 108 |
+
more_params = {}
|
| 109 |
+
more_params_cmd = ''
|
| 110 |
+
keys = cls.valid_element_keys()
|
| 111 |
+
for key, value in zip(keys, args):
|
| 112 |
+
compare_value = export_args.get(key)
|
| 113 |
+
compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
|
| 114 |
+
compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
|
| 115 |
+
if key in export_args and compare_value_ui != compare_value_arg and value:
|
| 116 |
+
if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
|
| 117 |
+
value = int(value)
|
| 118 |
+
elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
|
| 119 |
+
value = float(value)
|
| 120 |
+
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
|
| 121 |
+
value = True if value.lower() == 'true' else False
|
| 122 |
+
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
|
| 123 |
+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
|
| 124 |
+
else:
|
| 125 |
+
other_kwargs[key] = value
|
| 126 |
+
if key == 'more_params' and value:
|
| 127 |
+
try:
|
| 128 |
+
more_params = json.loads(value)
|
| 129 |
+
except (JSONDecodeError or TypeError):
|
| 130 |
+
more_params_cmd = value
|
| 131 |
+
|
| 132 |
+
kwargs.update(more_params)
|
| 133 |
+
model = kwargs.get('model')
|
| 134 |
+
if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
|
| 135 |
+
kwargs['ckpt_dir'] = kwargs.pop('model')
|
| 136 |
+
export_args = ExportArguments(
|
| 137 |
+
**{
|
| 138 |
+
key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value
|
| 139 |
+
for key, value in kwargs.items()
|
| 140 |
+
})
|
| 141 |
+
params = ''
|
| 142 |
+
sep = f'{cls.quote} {cls.quote}'
|
| 143 |
+
for e in kwargs:
|
| 144 |
+
if isinstance(kwargs[e], list):
|
| 145 |
+
params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
|
| 146 |
+
elif e in kwargs_is_list and kwargs_is_list[e]:
|
| 147 |
+
all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
|
| 148 |
+
params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
|
| 149 |
+
else:
|
| 150 |
+
params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
|
| 151 |
+
params += more_params_cmd + ' '
|
| 152 |
+
devices = other_kwargs['gpu_id']
|
| 153 |
+
devices = [d for d in devices if d]
|
| 154 |
+
assert (len(devices) == 1 or 'cpu' not in devices)
|
| 155 |
+
gpus = ','.join(devices)
|
| 156 |
+
cuda_param = ''
|
| 157 |
+
if gpus != 'cpu':
|
| 158 |
+
if is_torch_npu_available():
|
| 159 |
+
cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
|
| 160 |
+
elif is_torch_cuda_available():
|
| 161 |
+
cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
|
| 162 |
+
else:
|
| 163 |
+
cuda_param = ''
|
| 164 |
+
now = datetime.now()
|
| 165 |
+
time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}'
|
| 166 |
+
file_path = f'output/{export_args.model_type}-{time_str}'
|
| 167 |
+
if not os.path.exists(file_path):
|
| 168 |
+
os.makedirs(file_path, exist_ok=True)
|
| 169 |
+
log_file = os.path.join(os.getcwd(), f'{file_path}/run_export.log')
|
| 170 |
+
export_args.log_file = log_file
|
| 171 |
+
params += f'--log_file "{log_file}" '
|
| 172 |
+
params += '--ignore_args_error true '
|
| 173 |
+
additional_param = ''
|
| 174 |
+
if export_args.quant_method == 'gptq':
|
| 175 |
+
additional_param = 'OMP_NUM_THREADS=14'
|
| 176 |
+
if sys.platform == 'win32':
|
| 177 |
+
if cuda_param:
|
| 178 |
+
cuda_param = f'set {cuda_param} && '
|
| 179 |
+
if additional_param:
|
| 180 |
+
additional_param = f'set {additional_param} && '
|
| 181 |
+
run_command = f'{cuda_param}{additional_param}start /b swift export {params} > {log_file} 2>&1'
|
| 182 |
+
else:
|
| 183 |
+
run_command = f'{cuda_param} {additional_param} nohup swift export {params} > {log_file} 2>&1 &'
|
| 184 |
+
return run_command, export_args, log_file
|
| 185 |
+
|
| 186 |
+
@classmethod
|
| 187 |
+
def export_model(cls, *args):
|
| 188 |
+
run_command, export_args, log_file = cls.export(*args)
|
| 189 |
+
os.system(run_command)
|
| 190 |
+
time.sleep(2)
|
| 191 |
+
return gr.update(open=True), ExportRuntime.refresh_tasks(log_file)
|
ms-swift/swift/ui/llm_export/model.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from swift.llm import TEMPLATE_MAPPING, ExportArguments, ModelType
|
| 8 |
+
from swift.llm.model.register import get_all_models
|
| 9 |
+
from swift.ui.base import BaseUI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Model(BaseUI):
|
| 13 |
+
|
| 14 |
+
group = 'llm_export'
|
| 15 |
+
|
| 16 |
+
locale_dict = {
|
| 17 |
+
'checkpoint': {
|
| 18 |
+
'value': {
|
| 19 |
+
'zh': '训练后的模型',
|
| 20 |
+
'en': 'Trained model'
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
'model_type': {
|
| 24 |
+
'label': {
|
| 25 |
+
'zh': '选择模型类型',
|
| 26 |
+
'en': 'Select Model Type'
|
| 27 |
+
},
|
| 28 |
+
'info': {
|
| 29 |
+
'zh': 'SWIFT已支持的模型类型',
|
| 30 |
+
'en': 'Base model type supported by SWIFT'
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
'model': {
|
| 34 |
+
'label': {
|
| 35 |
+
'zh': '模型id或路径',
|
| 36 |
+
'en': 'Model id or path'
|
| 37 |
+
},
|
| 38 |
+
'info': {
|
| 39 |
+
'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
|
| 40 |
+
'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
|
| 41 |
+
}
|
| 42 |
+
},
|
| 43 |
+
'reset': {
|
| 44 |
+
'value': {
|
| 45 |
+
'zh': '恢复初始值',
|
| 46 |
+
'en': 'Reset to default'
|
| 47 |
+
},
|
| 48 |
+
},
|
| 49 |
+
'template': {
|
| 50 |
+
'label': {
|
| 51 |
+
'zh': '模型Prompt模板类型',
|
| 52 |
+
'en': 'Prompt template type'
|
| 53 |
+
},
|
| 54 |
+
'info': {
|
| 55 |
+
'zh': '选择匹配模型的Prompt模板',
|
| 56 |
+
'en': 'Choose the template type of the model'
|
| 57 |
+
}
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
ignored_models = ['int1', 'int2', 'int4', 'int8', 'awq', 'gptq', 'bnb', 'eetq', 'aqlm', 'hqq']
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 65 |
+
with gr.Row():
|
| 66 |
+
all_models = [
|
| 67 |
+
model for model in get_all_models() if not any([ignored in model for ignored in cls.ignored_models])
|
| 68 |
+
]
|
| 69 |
+
gr.Dropdown(
|
| 70 |
+
elem_id='model',
|
| 71 |
+
scale=20,
|
| 72 |
+
choices=all_models,
|
| 73 |
+
value='Qwen/Qwen2.5-7B-Instruct',
|
| 74 |
+
allow_custom_value=True)
|
| 75 |
+
gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
|
| 76 |
+
gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 80 |
+
cls.element('model').change(
|
| 81 |
+
partial(cls.update_input_model, arg_cls=ExportArguments, has_record=False),
|
| 82 |
+
inputs=[cls.element('model')],
|
| 83 |
+
outputs=list(cls.valid_elements().values()))
|
ms-swift/swift/ui/llm_export/runtime.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from swift.ui.llm_infer.runtime import Runtime
|
| 3 |
+
from swift.utils import get_logger
|
| 4 |
+
|
| 5 |
+
logger = get_logger()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ExportRuntime(Runtime):
|
| 9 |
+
|
| 10 |
+
group = 'llm_export'
|
| 11 |
+
|
| 12 |
+
cmd = 'export'
|
| 13 |
+
|
| 14 |
+
locale_dict = {
|
| 15 |
+
'runtime_tab': {
|
| 16 |
+
'label': {
|
| 17 |
+
'zh': '运行时',
|
| 18 |
+
'en': 'Runtime'
|
| 19 |
+
},
|
| 20 |
+
},
|
| 21 |
+
'running_cmd': {
|
| 22 |
+
'label': {
|
| 23 |
+
'zh': '运行命令',
|
| 24 |
+
'en': 'Command line'
|
| 25 |
+
},
|
| 26 |
+
'info': {
|
| 27 |
+
'zh': '执行的实际命令',
|
| 28 |
+
'en': 'The actual command'
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
'show_log': {
|
| 32 |
+
'value': {
|
| 33 |
+
'zh': '展示导出状态',
|
| 34 |
+
'en': 'Show export status'
|
| 35 |
+
},
|
| 36 |
+
},
|
| 37 |
+
'stop_show_log': {
|
| 38 |
+
'value': {
|
| 39 |
+
'zh': '停止展示',
|
| 40 |
+
'en': 'Stop showing running status'
|
| 41 |
+
},
|
| 42 |
+
},
|
| 43 |
+
'log': {
|
| 44 |
+
'label': {
|
| 45 |
+
'zh': '日志输出',
|
| 46 |
+
'en': 'Logging content'
|
| 47 |
+
},
|
| 48 |
+
'info': {
|
| 49 |
+
'zh': '如果日志无更新请再次点击"展示日志内容"',
|
| 50 |
+
'en': 'Please press "Show log" if the log content is not updating'
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
'running_tasks': {
|
| 54 |
+
'label': {
|
| 55 |
+
'zh': '运行中导出任务',
|
| 56 |
+
'en': 'Running export task'
|
| 57 |
+
},
|
| 58 |
+
'info': {
|
| 59 |
+
'zh': '所有的swift export命令启动的任务',
|
| 60 |
+
'en': 'All tasks started by swift export'
|
| 61 |
+
}
|
| 62 |
+
},
|
| 63 |
+
'refresh_tasks': {
|
| 64 |
+
'value': {
|
| 65 |
+
'zh': '找回导出任务',
|
| 66 |
+
'en': 'Find export'
|
| 67 |
+
},
|
| 68 |
+
},
|
| 69 |
+
'kill_task': {
|
| 70 |
+
'value': {
|
| 71 |
+
'zh': '杀死导出任务',
|
| 72 |
+
'en': 'Kill export'
|
| 73 |
+
},
|
| 74 |
+
},
|
| 75 |
+
}
|
ms-swift/swift/ui/llm_infer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
ms-swift/swift/ui/llm_infer/generate.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Generate(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_infer'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'max_new_tokens': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '生成序列最大长度',
|
| 17 |
+
'en': 'Max new tokens'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'temperature': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': 'temperature',
|
| 23 |
+
'en': 'temperature'
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
'top_k': {
|
| 27 |
+
'label': {
|
| 28 |
+
'zh': 'top_k',
|
| 29 |
+
'en': 'top_k'
|
| 30 |
+
},
|
| 31 |
+
},
|
| 32 |
+
'top_p': {
|
| 33 |
+
'label': {
|
| 34 |
+
'zh': 'top_p',
|
| 35 |
+
'en': 'top_p'
|
| 36 |
+
},
|
| 37 |
+
},
|
| 38 |
+
'repetition_penalty': {
|
| 39 |
+
'label': {
|
| 40 |
+
'zh': 'repetition_penalty',
|
| 41 |
+
'en': 'repetition_penalty'
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
'system': {
|
| 45 |
+
'label': {
|
| 46 |
+
'zh': 'system字段',
|
| 47 |
+
'en': 'system'
|
| 48 |
+
},
|
| 49 |
+
'info': {
|
| 50 |
+
'zh': 'system字段支持在加载模型后修改',
|
| 51 |
+
'en': 'system can be modified after the model weights loaded'
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 58 |
+
with gr.Row():
|
| 59 |
+
gr.Textbox(elem_id='max_new_tokens', lines=1, value='2048')
|
| 60 |
+
gr.Slider(elem_id='temperature', minimum=0.0, maximum=10, step=0.1, value=0.3)
|
| 61 |
+
gr.Slider(elem_id='top_k', minimum=1, maximum=100, step=5, value=20)
|
| 62 |
+
gr.Slider(elem_id='top_p', minimum=0.0, maximum=1.0, step=0.05, value=0.7)
|
| 63 |
+
gr.Slider(elem_id='repetition_penalty', minimum=0.0, maximum=10, step=0.05, value=1.05)
|
| 64 |
+
with gr.Row():
|
| 65 |
+
gr.Textbox(elem_id='system', lines=4, scale=20)
|
ms-swift/swift/ui/llm_infer/llm_infer.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import signal
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import List, Type
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import json
|
| 14 |
+
import torch
|
| 15 |
+
from json import JSONDecodeError
|
| 16 |
+
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
| 17 |
+
|
| 18 |
+
from swift.llm import DeployArguments, InferArguments, InferClient, InferRequest, RequestConfig
|
| 19 |
+
from swift.ui.base import BaseUI
|
| 20 |
+
from swift.ui.llm_infer.model import Model
|
| 21 |
+
from swift.ui.llm_infer.runtime import Runtime
|
| 22 |
+
from swift.utils import get_device_count, get_logger
|
| 23 |
+
|
| 24 |
+
logger = get_logger()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LLMInfer(BaseUI):
|
| 28 |
+
|
| 29 |
+
group = 'llm_infer'
|
| 30 |
+
|
| 31 |
+
is_multimodal = True
|
| 32 |
+
|
| 33 |
+
sub_ui = [Model, Runtime]
|
| 34 |
+
|
| 35 |
+
locale_dict = {
|
| 36 |
+
'generate_alert': {
|
| 37 |
+
'value': {
|
| 38 |
+
'zh': '请先部署模型',
|
| 39 |
+
'en': 'Please deploy model first',
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
'port': {
|
| 43 |
+
'label': {
|
| 44 |
+
'zh': '端口',
|
| 45 |
+
'en': 'port'
|
| 46 |
+
},
|
| 47 |
+
},
|
| 48 |
+
'llm_infer': {
|
| 49 |
+
'label': {
|
| 50 |
+
'zh': 'LLM推理',
|
| 51 |
+
'en': 'LLM Inference',
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
'load_alert': {
|
| 55 |
+
'value': {
|
| 56 |
+
'zh': '部署中,请点击"展示部署状态"查看',
|
| 57 |
+
'en': 'Start to deploy model, '
|
| 58 |
+
'please Click "Show running '
|
| 59 |
+
'status" to view details',
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
'loaded_alert': {
|
| 63 |
+
'value': {
|
| 64 |
+
'zh': '模型加载完成',
|
| 65 |
+
'en': 'Model loaded'
|
| 66 |
+
}
|
| 67 |
+
},
|
| 68 |
+
'port_alert': {
|
| 69 |
+
'value': {
|
| 70 |
+
'zh': '该端口已被占用',
|
| 71 |
+
'en': 'The port has been occupied'
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
'chatbot': {
|
| 75 |
+
'value': {
|
| 76 |
+
'zh': '对话框',
|
| 77 |
+
'en': 'Chat bot'
|
| 78 |
+
},
|
| 79 |
+
},
|
| 80 |
+
'infer_model_type': {
|
| 81 |
+
'label': {
|
| 82 |
+
'zh': 'Lora模块',
|
| 83 |
+
'en': 'Lora module'
|
| 84 |
+
},
|
| 85 |
+
'info': {
|
| 86 |
+
'zh': '发送给server端哪个LoRA,默认为`default`',
|
| 87 |
+
'en': 'Which LoRA to use on server, default value is `default`'
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
'prompt': {
|
| 91 |
+
'label': {
|
| 92 |
+
'zh': '请输入:',
|
| 93 |
+
'en': 'Input:'
|
| 94 |
+
},
|
| 95 |
+
},
|
| 96 |
+
'clear_history': {
|
| 97 |
+
'value': {
|
| 98 |
+
'zh': '清除对话信息',
|
| 99 |
+
'en': 'Clear history'
|
| 100 |
+
},
|
| 101 |
+
},
|
| 102 |
+
'submit': {
|
| 103 |
+
'value': {
|
| 104 |
+
'zh': '🚀 发送',
|
| 105 |
+
'en': '🚀 Send'
|
| 106 |
+
},
|
| 107 |
+
},
|
| 108 |
+
'gpu_id': {
|
| 109 |
+
'label': {
|
| 110 |
+
'zh': '选择可用GPU',
|
| 111 |
+
'en': 'Choose GPU'
|
| 112 |
+
},
|
| 113 |
+
'info': {
|
| 114 |
+
'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU',
|
| 115 |
+
'en': 'Select GPU to train'
|
| 116 |
+
}
|
| 117 |
+
},
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
choice_dict = BaseUI.get_choices_from_dataclass(InferArguments)
|
| 121 |
+
default_dict = BaseUI.get_default_value_from_dataclass(InferArguments)
|
| 122 |
+
arguments = BaseUI.get_argument_names(InferArguments)
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 126 |
+
with gr.TabItem(elem_id='llm_infer', label=''):
|
| 127 |
+
default_device = 'cpu'
|
| 128 |
+
device_count = get_device_count()
|
| 129 |
+
if device_count > 0:
|
| 130 |
+
default_device = '0'
|
| 131 |
+
with gr.Blocks():
|
| 132 |
+
infer_request = gr.State(None)
|
| 133 |
+
Model.build_ui(base_tab)
|
| 134 |
+
Runtime.build_ui(base_tab)
|
| 135 |
+
with gr.Row():
|
| 136 |
+
gr.Dropdown(
|
| 137 |
+
elem_id='gpu_id',
|
| 138 |
+
multiselect=True,
|
| 139 |
+
choices=[str(i) for i in range(device_count)] + ['cpu'],
|
| 140 |
+
value=default_device,
|
| 141 |
+
scale=8)
|
| 142 |
+
infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
|
| 143 |
+
gr.Textbox(elem_id='port', lines=1, value='8000', scale=4)
|
| 144 |
+
chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
|
| 145 |
+
with gr.Row():
|
| 146 |
+
prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
|
| 147 |
+
with gr.Tabs(visible=cls.is_multimodal):
|
| 148 |
+
with gr.TabItem(label='Image'):
|
| 149 |
+
image = gr.Image(type='filepath')
|
| 150 |
+
with gr.TabItem(label='Video'):
|
| 151 |
+
video = gr.Video()
|
| 152 |
+
with gr.TabItem(label='Audio'):
|
| 153 |
+
audio = gr.Audio(type='filepath')
|
| 154 |
+
|
| 155 |
+
with gr.Row():
|
| 156 |
+
clear_history = gr.Button(elem_id='clear_history')
|
| 157 |
+
submit = gr.Button(elem_id='submit')
|
| 158 |
+
|
| 159 |
+
cls.element('load_checkpoint').click(
|
| 160 |
+
cls.deploy_model, list(base_tab.valid_elements().values()),
|
| 161 |
+
[cls.element('runtime_tab'), cls.element('running_tasks')])
|
| 162 |
+
submit.click(
|
| 163 |
+
cls.send_message,
|
| 164 |
+
inputs=[
|
| 165 |
+
cls.element('running_tasks'),
|
| 166 |
+
cls.element('template'), prompt, image, video, audio, infer_request, infer_model_type,
|
| 167 |
+
cls.element('system'),
|
| 168 |
+
cls.element('max_new_tokens'),
|
| 169 |
+
cls.element('temperature'),
|
| 170 |
+
cls.element('top_k'),
|
| 171 |
+
cls.element('top_p'),
|
| 172 |
+
cls.element('repetition_penalty')
|
| 173 |
+
],
|
| 174 |
+
outputs=[prompt, chatbot, image, video, audio, infer_request],
|
| 175 |
+
queue=True)
|
| 176 |
+
|
| 177 |
+
clear_history.click(
|
| 178 |
+
fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot, image, video, audio, infer_request])
|
| 179 |
+
|
| 180 |
+
base_tab.element('running_tasks').change(
|
| 181 |
+
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
|
| 182 |
+
list(cls.valid_elements().values()) + [cls.element('log')])
|
| 183 |
+
Runtime.element('kill_task').click(
|
| 184 |
+
Runtime.kill_task,
|
| 185 |
+
[Runtime.element('running_tasks')],
|
| 186 |
+
[Runtime.element('running_tasks')] + [Runtime.element('log')],
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def deploy(cls, *args):
|
| 191 |
+
deploy_args = cls.get_default_value_from_dataclass(DeployArguments)
|
| 192 |
+
kwargs = {}
|
| 193 |
+
kwargs_is_list = {}
|
| 194 |
+
other_kwargs = {}
|
| 195 |
+
more_params = {}
|
| 196 |
+
more_params_cmd = ''
|
| 197 |
+
keys = cls.valid_element_keys()
|
| 198 |
+
for key, value in zip(keys, args):
|
| 199 |
+
compare_value = deploy_args.get(key)
|
| 200 |
+
compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
|
| 201 |
+
compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
|
| 202 |
+
if key in deploy_args and compare_value_ui != compare_value_arg and value:
|
| 203 |
+
if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
|
| 204 |
+
value = int(value)
|
| 205 |
+
elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
|
| 206 |
+
value = float(value)
|
| 207 |
+
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
|
| 208 |
+
value = True if value.lower() == 'true' else False
|
| 209 |
+
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
|
| 210 |
+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
|
| 211 |
+
else:
|
| 212 |
+
other_kwargs[key] = value
|
| 213 |
+
if key == 'more_params' and value:
|
| 214 |
+
try:
|
| 215 |
+
more_params = json.loads(value)
|
| 216 |
+
except (JSONDecodeError or TypeError):
|
| 217 |
+
more_params_cmd = value
|
| 218 |
+
|
| 219 |
+
kwargs.update(more_params)
|
| 220 |
+
model = kwargs.get('model')
|
| 221 |
+
if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
|
| 222 |
+
kwargs['ckpt_dir'] = kwargs.pop('model')
|
| 223 |
+
with open(os.path.join(kwargs['ckpt_dir'], 'args.json'), 'r', encoding='utf-8') as f:
|
| 224 |
+
_json = json.load(f)
|
| 225 |
+
kwargs['model_type'] = _json['model_type']
|
| 226 |
+
kwargs['train_type'] = _json['train_type']
|
| 227 |
+
deploy_args = DeployArguments(
|
| 228 |
+
**{
|
| 229 |
+
key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value
|
| 230 |
+
for key, value in kwargs.items()
|
| 231 |
+
})
|
| 232 |
+
if deploy_args.port in Runtime.get_all_ports():
|
| 233 |
+
raise gr.Error(cls.locale('port_alert', cls.lang)['value'])
|
| 234 |
+
params = ''
|
| 235 |
+
sep = f'{cls.quote} {cls.quote}'
|
| 236 |
+
for e in kwargs:
|
| 237 |
+
if isinstance(kwargs[e], list):
|
| 238 |
+
params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
|
| 239 |
+
elif e in kwargs_is_list and kwargs_is_list[e]:
|
| 240 |
+
all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
|
| 241 |
+
params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
|
| 242 |
+
else:
|
| 243 |
+
params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
|
| 244 |
+
if 'port' not in kwargs:
|
| 245 |
+
params += f'--port "{deploy_args.port}" '
|
| 246 |
+
params += more_params_cmd + ' '
|
| 247 |
+
devices = other_kwargs['gpu_id']
|
| 248 |
+
devices = [d for d in devices if d]
|
| 249 |
+
assert (len(devices) == 1 or 'cpu' not in devices)
|
| 250 |
+
gpus = ','.join(devices)
|
| 251 |
+
cuda_param = ''
|
| 252 |
+
if gpus != 'cpu':
|
| 253 |
+
if is_torch_npu_available():
|
| 254 |
+
cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
|
| 255 |
+
elif is_torch_cuda_available():
|
| 256 |
+
cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
|
| 257 |
+
else:
|
| 258 |
+
cuda_param = ''
|
| 259 |
+
now = datetime.now()
|
| 260 |
+
time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}'
|
| 261 |
+
file_path = f'output/{deploy_args.model_type}-{time_str}'
|
| 262 |
+
if not os.path.exists(file_path):
|
| 263 |
+
os.makedirs(file_path, exist_ok=True)
|
| 264 |
+
log_file = os.path.join(os.getcwd(), f'{file_path}/run_deploy.log')
|
| 265 |
+
deploy_args.log_file = log_file
|
| 266 |
+
params += f'--log_file "{log_file}" '
|
| 267 |
+
params += '--ignore_args_error true '
|
| 268 |
+
if sys.platform == 'win32':
|
| 269 |
+
if cuda_param:
|
| 270 |
+
cuda_param = f'set {cuda_param} && '
|
| 271 |
+
run_command = f'{cuda_param}start /b swift deploy {params} > {log_file} 2>&1'
|
| 272 |
+
else:
|
| 273 |
+
run_command = f'{cuda_param} nohup swift deploy {params} > {log_file} 2>&1 &'
|
| 274 |
+
return run_command, deploy_args, log_file
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
def deploy_model(cls, *args):
|
| 278 |
+
run_command, deploy_args, log_file = cls.deploy(*args)
|
| 279 |
+
logger.info(f'Running deployment command: {run_command}')
|
| 280 |
+
os.system(run_command)
|
| 281 |
+
gr.Info(cls.locale('load_alert', cls.lang)['value'])
|
| 282 |
+
time.sleep(2)
|
| 283 |
+
running_task = Runtime.refresh_tasks(log_file)
|
| 284 |
+
return gr.update(open=True), running_task
|
| 285 |
+
|
| 286 |
+
@classmethod
|
| 287 |
+
def register_clean_hook(cls):
|
| 288 |
+
signal.signal(signal.SIGINT, LLMInfer.signal_handler)
|
| 289 |
+
if os.name != 'nt':
|
| 290 |
+
signal.signal(signal.SIGTERM, LLMInfer.signal_handler)
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def signal_handler(*args, **kwargs):
|
| 294 |
+
LLMInfer.clean_deployment()
|
| 295 |
+
sys.exit(0)
|
| 296 |
+
|
| 297 |
+
@classmethod
|
| 298 |
+
def clear_session(cls):
|
| 299 |
+
return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), []
|
| 300 |
+
|
| 301 |
+
@classmethod
|
| 302 |
+
def _replace_tag_with_media(cls, infer_request: InferRequest):
|
| 303 |
+
total_history = []
|
| 304 |
+
messages = deepcopy(infer_request.messages)
|
| 305 |
+
if messages[0]['role'] == 'system':
|
| 306 |
+
messages.pop(0)
|
| 307 |
+
for i in range(0, len(messages), 2):
|
| 308 |
+
slices = messages[i:i + 2]
|
| 309 |
+
if len(slices) == 2:
|
| 310 |
+
user, assistant = slices
|
| 311 |
+
else:
|
| 312 |
+
user = slices[0]
|
| 313 |
+
assistant = {'role': 'assistant', 'content': None}
|
| 314 |
+
user['content'] = (user['content'] or '').replace('<image>', '').replace('<video>',
|
| 315 |
+
'').replace('<audio>', '').strip()
|
| 316 |
+
for media in user['medias']:
|
| 317 |
+
total_history.append([(media, ), None])
|
| 318 |
+
if user['content'] or assistant['content']:
|
| 319 |
+
total_history.append((user['content'], assistant['content']))
|
| 320 |
+
return total_history
|
| 321 |
+
|
| 322 |
+
@classmethod
|
| 323 |
+
def agent_type(cls, response):
|
| 324 |
+
if not response:
|
| 325 |
+
return None
|
| 326 |
+
if response.lower().endswith('observation:'):
|
| 327 |
+
return 'react'
|
| 328 |
+
if 'observation:' not in response.lower() and 'action input:' in response.lower():
|
| 329 |
+
return 'toolbench'
|
| 330 |
+
return None
|
| 331 |
+
|
| 332 |
+
@classmethod
|
| 333 |
+
def send_message(cls, running_task, template_type, prompt: str, image, video, audio, infer_request: InferRequest,
|
| 334 |
+
infer_model_type, system, max_new_tokens, temperature, top_k, top_p, repetition_penalty):
|
| 335 |
+
|
| 336 |
+
if not infer_request:
|
| 337 |
+
infer_request = InferRequest(messages=[])
|
| 338 |
+
if system:
|
| 339 |
+
if not infer_request.messages or infer_request.messages[0]['role'] != 'system':
|
| 340 |
+
infer_request.messages.insert(0, {'role': 'system', 'content': system})
|
| 341 |
+
else:
|
| 342 |
+
infer_request.messages[0]['content'] = system
|
| 343 |
+
if not infer_request.messages or infer_request.messages[-1]['role'] != 'user':
|
| 344 |
+
infer_request.messages.append({'role': 'user', 'content': '', 'medias': []})
|
| 345 |
+
media = image or video or audio
|
| 346 |
+
media_type = 'images' if image else 'videos' if video else 'audios'
|
| 347 |
+
if media:
|
| 348 |
+
_saved_medias: List = getattr(infer_request, media_type)
|
| 349 |
+
if not _saved_medias or _saved_medias[-1] != media:
|
| 350 |
+
_saved_medias.append(media)
|
| 351 |
+
infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + f'<{media_type[:-1]}>'
|
| 352 |
+
infer_request.messages[-1]['medias'].append(media)
|
| 353 |
+
|
| 354 |
+
if not prompt:
|
| 355 |
+
yield '', cls._replace_tag_with_media(infer_request), gr.update(value=None), gr.update(
|
| 356 |
+
value=None), gr.update(value=None), infer_request
|
| 357 |
+
return
|
| 358 |
+
else:
|
| 359 |
+
infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt
|
| 360 |
+
|
| 361 |
+
_, args = Runtime.parse_info_from_cmdline(running_task)
|
| 362 |
+
request_config = RequestConfig(
|
| 363 |
+
temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
|
| 364 |
+
request_config.stream = True
|
| 365 |
+
request_config.stop = ['Observation:']
|
| 366 |
+
request_config.max_tokens = max_new_tokens
|
| 367 |
+
stream_resp_with_history = ''
|
| 368 |
+
response = ''
|
| 369 |
+
i = len(infer_request.messages) - 1
|
| 370 |
+
for i in range(len(infer_request.messages) - 1, -1, -1):
|
| 371 |
+
if infer_request.messages[i]['role'] == 'assistant':
|
| 372 |
+
response = infer_request.messages[i]['content']
|
| 373 |
+
agent_type = cls.agent_type(response)
|
| 374 |
+
if i != len(infer_request.messages) - 1 and agent_type == 'toolbench':
|
| 375 |
+
infer_request.messages[i + 1]['role'] = 'tool'
|
| 376 |
+
|
| 377 |
+
chat = not template_type.endswith('generation')
|
| 378 |
+
_infer_request = deepcopy(infer_request)
|
| 379 |
+
for m in _infer_request.messages:
|
| 380 |
+
if 'medias' in m:
|
| 381 |
+
m.pop('medias')
|
| 382 |
+
model_kwargs = {}
|
| 383 |
+
if infer_model_type:
|
| 384 |
+
model_kwargs = {'model': infer_model_type}
|
| 385 |
+
gen_list = InferClient(
|
| 386 |
+
port=args['port'], ).infer(
|
| 387 |
+
infer_requests=[_infer_request], request_config=request_config, **model_kwargs)
|
| 388 |
+
if infer_request.messages[-1]['role'] != 'assistant':
|
| 389 |
+
infer_request.messages.append({'role': 'assistant', 'content': ''})
|
| 390 |
+
for chunk in gen_list[0]:
|
| 391 |
+
if chunk is None:
|
| 392 |
+
continue
|
| 393 |
+
stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text
|
| 394 |
+
infer_request.messages[-1]['content'] = stream_resp_with_history
|
| 395 |
+
yield '', cls._replace_tag_with_media(infer_request), gr.update(value=None), gr.update(
|
| 396 |
+
value=None), gr.update(value=None), infer_request
|
ms-swift/swift/ui/llm_infer/model.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from swift.llm import TEMPLATE_MAPPING, DeployArguments, ModelType
|
| 8 |
+
from swift.llm.model.register import get_all_models
|
| 9 |
+
from swift.ui.base import BaseUI
|
| 10 |
+
from swift.ui.llm_infer.generate import Generate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Model(BaseUI):
|
| 14 |
+
|
| 15 |
+
llm_train = 'llm_infer'
|
| 16 |
+
|
| 17 |
+
sub_ui = [Generate]
|
| 18 |
+
|
| 19 |
+
locale_dict = {
|
| 20 |
+
'model_type': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '选择模型类型',
|
| 23 |
+
'en': 'Select Model Type'
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': 'SWIFT已支持的模型类型',
|
| 27 |
+
'en': 'Base model type supported by SWIFT'
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'load_checkpoint': {
|
| 31 |
+
'value': {
|
| 32 |
+
'zh': '部署模型',
|
| 33 |
+
'en': 'Deploy model',
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
'model': {
|
| 37 |
+
'label': {
|
| 38 |
+
'zh': '模型id或路径',
|
| 39 |
+
'en': 'Model id or path'
|
| 40 |
+
},
|
| 41 |
+
'info': {
|
| 42 |
+
'zh': '实际的模型id,如果是训练后的模型请填入checkpoint-xxx的目录',
|
| 43 |
+
'en': 'The actual model id or path, if is a trained model, please fill in the checkpoint-xxx dir'
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
'template': {
|
| 47 |
+
'label': {
|
| 48 |
+
'zh': '模型Prompt模板类型',
|
| 49 |
+
'en': 'Prompt template type'
|
| 50 |
+
},
|
| 51 |
+
'info': {
|
| 52 |
+
'zh': '选择匹配模型的Prompt模板',
|
| 53 |
+
'en': 'Choose the template type of the model'
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
'merge_lora': {
|
| 57 |
+
'label': {
|
| 58 |
+
'zh': '合并lora',
|
| 59 |
+
'en': 'merge lora'
|
| 60 |
+
},
|
| 61 |
+
'info': {
|
| 62 |
+
'zh': '仅在sft_type=lora时可用',
|
| 63 |
+
'en': 'Only available when sft_type=lora'
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
'lora_modules': {
|
| 67 |
+
'label': {
|
| 68 |
+
'zh': '外部lora模块',
|
| 69 |
+
'en': 'More lora modules'
|
| 70 |
+
},
|
| 71 |
+
'info': {
|
| 72 |
+
'zh': '空格分割的name=/path1/path2键值对',
|
| 73 |
+
'en': 'name=/path1/path2 split by blanks'
|
| 74 |
+
}
|
| 75 |
+
},
|
| 76 |
+
'more_params': {
|
| 77 |
+
'label': {
|
| 78 |
+
'zh': '更多参数',
|
| 79 |
+
'en': 'More params'
|
| 80 |
+
},
|
| 81 |
+
'info': {
|
| 82 |
+
'zh': '以json格式或--xxx xxx命令行格式填入',
|
| 83 |
+
'en': 'Fill in with json format or --xxx xxx cmd format'
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
'reset': {
|
| 87 |
+
'value': {
|
| 88 |
+
'zh': '恢复初始值',
|
| 89 |
+
'en': 'Reset to default'
|
| 90 |
+
},
|
| 91 |
+
},
|
| 92 |
+
'infer_backend': {
|
| 93 |
+
'label': {
|
| 94 |
+
'zh': '推理框架',
|
| 95 |
+
'en': 'Infer backend'
|
| 96 |
+
},
|
| 97 |
+
},
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 102 |
+
with gr.Row():
|
| 103 |
+
gr.Dropdown(
|
| 104 |
+
elem_id='model',
|
| 105 |
+
scale=20,
|
| 106 |
+
choices=get_all_models(),
|
| 107 |
+
value='Qwen/Qwen2.5-7B-Instruct',
|
| 108 |
+
allow_custom_value=True)
|
| 109 |
+
gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
|
| 110 |
+
gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
|
| 111 |
+
gr.Checkbox(elem_id='merge_lora', scale=4)
|
| 112 |
+
gr.Button(elem_id='reset', scale=2)
|
| 113 |
+
with gr.Row():
|
| 114 |
+
gr.Dropdown(elem_id='infer_backend', value='pt', scale=5)
|
| 115 |
+
Generate.build_ui(base_tab)
|
| 116 |
+
with gr.Row():
|
| 117 |
+
gr.Textbox(elem_id='lora_modules', lines=1, is_list=True, scale=40)
|
| 118 |
+
gr.Textbox(elem_id='more_params', lines=1, scale=20)
|
| 119 |
+
gr.Button(elem_id='load_checkpoint', scale=2, variant='primary')
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 123 |
+
cls.element('model').change(
|
| 124 |
+
partial(cls.update_input_model, arg_cls=DeployArguments, has_record=False),
|
| 125 |
+
inputs=[cls.element('model')],
|
| 126 |
+
outputs=list(cls.valid_elements().values()))
|
ms-swift/swift/ui/llm_infer/runtime.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
import os.path
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, List, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import psutil
|
| 11 |
+
from packaging import version
|
| 12 |
+
|
| 13 |
+
from swift.ui.base import BaseUI
|
| 14 |
+
from swift.utils import get_logger
|
| 15 |
+
from swift.utils.utils import format_time
|
| 16 |
+
|
| 17 |
+
logger = get_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Runtime(BaseUI):
|
| 21 |
+
handlers: Dict[str, Tuple[List, Tuple]] = {}
|
| 22 |
+
|
| 23 |
+
group = 'llm_infer'
|
| 24 |
+
|
| 25 |
+
cmd = 'deploy'
|
| 26 |
+
|
| 27 |
+
log_event = {}
|
| 28 |
+
|
| 29 |
+
locale_dict = {
|
| 30 |
+
'runtime_tab': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '运行时',
|
| 33 |
+
'en': 'Runtime'
|
| 34 |
+
},
|
| 35 |
+
},
|
| 36 |
+
'running_cmd': {
|
| 37 |
+
'label': {
|
| 38 |
+
'zh': '运行命令',
|
| 39 |
+
'en': 'Command line'
|
| 40 |
+
},
|
| 41 |
+
'info': {
|
| 42 |
+
'zh': '执行的实际命令',
|
| 43 |
+
'en': 'The actual command'
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
'show_log': {
|
| 47 |
+
'value': {
|
| 48 |
+
'zh': '展示部署状态',
|
| 49 |
+
'en': 'Show running status'
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
'stop_show_log': {
|
| 53 |
+
'value': {
|
| 54 |
+
'zh': '停止展示',
|
| 55 |
+
'en': 'Stop showing running status'
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
'log': {
|
| 59 |
+
'label': {
|
| 60 |
+
'zh': '日志输出',
|
| 61 |
+
'en': 'Logging content'
|
| 62 |
+
},
|
| 63 |
+
'info': {
|
| 64 |
+
'zh': '如果日志无更新请再次点击"展示日志内容"',
|
| 65 |
+
'en': 'Please press "Show log" if the log content is not updating'
|
| 66 |
+
}
|
| 67 |
+
},
|
| 68 |
+
'running_tasks': {
|
| 69 |
+
'label': {
|
| 70 |
+
'zh': '运行中部署',
|
| 71 |
+
'en': 'Running deployments'
|
| 72 |
+
},
|
| 73 |
+
'info': {
|
| 74 |
+
'zh': '所有的swift deploy命令启动的任务',
|
| 75 |
+
'en': 'Started by swift deploy'
|
| 76 |
+
}
|
| 77 |
+
},
|
| 78 |
+
'refresh_tasks': {
|
| 79 |
+
'value': {
|
| 80 |
+
'zh': '找回部署',
|
| 81 |
+
'en': 'Find deployments'
|
| 82 |
+
},
|
| 83 |
+
},
|
| 84 |
+
'kill_task': {
|
| 85 |
+
'value': {
|
| 86 |
+
'zh': '杀死部署',
|
| 87 |
+
'en': 'Kill running task'
|
| 88 |
+
},
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 94 |
+
with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
|
| 95 |
+
with gr.Blocks():
|
| 96 |
+
with gr.Row():
|
| 97 |
+
gr.Dropdown(elem_id='running_tasks', scale=10, allow_custom_value=True)
|
| 98 |
+
gr.Button(elem_id='refresh_tasks', scale=1, variant='primary')
|
| 99 |
+
gr.Button(elem_id='show_log', scale=1, variant='primary')
|
| 100 |
+
gr.Button(elem_id='stop_show_log', scale=1)
|
| 101 |
+
gr.Button(elem_id='kill_task', scale=1)
|
| 102 |
+
with gr.Row():
|
| 103 |
+
gr.Textbox(elem_id='log', lines=6, visible=False)
|
| 104 |
+
|
| 105 |
+
concurrency_limit = {}
|
| 106 |
+
if version.parse(gr.__version__) >= version.parse('4.0.0'):
|
| 107 |
+
concurrency_limit = {'concurrency_limit': 5}
|
| 108 |
+
base_tab.element('show_log').click(cls.update_log, [],
|
| 109 |
+
[cls.element('log')]).then(cls.wait,
|
| 110 |
+
[base_tab.element('running_tasks')],
|
| 111 |
+
[cls.element('log')], **concurrency_limit)
|
| 112 |
+
|
| 113 |
+
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])
|
| 114 |
+
|
| 115 |
+
base_tab.element('refresh_tasks').click(
|
| 116 |
+
cls.refresh_tasks,
|
| 117 |
+
[base_tab.element('running_tasks')],
|
| 118 |
+
[base_tab.element('running_tasks')],
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def break_log_event(cls, task):
|
| 123 |
+
if not task:
|
| 124 |
+
return
|
| 125 |
+
pid, all_args = cls.parse_info_from_cmdline(task)
|
| 126 |
+
cls.log_event[all_args['log_file']] = True
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def update_log(cls):
|
| 130 |
+
return gr.update(visible=True)
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def wait(cls, task):
|
| 134 |
+
if not task:
|
| 135 |
+
return [None]
|
| 136 |
+
_, args = cls.parse_info_from_cmdline(task)
|
| 137 |
+
log_file = args['log_file']
|
| 138 |
+
cls.log_event[log_file] = False
|
| 139 |
+
offset = 0
|
| 140 |
+
latest_data = ''
|
| 141 |
+
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
|
| 142 |
+
try:
|
| 143 |
+
with open(log_file, 'r', encoding='utf-8') as input:
|
| 144 |
+
input.seek(offset)
|
| 145 |
+
fail_cnt = 0
|
| 146 |
+
while True:
|
| 147 |
+
try:
|
| 148 |
+
latest_data += input.read()
|
| 149 |
+
except UnicodeDecodeError:
|
| 150 |
+
continue
|
| 151 |
+
if not latest_data:
|
| 152 |
+
time.sleep(0.5)
|
| 153 |
+
fail_cnt += 1
|
| 154 |
+
if fail_cnt > 50:
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
if cls.log_event.get(log_file, False):
|
| 158 |
+
cls.log_event[log_file] = False
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
if '\n' not in latest_data:
|
| 162 |
+
continue
|
| 163 |
+
latest_lines = latest_data.split('\n')
|
| 164 |
+
if latest_data[-1] != '\n':
|
| 165 |
+
latest_data = latest_lines[-1]
|
| 166 |
+
latest_lines = latest_lines[:-1]
|
| 167 |
+
else:
|
| 168 |
+
latest_data = ''
|
| 169 |
+
lines.extend(latest_lines)
|
| 170 |
+
yield '\n'.join(lines)
|
| 171 |
+
except IOError:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def get_all_ports(cls):
|
| 176 |
+
process_name = 'swift'
|
| 177 |
+
cmd_name = cls.cmd
|
| 178 |
+
ports = set()
|
| 179 |
+
for proc in psutil.process_iter():
|
| 180 |
+
try:
|
| 181 |
+
cmdlines = proc.cmdline()
|
| 182 |
+
except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
|
| 183 |
+
cmdlines = []
|
| 184 |
+
if any([process_name in cmdline for cmdline in cmdlines]) and any( # noqa
|
| 185 |
+
[cmd_name == cmdline for cmdline in cmdlines]): # noqa
|
| 186 |
+
try:
|
| 187 |
+
ports.add(int(cls.parse_info_from_cmdline(cls.construct_running_task(proc))[1].get('port', 8000)))
|
| 188 |
+
except IndexError:
|
| 189 |
+
pass
|
| 190 |
+
return ports
|
| 191 |
+
|
| 192 |
+
@classmethod
|
| 193 |
+
def refresh_tasks(cls, running_task=None):
|
| 194 |
+
log_file = running_task if not running_task or 'pid:' not in running_task else None
|
| 195 |
+
process_name = 'swift'
|
| 196 |
+
negative_name = 'swift.exe'
|
| 197 |
+
cmd_name = cls.cmd
|
| 198 |
+
process = []
|
| 199 |
+
selected = None
|
| 200 |
+
for proc in psutil.process_iter():
|
| 201 |
+
try:
|
| 202 |
+
cmdlines = proc.cmdline()
|
| 203 |
+
except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
|
| 204 |
+
cmdlines = []
|
| 205 |
+
if any([process_name in cmdline
|
| 206 |
+
for cmdline in cmdlines]) and not any([negative_name in cmdline
|
| 207 |
+
for cmdline in cmdlines]) and any( # noqa
|
| 208 |
+
[cmd_name == cmdline for cmdline in cmdlines]): # noqa
|
| 209 |
+
process.append(cls.construct_running_task(proc))
|
| 210 |
+
if log_file is not None and any( # noqa
|
| 211 |
+
[log_file == cmdline for cmdline in cmdlines]): # noqa
|
| 212 |
+
selected = cls.construct_running_task(proc)
|
| 213 |
+
if not selected:
|
| 214 |
+
if running_task and running_task in process:
|
| 215 |
+
selected = running_task
|
| 216 |
+
if not selected and process:
|
| 217 |
+
selected = process[0]
|
| 218 |
+
return gr.update(choices=process, value=selected)
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def construct_running_task(proc):
|
| 222 |
+
pid = proc.pid
|
| 223 |
+
ts = time.time()
|
| 224 |
+
create_time = proc.create_time()
|
| 225 |
+
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
|
| 226 |
+
|
| 227 |
+
return f'pid:{pid}/create:{create_time_formatted}' \
|
| 228 |
+
f'/running:{format_time(ts - create_time)}/cmd:{" ".join(proc.cmdline())}'
|
| 229 |
+
|
| 230 |
+
@classmethod
|
| 231 |
+
def parse_info_from_cmdline(cls, task):
|
| 232 |
+
pid = None
|
| 233 |
+
for i in range(3):
|
| 234 |
+
slash = task.find('/')
|
| 235 |
+
if i == 0:
|
| 236 |
+
pid = task[:slash].split(':')[1]
|
| 237 |
+
task = task[slash + 1:]
|
| 238 |
+
args = task.split(f'swift {cls.cmd}')[1]
|
| 239 |
+
args = [arg.strip() for arg in args.split('--') if arg.strip()]
|
| 240 |
+
all_args = {}
|
| 241 |
+
for i in range(len(args)):
|
| 242 |
+
space = args[i].find(' ')
|
| 243 |
+
splits = args[i][:space], args[i][space + 1:]
|
| 244 |
+
all_args[splits[0]] = splits[1]
|
| 245 |
+
return pid, all_args
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def kill_task(cls, task):
|
| 249 |
+
if task:
|
| 250 |
+
pid, all_args = cls.parse_info_from_cmdline(task)
|
| 251 |
+
log_file = all_args['log_file']
|
| 252 |
+
if sys.platform == 'win32':
|
| 253 |
+
os.system(f'taskkill /f /t /pid "{pid}"')
|
| 254 |
+
else:
|
| 255 |
+
os.system(f'pkill -9 -f {log_file}')
|
| 256 |
+
time.sleep(1)
|
| 257 |
+
cls.break_log_event(task)
|
| 258 |
+
return [cls.refresh_tasks()] + [gr.update(value=None)]
|
| 259 |
+
|
| 260 |
+
@classmethod
|
| 261 |
+
def task_changed(cls, task, base_tab):
|
| 262 |
+
if task:
|
| 263 |
+
_, all_args = cls.parse_info_from_cmdline(task)
|
| 264 |
+
else:
|
| 265 |
+
all_args = {}
|
| 266 |
+
elements = list(base_tab.valid_elements().values())
|
| 267 |
+
ret = []
|
| 268 |
+
is_custom_path = 'ckpt_dir' in all_args
|
| 269 |
+
for e in elements:
|
| 270 |
+
if e.elem_id in all_args:
|
| 271 |
+
if isinstance(e, gr.Dropdown) and e.multiselect:
|
| 272 |
+
arg = all_args[e.elem_id].split(' ')
|
| 273 |
+
else:
|
| 274 |
+
if e.elem_id == 'model':
|
| 275 |
+
if is_custom_path:
|
| 276 |
+
arg = all_args['ckpt_dir']
|
| 277 |
+
else:
|
| 278 |
+
arg = all_args[e.elem_id]
|
| 279 |
+
else:
|
| 280 |
+
arg = all_args[e.elem_id]
|
| 281 |
+
ret.append(gr.update(value=arg))
|
| 282 |
+
else:
|
| 283 |
+
ret.append(gr.update())
|
| 284 |
+
cls.break_log_event(task)
|
| 285 |
+
return ret + [gr.update(value=None)]
|
ms-swift/swift/ui/llm_train/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
ms-swift/swift/ui/llm_train/advanced.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Advanced(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'advanced_param': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '高级参数设置',
|
| 17 |
+
'en': 'Advanced settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'optim': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': 'Optimizer类型',
|
| 23 |
+
'en': 'The Optimizer type'
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '设置Optimizer类型',
|
| 27 |
+
'en': 'Set the Optimizer type'
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'weight_decay': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '权重衰减',
|
| 33 |
+
'en': 'Weight decay'
|
| 34 |
+
},
|
| 35 |
+
'info': {
|
| 36 |
+
'zh': '设置weight decay',
|
| 37 |
+
'en': 'Set the weight decay'
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
'logging_steps': {
|
| 41 |
+
'label': {
|
| 42 |
+
'zh': '日志打印步数',
|
| 43 |
+
'en': 'Logging steps'
|
| 44 |
+
},
|
| 45 |
+
'info': {
|
| 46 |
+
'zh': '设置日志打印的步数间隔',
|
| 47 |
+
'en': 'Set the logging interval'
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
'lr_scheduler_type': {
|
| 51 |
+
'label': {
|
| 52 |
+
'zh': 'LrScheduler类型',
|
| 53 |
+
'en': 'The LrScheduler type'
|
| 54 |
+
},
|
| 55 |
+
'info': {
|
| 56 |
+
'zh': '设置LrScheduler类型',
|
| 57 |
+
'en': 'Set the LrScheduler type'
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
'warmup_ratio': {
|
| 61 |
+
'label': {
|
| 62 |
+
'zh': '学习率warmup比例',
|
| 63 |
+
'en': 'Lr warmup ratio'
|
| 64 |
+
},
|
| 65 |
+
'info': {
|
| 66 |
+
'zh': '设置学习率warmup比例',
|
| 67 |
+
'en': 'Set the warmup ratio in total steps'
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
'more_params': {
|
| 71 |
+
'label': {
|
| 72 |
+
'zh': '其他高级参数',
|
| 73 |
+
'en': 'Other params'
|
| 74 |
+
},
|
| 75 |
+
'info': {
|
| 76 |
+
'zh': '以json格式或--xxx xxx命令行格式填入',
|
| 77 |
+
'en': 'Fill in with json format or --xxx xxx cmd format'
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
'truncation_strategy': {
|
| 81 |
+
'label': {
|
| 82 |
+
'zh': '数据集超长策略',
|
| 83 |
+
'en': 'Dataset truncation strategy'
|
| 84 |
+
},
|
| 85 |
+
'info': {
|
| 86 |
+
'zh': '如果token超长该如何处理',
|
| 87 |
+
'en': 'How to deal with the rows exceed the max length'
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
'max_steps': {
|
| 91 |
+
'label': {
|
| 92 |
+
'zh': '最大迭代步数',
|
| 93 |
+
'en': 'Max steps',
|
| 94 |
+
},
|
| 95 |
+
'info': {
|
| 96 |
+
'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
|
| 97 |
+
'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
'per_device_eval_batch_size': {
|
| 101 |
+
'label': {
|
| 102 |
+
'zh': '验证batch size',
|
| 103 |
+
'en': 'Val batch size',
|
| 104 |
+
},
|
| 105 |
+
'info': {
|
| 106 |
+
'zh': '验证的batch size',
|
| 107 |
+
'en': 'Set the val batch size',
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
'max_grad_norm': {
|
| 111 |
+
'label': {
|
| 112 |
+
'zh': '梯度裁剪',
|
| 113 |
+
'en': 'Max grad norm',
|
| 114 |
+
},
|
| 115 |
+
'info': {
|
| 116 |
+
'zh': '设置梯度裁剪',
|
| 117 |
+
'en': 'Set the max grad norm',
|
| 118 |
+
}
|
| 119 |
+
},
|
| 120 |
+
'predict_with_generate': {
|
| 121 |
+
'label': {
|
| 122 |
+
'zh': '使用生成指标代替loss',
|
| 123 |
+
'en': 'Use generate metric instead of loss',
|
| 124 |
+
},
|
| 125 |
+
'info': {
|
| 126 |
+
'zh': '验证时使用generate/Rouge代替loss',
|
| 127 |
+
'en': 'Use model.generate/Rouge instead of loss',
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
'deepspeed': {
|
| 131 |
+
'label': {
|
| 132 |
+
'zh': 'deepspeed',
|
| 133 |
+
'en': 'deepspeed',
|
| 134 |
+
},
|
| 135 |
+
'info': {
|
| 136 |
+
'zh': '可以选择下拉列表,也支持传入路径',
|
| 137 |
+
'en': 'Choose from the dropbox or fill in a valid path',
|
| 138 |
+
}
|
| 139 |
+
},
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 144 |
+
with gr.Accordion(elem_id='advanced_param', open=False):
|
| 145 |
+
with gr.Blocks():
|
| 146 |
+
with gr.Row():
|
| 147 |
+
gr.Textbox(elem_id='optim', lines=1, scale=20)
|
| 148 |
+
gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
|
| 149 |
+
gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
|
| 150 |
+
gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
|
| 151 |
+
gr.Textbox(elem_id='max_steps', lines=1, scale=20)
|
| 152 |
+
gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
|
| 153 |
+
with gr.Row():
|
| 154 |
+
gr.Dropdown(elem_id='truncation_strategy', scale=20)
|
| 155 |
+
gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
|
| 156 |
+
gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
|
| 157 |
+
gr.Dropdown(
|
| 158 |
+
elem_id='deepspeed',
|
| 159 |
+
scale=20,
|
| 160 |
+
allow_custom_value=True,
|
| 161 |
+
value=None,
|
| 162 |
+
choices=['zero0', 'zero1', 'zero2', 'zero3', 'zero2_offload', 'zero3_offload'])
|
| 163 |
+
with gr.Row():
|
| 164 |
+
gr.Textbox(elem_id='more_params', lines=4, scale=20)
|
ms-swift/swift/ui/llm_train/dataset.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.llm.dataset.register import get_dataset_list
|
| 7 |
+
from swift.ui.base import BaseUI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Dataset(BaseUI):
|
| 11 |
+
|
| 12 |
+
group = 'llm_train'
|
| 13 |
+
|
| 14 |
+
locale_dict = {
|
| 15 |
+
'dataset': {
|
| 16 |
+
'label': {
|
| 17 |
+
'zh': '数据集名称',
|
| 18 |
+
'en': 'Dataset Code'
|
| 19 |
+
},
|
| 20 |
+
'info': {
|
| 21 |
+
'zh': '选择训练的数据集,支持复选/本地路径',
|
| 22 |
+
'en': 'The dataset(s) to train the models, support multi select and local folder/files'
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
'max_length': {
|
| 26 |
+
'label': {
|
| 27 |
+
'zh': '句子最大长度',
|
| 28 |
+
'en': 'The max length',
|
| 29 |
+
},
|
| 30 |
+
'info': {
|
| 31 |
+
'zh': '设置输入模型的最大长度',
|
| 32 |
+
'en': 'Set the max length input to the model',
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
'split_dataset_ratio': {
|
| 36 |
+
'label': {
|
| 37 |
+
'zh': '验证集拆分比例',
|
| 38 |
+
'en': 'Split ratio of eval dataset'
|
| 39 |
+
},
|
| 40 |
+
'info': {
|
| 41 |
+
'zh': '表示将总数据的多少拆分到验证集中',
|
| 42 |
+
'en': 'Split the datasets by this ratio for eval'
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
'train_dataset_sample': {
|
| 46 |
+
'label': {
|
| 47 |
+
'zh': '训练集采样数量',
|
| 48 |
+
'en': 'The sample size from the train dataset'
|
| 49 |
+
},
|
| 50 |
+
'info': {
|
| 51 |
+
'zh': '从训练集中采样一定行数进行训练',
|
| 52 |
+
'en': 'Train with the sample size from the dataset',
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
'val_dataset_sample': {
|
| 56 |
+
'label': {
|
| 57 |
+
'zh': '验证集采样数量',
|
| 58 |
+
'en': 'The sample size from the val dataset'
|
| 59 |
+
},
|
| 60 |
+
'info': {
|
| 61 |
+
'zh': '从验证集中采样一定行数进行训练',
|
| 62 |
+
'en': 'Validate with the sample size from the dataset',
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
'custom_dataset_info': {
|
| 66 |
+
'label': {
|
| 67 |
+
'zh': '外部数据集配置',
|
| 68 |
+
'en': 'Custom dataset config'
|
| 69 |
+
},
|
| 70 |
+
'info': {
|
| 71 |
+
'zh': '注册外部数据集的配置文件',
|
| 72 |
+
'en': 'An extra dataset config to register your own datasets'
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
'dataset_param': {
|
| 76 |
+
'label': {
|
| 77 |
+
'zh': '数据集设置',
|
| 78 |
+
'en': 'Dataset settings'
|
| 79 |
+
},
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 85 |
+
with gr.Accordion(elem_id='dataset_param', open=True):
|
| 86 |
+
with gr.Row():
|
| 87 |
+
gr.Dropdown(
|
| 88 |
+
elem_id='dataset', multiselect=True, choices=get_dataset_list(), scale=20, allow_custom_value=True)
|
| 89 |
+
gr.Textbox(elem_id='custom_dataset_info', is_list=False, scale=20)
|
| 90 |
+
gr.Slider(elem_id='split_dataset_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=10)
|
| 91 |
+
gr.Slider(elem_id='max_length', minimum=32, maximum=32768, value=1024, step=1, scale=10)
|
ms-swift/swift/ui/llm_train/hyper.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Hyper(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'hyper_param': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '超参数设置(更多参数->高级参数设置)',
|
| 17 |
+
'en': 'Hyper settings(more params->Advanced settings)',
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'per_device_train_batch_size': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '训练batch size',
|
| 23 |
+
'en': 'Train batch size',
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '训练的batch size',
|
| 27 |
+
'en': 'Set the train batch size',
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'learning_rate': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '学习率',
|
| 33 |
+
'en': 'Learning rate',
|
| 34 |
+
},
|
| 35 |
+
'info': {
|
| 36 |
+
'zh': '设置学习率',
|
| 37 |
+
'en': 'Set the learning rate',
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
'eval_steps': {
|
| 41 |
+
'label': {
|
| 42 |
+
'zh': '交叉验证步数',
|
| 43 |
+
'en': 'Eval steps',
|
| 44 |
+
},
|
| 45 |
+
'info': {
|
| 46 |
+
'zh': '设置每隔多少步数进行一次验证',
|
| 47 |
+
'en': 'Set the step interval to validate',
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
'num_train_epochs': {
|
| 51 |
+
'label': {
|
| 52 |
+
'zh': '数据集迭代轮次',
|
| 53 |
+
'en': 'Train epoch',
|
| 54 |
+
},
|
| 55 |
+
'info': {
|
| 56 |
+
'zh': '设置对数据集训练多少轮次',
|
| 57 |
+
'en': 'Set the max train epoch',
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
'gradient_accumulation_steps': {
|
| 61 |
+
'label': {
|
| 62 |
+
'zh': '梯度累计步数',
|
| 63 |
+
'en': 'Gradient accumulation steps',
|
| 64 |
+
},
|
| 65 |
+
'info': {
|
| 66 |
+
'zh': '设置梯度累计步数以减小显存占用',
|
| 67 |
+
'en': 'Set the gradient accumulation steps',
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
'attn_impl': {
|
| 71 |
+
'label': {
|
| 72 |
+
'zh': 'Flash Attention类型',
|
| 73 |
+
'en': 'Flash Attention Type',
|
| 74 |
+
},
|
| 75 |
+
},
|
| 76 |
+
'neftune_noise_alpha': {
|
| 77 |
+
'label': {
|
| 78 |
+
'zh': 'neftune_noise_alpha',
|
| 79 |
+
'en': 'neftune_noise_alpha'
|
| 80 |
+
},
|
| 81 |
+
'info': {
|
| 82 |
+
'zh': '使用neftune提升训练效果, 一般设置为5或者10',
|
| 83 |
+
'en': 'Use neftune to improve performance, normally the value should be 5 or 10'
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
'save_steps': {
|
| 87 |
+
'label': {
|
| 88 |
+
'zh': '存储步数',
|
| 89 |
+
'en': 'save steps',
|
| 90 |
+
},
|
| 91 |
+
'info': {
|
| 92 |
+
'zh': '设置每个多少步数进行存储',
|
| 93 |
+
'en': 'Set the save steps',
|
| 94 |
+
}
|
| 95 |
+
},
|
| 96 |
+
'output_dir': {
|
| 97 |
+
'label': {
|
| 98 |
+
'zh': '存储目录',
|
| 99 |
+
'en': 'The output dir',
|
| 100 |
+
},
|
| 101 |
+
'info': {
|
| 102 |
+
'zh': '设置输出模型存储在哪个文件夹下',
|
| 103 |
+
'en': 'Set the output folder',
|
| 104 |
+
}
|
| 105 |
+
},
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 110 |
+
with gr.Accordion(elem_id='hyper_param', open=False):
|
| 111 |
+
with gr.Blocks():
|
| 112 |
+
with gr.Row():
|
| 113 |
+
gr.Slider(elem_id='per_device_train_batch_size', minimum=1, maximum=256, step=2, scale=20)
|
| 114 |
+
gr.Textbox(elem_id='learning_rate', value='1e-4', lines=1, scale=20)
|
| 115 |
+
gr.Textbox(elem_id='num_train_epochs', lines=1, scale=20)
|
| 116 |
+
gr.Dropdown(elem_id='attn_impl', scale=20, value='flash_attn')
|
| 117 |
+
gr.Slider(elem_id='gradient_accumulation_steps', minimum=1, maximum=256, step=2, value=16, scale=20)
|
| 118 |
+
with gr.Row():
|
| 119 |
+
gr.Textbox(elem_id='eval_steps', lines=1, value='500', scale=20)
|
| 120 |
+
gr.Textbox(elem_id='save_steps', value='500', lines=1, scale=20)
|
| 121 |
+
gr.Textbox(elem_id='output_dir', scale=20)
|
| 122 |
+
gr.Slider(elem_id='neftune_noise_alpha', minimum=0.0, maximum=20.0, step=0.5, scale=20)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def update_lr(sft_type):
|
| 126 |
+
if sft_type == 'full':
|
| 127 |
+
return 1e-5
|
| 128 |
+
else:
|
| 129 |
+
return 1e-4
|
ms-swift/swift/ui/llm_train/llamapro.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LlamaPro(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'llamapro_tab': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': 'LLAMAPRO参数设置',
|
| 17 |
+
'en': 'LLAMAPRO Settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'llamapro_num_new_blocks': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': 'LLAMAPRO插入层数',
|
| 23 |
+
'en': 'LLAMAPRO new layers'
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
'llamapro_num_groups': {
|
| 27 |
+
'label': {
|
| 28 |
+
'zh': 'LLAMAPRO对原模型的分组数',
|
| 29 |
+
'en': 'LLAMAPRO groups of model'
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 36 |
+
with gr.Accordion(elem_id='llamapro_tab', open=False):
|
| 37 |
+
with gr.Blocks():
|
| 38 |
+
with gr.Row():
|
| 39 |
+
gr.Textbox(elem_id='llamapro_num_new_blocks')
|
| 40 |
+
gr.Textbox(elem_id='llamapro_num_groups')
|
ms-swift/swift/ui/llm_train/llm_train.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
from subprocess import PIPE, STDOUT, Popen
|
| 9 |
+
from typing import Dict, Type
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import json
|
| 13 |
+
import torch
|
| 14 |
+
from json import JSONDecodeError
|
| 15 |
+
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
| 16 |
+
|
| 17 |
+
from swift.llm import RLHFArguments
|
| 18 |
+
from swift.llm.argument.base_args.base_args import get_supported_tuners
|
| 19 |
+
from swift.ui.base import BaseUI
|
| 20 |
+
from swift.ui.llm_train.advanced import Advanced
|
| 21 |
+
from swift.ui.llm_train.dataset import Dataset
|
| 22 |
+
from swift.ui.llm_train.galore import Galore
|
| 23 |
+
from swift.ui.llm_train.hyper import Hyper
|
| 24 |
+
from swift.ui.llm_train.lisa import Lisa
|
| 25 |
+
from swift.ui.llm_train.llamapro import LlamaPro
|
| 26 |
+
from swift.ui.llm_train.lora import LoRA
|
| 27 |
+
from swift.ui.llm_train.model import Model
|
| 28 |
+
from swift.ui.llm_train.quantization import Quantization
|
| 29 |
+
from swift.ui.llm_train.report_to import ReportTo
|
| 30 |
+
from swift.ui.llm_train.rlhf import RLHF
|
| 31 |
+
from swift.ui.llm_train.runtime import Runtime
|
| 32 |
+
from swift.ui.llm_train.save import Save
|
| 33 |
+
from swift.ui.llm_train.self_cog import SelfCog
|
| 34 |
+
from swift.utils import get_device_count, get_logger
|
| 35 |
+
|
| 36 |
+
logger = get_logger()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LLMTrain(BaseUI):
|
| 40 |
+
group = 'llm_train'
|
| 41 |
+
|
| 42 |
+
sub_ui = [
|
| 43 |
+
Model,
|
| 44 |
+
Dataset,
|
| 45 |
+
Runtime,
|
| 46 |
+
Save,
|
| 47 |
+
LoRA,
|
| 48 |
+
Hyper,
|
| 49 |
+
Quantization,
|
| 50 |
+
SelfCog,
|
| 51 |
+
Advanced,
|
| 52 |
+
RLHF,
|
| 53 |
+
Lisa,
|
| 54 |
+
Galore,
|
| 55 |
+
LlamaPro,
|
| 56 |
+
ReportTo,
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
locale_dict: Dict[str, Dict] = {
|
| 60 |
+
'llm_train': {
|
| 61 |
+
'label': {
|
| 62 |
+
'zh': 'LLM训练',
|
| 63 |
+
'en': 'LLM Training',
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
'train_stage': {
|
| 67 |
+
'label': {
|
| 68 |
+
'zh': '训练Stage',
|
| 69 |
+
'en': 'Train Stage'
|
| 70 |
+
},
|
| 71 |
+
'info': {
|
| 72 |
+
'zh': '请注意选择与此匹配的数据集,人类对齐配置在页面下方',
|
| 73 |
+
'en': 'Please choose matched dataset, RLHF settings is at the bottom of the page'
|
| 74 |
+
}
|
| 75 |
+
},
|
| 76 |
+
'submit_alert': {
|
| 77 |
+
'value': {
|
| 78 |
+
'zh':
|
| 79 |
+
'任务已开始,请查看tensorboard或日志记录,关闭本页面不影响训练过程',
|
| 80 |
+
'en':
|
| 81 |
+
'Task started, please check the tensorboard or log file, '
|
| 82 |
+
'closing this page does not affect training'
|
| 83 |
+
}
|
| 84 |
+
},
|
| 85 |
+
'dataset_alert': {
|
| 86 |
+
'value': {
|
| 87 |
+
'zh': '请选择或填入一个数据集',
|
| 88 |
+
'en': 'Please input or select a dataset'
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
'submit': {
|
| 92 |
+
'value': {
|
| 93 |
+
'zh': '🚀 开始训练',
|
| 94 |
+
'en': '🚀 Begin'
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
'dry_run': {
|
| 98 |
+
'label': {
|
| 99 |
+
'zh': '仅生成运行命令',
|
| 100 |
+
'en': 'Dry-run'
|
| 101 |
+
},
|
| 102 |
+
'info': {
|
| 103 |
+
'zh': '仅生成运行命令,开发者自行运行',
|
| 104 |
+
'en': 'Generate run command only, for manually running'
|
| 105 |
+
}
|
| 106 |
+
},
|
| 107 |
+
'gpu_id': {
|
| 108 |
+
'label': {
|
| 109 |
+
'zh': '选择可用GPU',
|
| 110 |
+
'en': 'Choose GPU'
|
| 111 |
+
},
|
| 112 |
+
'info': {
|
| 113 |
+
'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU',
|
| 114 |
+
'en': 'Select GPU to train'
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
'train_type': {
|
| 118 |
+
'label': {
|
| 119 |
+
'zh': '训练方式',
|
| 120 |
+
'en': 'Train type'
|
| 121 |
+
},
|
| 122 |
+
'info': {
|
| 123 |
+
'zh': '选择训练的方式',
|
| 124 |
+
'en': 'Select the training type'
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
'seed': {
|
| 128 |
+
'label': {
|
| 129 |
+
'zh': '随机数种子',
|
| 130 |
+
'en': 'Seed'
|
| 131 |
+
},
|
| 132 |
+
'info': {
|
| 133 |
+
'zh': '选择随机数种子',
|
| 134 |
+
'en': 'Select a random seed'
|
| 135 |
+
}
|
| 136 |
+
},
|
| 137 |
+
'torch_dtype': {
|
| 138 |
+
'label': {
|
| 139 |
+
'zh': '训练精度',
|
| 140 |
+
'en': 'Training Precision'
|
| 141 |
+
},
|
| 142 |
+
'info': {
|
| 143 |
+
'zh': '选择训练精度',
|
| 144 |
+
'en': 'Select the training precision'
|
| 145 |
+
}
|
| 146 |
+
},
|
| 147 |
+
'envs': {
|
| 148 |
+
'label': {
|
| 149 |
+
'zh': '环境变量',
|
| 150 |
+
'en': 'Extra env vars'
|
| 151 |
+
},
|
| 152 |
+
},
|
| 153 |
+
'use_ddp': {
|
| 154 |
+
'label': {
|
| 155 |
+
'zh': '使用DDP',
|
| 156 |
+
'en': 'Use DDP'
|
| 157 |
+
},
|
| 158 |
+
'info': {
|
| 159 |
+
'zh': '是否使用数据并行训练',
|
| 160 |
+
'en': 'Use Distributed Data Parallel to train'
|
| 161 |
+
}
|
| 162 |
+
},
|
| 163 |
+
'ddp_num': {
|
| 164 |
+
'label': {
|
| 165 |
+
'zh': 'DDP分片数量',
|
| 166 |
+
'en': 'Number of DDP sharding'
|
| 167 |
+
},
|
| 168 |
+
'info': {
|
| 169 |
+
'zh': '启用多少进程的数据并��',
|
| 170 |
+
'en': 'The data parallel size of DDP'
|
| 171 |
+
}
|
| 172 |
+
},
|
| 173 |
+
'tuner_backend': {
|
| 174 |
+
'label': {
|
| 175 |
+
'zh': 'Tuner backend',
|
| 176 |
+
'en': 'Tuner backend'
|
| 177 |
+
},
|
| 178 |
+
'info': {
|
| 179 |
+
'zh': 'tuner实现框架',
|
| 180 |
+
'en': 'The tuner backend'
|
| 181 |
+
}
|
| 182 |
+
},
|
| 183 |
+
'use_liger_kernel': {
|
| 184 |
+
'label': {
|
| 185 |
+
'zh': '使用Liger kernel',
|
| 186 |
+
'en': 'Use Liger kernel'
|
| 187 |
+
},
|
| 188 |
+
'info': {
|
| 189 |
+
'zh': 'Liger kernel可以有效降低显存使用',
|
| 190 |
+
'en': 'Liger kernel can reduce memory usage'
|
| 191 |
+
}
|
| 192 |
+
},
|
| 193 |
+
'train_param': {
|
| 194 |
+
'label': {
|
| 195 |
+
'zh': '训练参数设置',
|
| 196 |
+
'en': 'Train settings'
|
| 197 |
+
},
|
| 198 |
+
},
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
choice_dict = BaseUI.get_choices_from_dataclass(RLHFArguments)
|
| 202 |
+
default_dict = BaseUI.get_default_value_from_dataclass(RLHFArguments)
|
| 203 |
+
arguments = BaseUI.get_argument_names(RLHFArguments)
|
| 204 |
+
|
| 205 |
+
@classmethod
|
| 206 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 207 |
+
with gr.TabItem(elem_id='llm_train', label=''):
|
| 208 |
+
default_device = 'cpu'
|
| 209 |
+
device_count = get_device_count()
|
| 210 |
+
if device_count > 0:
|
| 211 |
+
default_device = '0'
|
| 212 |
+
with gr.Blocks():
|
| 213 |
+
Model.build_ui(base_tab)
|
| 214 |
+
Dataset.build_ui(base_tab)
|
| 215 |
+
with gr.Accordion(elem_id='train_param', open=True):
|
| 216 |
+
with gr.Row():
|
| 217 |
+
gr.Dropdown(elem_id='train_stage', choices=['pt', 'sft', 'rlhf'], value='sft', scale=3)
|
| 218 |
+
gr.Dropdown(elem_id='train_type', scale=2, choices=list(get_supported_tuners()))
|
| 219 |
+
gr.Dropdown(elem_id='tuner_backend', scale=2)
|
| 220 |
+
with gr.Row():
|
| 221 |
+
gr.Textbox(elem_id='seed', scale=4)
|
| 222 |
+
gr.Dropdown(elem_id='torch_dtype', scale=4)
|
| 223 |
+
gr.Checkbox(elem_id='use_liger_kernel', scale=4)
|
| 224 |
+
gr.Checkbox(elem_id='use_ddp', value=False, scale=4)
|
| 225 |
+
gr.Textbox(elem_id='ddp_num', value='2', scale=4)
|
| 226 |
+
Hyper.build_ui(base_tab)
|
| 227 |
+
Runtime.build_ui(base_tab)
|
| 228 |
+
with gr.Row():
|
| 229 |
+
gr.Dropdown(
|
| 230 |
+
elem_id='gpu_id',
|
| 231 |
+
multiselect=True,
|
| 232 |
+
choices=[str(i) for i in range(device_count)] + ['cpu'],
|
| 233 |
+
value=default_device,
|
| 234 |
+
scale=8)
|
| 235 |
+
gr.Textbox(elem_id='envs', scale=8)
|
| 236 |
+
gr.Checkbox(elem_id='dry_run', value=False, scale=4)
|
| 237 |
+
submit = gr.Button(elem_id='submit', scale=4, variant='primary')
|
| 238 |
+
|
| 239 |
+
LoRA.build_ui(base_tab)
|
| 240 |
+
RLHF.build_ui(base_tab)
|
| 241 |
+
Quantization.build_ui(base_tab)
|
| 242 |
+
Galore.build_ui(base_tab)
|
| 243 |
+
Lisa.build_ui(base_tab)
|
| 244 |
+
LlamaPro.build_ui(base_tab)
|
| 245 |
+
SelfCog.build_ui(base_tab)
|
| 246 |
+
Save.build_ui(base_tab)
|
| 247 |
+
ReportTo.build_ui(base_tab)
|
| 248 |
+
Advanced.build_ui(base_tab)
|
| 249 |
+
|
| 250 |
+
cls.element('train_type').change(
|
| 251 |
+
Hyper.update_lr, inputs=[base_tab.element('train_type')], outputs=[cls.element('learning_rate')])
|
| 252 |
+
|
| 253 |
+
submit.click(
|
| 254 |
+
cls.train_local,
|
| 255 |
+
list(cls.valid_elements().values()), [
|
| 256 |
+
cls.element('running_cmd'),
|
| 257 |
+
cls.element('logging_dir'),
|
| 258 |
+
cls.element('runtime_tab'),
|
| 259 |
+
cls.element('running_tasks'),
|
| 260 |
+
cls.element('train_record'),
|
| 261 |
+
],
|
| 262 |
+
queue=True)
|
| 263 |
+
|
| 264 |
+
base_tab.element('running_tasks').change(
|
| 265 |
+
partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
|
| 266 |
+
list(base_tab.valid_elements().values()) + [cls.element('log')] + Runtime.all_plots)
|
| 267 |
+
Runtime.element('kill_task').click(
|
| 268 |
+
Runtime.kill_task,
|
| 269 |
+
[Runtime.element('running_tasks')],
|
| 270 |
+
[Runtime.element('running_tasks')] + [Runtime.element('log')] + Runtime.all_plots,
|
| 271 |
+
).then(Runtime.reset, [], [Runtime.element('logging_dir')] + [Hyper.element('output_dir')])
|
| 272 |
+
|
| 273 |
+
@classmethod
|
| 274 |
+
def update_runtime(cls):
|
| 275 |
+
return gr.update(open=True), gr.update(visible=True)
|
| 276 |
+
|
| 277 |
+
@classmethod
|
| 278 |
+
def train(cls, *args):
|
| 279 |
+
ignore_elements = ('logging_dir', 'more_params', 'train_stage', 'envs')
|
| 280 |
+
default_args = cls.get_default_value_from_dataclass(RLHFArguments)
|
| 281 |
+
kwargs = {}
|
| 282 |
+
kwargs_is_list = {}
|
| 283 |
+
other_kwargs = {}
|
| 284 |
+
more_params = {}
|
| 285 |
+
more_params_cmd = ''
|
| 286 |
+
keys = cls.valid_element_keys()
|
| 287 |
+
train_stage = 'sft'
|
| 288 |
+
for key, value in zip(keys, args):
|
| 289 |
+
compare_value = default_args.get(key)
|
| 290 |
+
if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
|
| 291 |
+
value = int(value)
|
| 292 |
+
elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
|
| 293 |
+
value = float(value)
|
| 294 |
+
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
|
| 295 |
+
value = True if value.lower() == 'true' else False
|
| 296 |
+
if key not in ignore_elements and key in default_args and compare_value != value and value:
|
| 297 |
+
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
|
| 298 |
+
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
|
| 299 |
+
else:
|
| 300 |
+
other_kwargs[key] = value
|
| 301 |
+
if key == 'more_params' and value:
|
| 302 |
+
try:
|
| 303 |
+
more_params = json.loads(value)
|
| 304 |
+
except (JSONDecodeError or TypeError):
|
| 305 |
+
more_params_cmd = value
|
| 306 |
+
|
| 307 |
+
if key == 'train_stage':
|
| 308 |
+
train_stage = value
|
| 309 |
+
|
| 310 |
+
kwargs.update(more_params)
|
| 311 |
+
if 'dataset' not in kwargs and 'custom_train_dataset_path' not in kwargs:
|
| 312 |
+
raise gr.Error(cls.locale('dataset_alert', cls.lang)['value'])
|
| 313 |
+
|
| 314 |
+
model = kwargs.get('model')
|
| 315 |
+
if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
|
| 316 |
+
kwargs['resume_from_checkpoint'] = kwargs.pop('model')
|
| 317 |
+
|
| 318 |
+
cmd = train_stage
|
| 319 |
+
if kwargs.get('deepspeed'):
|
| 320 |
+
more_params_cmd += f' --deepspeed {kwargs.pop("deepspeed")} '
|
| 321 |
+
try:
|
| 322 |
+
sft_args = RLHFArguments(
|
| 323 |
+
**{
|
| 324 |
+
key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
|
| 325 |
+
for key, value in kwargs.items()
|
| 326 |
+
})
|
| 327 |
+
except Exception as e:
|
| 328 |
+
if 'using `--model`' in str(e): # TODO a dirty fix
|
| 329 |
+
kwargs['model'] = kwargs.pop('resume_from_checkpoint')
|
| 330 |
+
sft_args = RLHFArguments(
|
| 331 |
+
**{
|
| 332 |
+
key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
|
| 333 |
+
for key, value in kwargs.items()
|
| 334 |
+
})
|
| 335 |
+
else:
|
| 336 |
+
raise e
|
| 337 |
+
params = ''
|
| 338 |
+
|
| 339 |
+
sep = f'{cls.quote} {cls.quote}'
|
| 340 |
+
for e in kwargs:
|
| 341 |
+
if isinstance(kwargs[e], list):
|
| 342 |
+
params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
|
| 343 |
+
elif e in kwargs_is_list and kwargs_is_list[e]:
|
| 344 |
+
all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
|
| 345 |
+
params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
|
| 346 |
+
else:
|
| 347 |
+
params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
|
| 348 |
+
params += more_params_cmd + ' '
|
| 349 |
+
params += f'--add_version False --output_dir {sft_args.output_dir} ' \
|
| 350 |
+
f'--logging_dir {sft_args.logging_dir} --ignore_args_error True'
|
| 351 |
+
ddp_param = ''
|
| 352 |
+
devices = other_kwargs['gpu_id']
|
| 353 |
+
envs = other_kwargs['envs'] or ''
|
| 354 |
+
envs = envs.strip()
|
| 355 |
+
devices = [d for d in devices if d]
|
| 356 |
+
if other_kwargs['use_ddp']:
|
| 357 |
+
assert int(other_kwargs['ddp_num']) > 0
|
| 358 |
+
ddp_param = f'NPROC_PER_NODE={int(other_kwargs["ddp_num"])}'
|
| 359 |
+
assert (len(devices) == 1 or 'cpu' not in devices)
|
| 360 |
+
gpus = ','.join(devices)
|
| 361 |
+
cuda_param = ''
|
| 362 |
+
if gpus != 'cpu':
|
| 363 |
+
if is_torch_npu_available():
|
| 364 |
+
cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
|
| 365 |
+
elif is_torch_cuda_available():
|
| 366 |
+
cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
|
| 367 |
+
else:
|
| 368 |
+
cuda_param = ''
|
| 369 |
+
|
| 370 |
+
log_file = os.path.join(sft_args.logging_dir, 'run.log')
|
| 371 |
+
if sys.platform == 'win32':
|
| 372 |
+
if cuda_param:
|
| 373 |
+
cuda_param = f'set {cuda_param} && '
|
| 374 |
+
if ddp_param:
|
| 375 |
+
ddp_param = f'set {ddp_param} && '
|
| 376 |
+
if envs:
|
| 377 |
+
envs = [env.strip() for env in envs.split(' ') if env.strip()]
|
| 378 |
+
_envs = ''
|
| 379 |
+
for env in envs:
|
| 380 |
+
_envs += f'set {env} && '
|
| 381 |
+
envs = _envs
|
| 382 |
+
run_command = f'{cuda_param}{ddp_param}{envs}start /b swift sft {params} > {log_file} 2>&1'
|
| 383 |
+
else:
|
| 384 |
+
run_command = f'{cuda_param} {ddp_param} {envs} nohup swift {cmd} {params} > {log_file} 2>&1 &'
|
| 385 |
+
logger.info(f'Run training: {run_command}')
|
| 386 |
+
if model:
|
| 387 |
+
record = {}
|
| 388 |
+
for key, value in zip(keys, args):
|
| 389 |
+
if key in default_args or key in ('more_params', 'train_stage', 'use_ddp', 'ddp_num', 'gpu_id', 'envs'):
|
| 390 |
+
record[key] = value or None
|
| 391 |
+
cls.save_cache(model, record)
|
| 392 |
+
return run_command, sft_args, other_kwargs
|
| 393 |
+
|
| 394 |
+
@classmethod
|
| 395 |
+
def train_studio(cls, *args):
|
| 396 |
+
run_command, sft_args, other_kwargs = cls.train(*args)
|
| 397 |
+
if not other_kwargs['dry_run']:
|
| 398 |
+
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
|
| 399 |
+
process = Popen(run_command, shell=True, stdout=PIPE, stderr=STDOUT)
|
| 400 |
+
with process.stdout:
|
| 401 |
+
for line in iter(process.stdout.readline, b''):
|
| 402 |
+
line = line.decode('utf-8')
|
| 403 |
+
lines.append(line)
|
| 404 |
+
yield ['\n'.join(lines)] + Runtime.plot(run_command) + [run_command]
|
| 405 |
+
else:
|
| 406 |
+
yield [
|
| 407 |
+
'Current is dryrun mode so you can only view the training cmd, please duplicate this space to '
|
| 408 |
+
'do training or use with inference.'
|
| 409 |
+
] + [None] * len(Runtime.sft_plot) + [run_command]
|
| 410 |
+
|
| 411 |
+
@classmethod
|
| 412 |
+
def train_local(cls, *args):
|
| 413 |
+
run_command, sft_args, other_kwargs = cls.train(*args)
|
| 414 |
+
if not other_kwargs['dry_run']:
|
| 415 |
+
os.makedirs(sft_args.logging_dir, exist_ok=True)
|
| 416 |
+
os.system(run_command)
|
| 417 |
+
time.sleep(1) # to make sure the log file has been created.
|
| 418 |
+
gr.Info(cls.locale('submit_alert', cls.lang)['value'])
|
| 419 |
+
return run_command, sft_args.logging_dir, gr.update(open=True), Runtime.refresh_tasks(
|
| 420 |
+
sft_args.output_dir), gr.update(choices=cls.list_cache(sft_args.model))
|
ms-swift/swift/ui/llm_train/lora.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LoRA(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'lora_tab': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': 'LoRA参数设置',
|
| 17 |
+
'en': 'LoRA settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'target_modules': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': 'LoRA目标模块',
|
| 23 |
+
'en': 'LoRA target modules'
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '设置LoRA目标模块,如训练所有Linear请改为`all-linear`',
|
| 27 |
+
'en': 'Set the LoRA target modules, fill in `all-linear` if train all Linears'
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'lora_rank': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': 'LoRA的秩',
|
| 33 |
+
'en': 'The LoRA rank'
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
'lora_alpha': {
|
| 37 |
+
'label': {
|
| 38 |
+
'zh': 'LoRA的alpha',
|
| 39 |
+
'en': 'The LoRA alpha'
|
| 40 |
+
}
|
| 41 |
+
},
|
| 42 |
+
'lora_dropout': {
|
| 43 |
+
'label': {
|
| 44 |
+
'zh': 'LoRA的dropout',
|
| 45 |
+
'en': 'The LoRA dropout'
|
| 46 |
+
}
|
| 47 |
+
},
|
| 48 |
+
'use_rslora': {
|
| 49 |
+
'label': {
|
| 50 |
+
'zh': '使用rslora',
|
| 51 |
+
'en': 'Use rslora'
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
'use_dora': {
|
| 55 |
+
'label': {
|
| 56 |
+
'zh': '使用dora',
|
| 57 |
+
'en': 'Use dora'
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
'lora_dtype': {
|
| 61 |
+
'label': {
|
| 62 |
+
'zh': 'lora部分的参数类型',
|
| 63 |
+
'en': 'The dtype of lora parameters'
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
'init_weights': {
|
| 67 |
+
'label': {
|
| 68 |
+
'zh': 'lora初始化方法',
|
| 69 |
+
'en': 'init lora weights'
|
| 70 |
+
},
|
| 71 |
+
'info': {
|
| 72 |
+
'zh': 'gaussian/pissa/pissa_niter_[n]/olora/loftq/true/false',
|
| 73 |
+
'en': 'gaussian/pissa/pissa_niter_[n]/olora/loftq/true/false',
|
| 74 |
+
}
|
| 75 |
+
},
|
| 76 |
+
'lorap_lr_ratio': {
|
| 77 |
+
'label': {
|
| 78 |
+
'zh': 'Lora+学习率倍率',
|
| 79 |
+
'en': 'The lr ratio of Lora+'
|
| 80 |
+
},
|
| 81 |
+
'info': {
|
| 82 |
+
'zh': '建议值16.0',
|
| 83 |
+
'en': 'Suggested value: 16.0'
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 90 |
+
with gr.Accordion(elem_id='lora_tab', open=True):
|
| 91 |
+
with gr.Blocks():
|
| 92 |
+
with gr.Row():
|
| 93 |
+
gr.Textbox(elem_id='target_modules', lines=1, scale=5, value='all-linear', is_list=True)
|
| 94 |
+
gr.Slider(elem_id='lora_rank', value=8, minimum=1, maximum=512, step=8, scale=2)
|
| 95 |
+
gr.Slider(elem_id='lora_alpha', value=32, minimum=1, maximum=512, step=8, scale=2)
|
| 96 |
+
gr.Textbox(elem_id='lora_dropout', scale=2)
|
| 97 |
+
with gr.Row():
|
| 98 |
+
gr.Dropdown(elem_id='lora_dtype', scale=2, value=None)
|
| 99 |
+
gr.Textbox(elem_id='lorap_lr_ratio', scale=2)
|
| 100 |
+
gr.Checkbox(elem_id='use_rslora', scale=2)
|
| 101 |
+
gr.Checkbox(elem_id='use_dora', scale=2)
|
| 102 |
+
gr.Textbox(elem_id='init_weights', scale=4)
|
ms-swift/swift/ui/llm_train/model.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from swift.llm import TEMPLATE_MAPPING, ModelType, RLHFArguments
|
| 8 |
+
from swift.llm.model.register import get_all_models
|
| 9 |
+
from swift.ui.base import BaseUI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Model(BaseUI):
|
| 13 |
+
group = 'llm_train'
|
| 14 |
+
|
| 15 |
+
locale_dict = {
|
| 16 |
+
'model_type': {
|
| 17 |
+
'label': {
|
| 18 |
+
'zh': '模型类型',
|
| 19 |
+
'en': 'Select Model Type'
|
| 20 |
+
},
|
| 21 |
+
'info': {
|
| 22 |
+
'zh': 'SWIFT已支持的模型类型',
|
| 23 |
+
'en': 'Base model type supported by SWIFT'
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
'model': {
|
| 27 |
+
'label': {
|
| 28 |
+
'zh': '模型id或路径',
|
| 29 |
+
'en': 'Model id or path'
|
| 30 |
+
},
|
| 31 |
+
'info': {
|
| 32 |
+
'zh': '实际的模型id',
|
| 33 |
+
'en': 'The actual model id or model path'
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
'template': {
|
| 37 |
+
'label': {
|
| 38 |
+
'zh': '模型Prompt模板类型',
|
| 39 |
+
'en': 'Prompt template type'
|
| 40 |
+
},
|
| 41 |
+
'info': {
|
| 42 |
+
'zh': '选择匹配模型的Prompt模板',
|
| 43 |
+
'en': 'Choose the template type of the model'
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
'system': {
|
| 47 |
+
'label': {
|
| 48 |
+
'zh': 'system字段',
|
| 49 |
+
'en': 'system'
|
| 50 |
+
},
|
| 51 |
+
'info': {
|
| 52 |
+
'zh': '选择system字段的内容',
|
| 53 |
+
'en': 'Choose the content of the system field'
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
'reset': {
|
| 57 |
+
'value': {
|
| 58 |
+
'zh': '恢复模型初始值',
|
| 59 |
+
'en': 'Reset model default'
|
| 60 |
+
},
|
| 61 |
+
},
|
| 62 |
+
'train_record': {
|
| 63 |
+
'label': {
|
| 64 |
+
'zh': '训练记录',
|
| 65 |
+
'en': 'Train record'
|
| 66 |
+
},
|
| 67 |
+
'info': {
|
| 68 |
+
'zh': '展示使用web-ui的历史训练及参数',
|
| 69 |
+
'en': 'Show the training history and parameters'
|
| 70 |
+
}
|
| 71 |
+
},
|
| 72 |
+
'clear_cache': {
|
| 73 |
+
'value': {
|
| 74 |
+
'zh': '删除训练记录',
|
| 75 |
+
'en': 'Delete train records'
|
| 76 |
+
},
|
| 77 |
+
},
|
| 78 |
+
'model_param': {
|
| 79 |
+
'label': {
|
| 80 |
+
'zh': '模型设置',
|
| 81 |
+
'en': 'Model settings'
|
| 82 |
+
},
|
| 83 |
+
},
|
| 84 |
+
'checkpoint': {
|
| 85 |
+
'value': {
|
| 86 |
+
'zh': '训练后的模型',
|
| 87 |
+
'en': 'Trained model'
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 94 |
+
with gr.Accordion(elem_id='model_param', open=True):
|
| 95 |
+
with gr.Row():
|
| 96 |
+
model = gr.Dropdown(
|
| 97 |
+
elem_id='model',
|
| 98 |
+
scale=20,
|
| 99 |
+
choices=get_all_models(),
|
| 100 |
+
value='Qwen/Qwen2.5-7B-Instruct',
|
| 101 |
+
allow_custom_value=True)
|
| 102 |
+
gr.Dropdown(elem_id='model_type', choices=ModelType.get_model_name_list(), scale=20)
|
| 103 |
+
gr.Dropdown(elem_id='template', choices=list(TEMPLATE_MAPPING.keys()), scale=20)
|
| 104 |
+
train_record = gr.Dropdown(elem_id='train_record', choices=[], scale=20)
|
| 105 |
+
clear_cache = gr.Button(elem_id='clear_cache', scale=2)
|
| 106 |
+
with gr.Row():
|
| 107 |
+
gr.Textbox(elem_id='system', lines=1, scale=20)
|
| 108 |
+
|
| 109 |
+
def clear_record(model):
|
| 110 |
+
if model:
|
| 111 |
+
cls.clear_cache(model)
|
| 112 |
+
return gr.update(choices=[])
|
| 113 |
+
return gr.update()
|
| 114 |
+
|
| 115 |
+
clear_cache.click(clear_record, inputs=[model], outputs=[train_record])
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 119 |
+
cls.element('model').change(
|
| 120 |
+
partial(base_tab.update_input_model, arg_cls=RLHFArguments),
|
| 121 |
+
inputs=[cls.element('model')],
|
| 122 |
+
outputs=[cls.element('train_record')] + list(base_tab.valid_elements().values()))
|
| 123 |
+
|
| 124 |
+
cls.element('train_record').change(
|
| 125 |
+
partial(base_tab.update_all_settings, base_tab=base_tab),
|
| 126 |
+
inputs=[cls.element('model'), cls.element('train_record')],
|
| 127 |
+
outputs=list(base_tab.valid_elements().values()))
|
ms-swift/swift/ui/llm_train/quantization.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Quantization(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'quantization_tab': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '量化参数设置',
|
| 17 |
+
'en': 'Quantization settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'quant_method': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '量化方式',
|
| 23 |
+
'en': 'Quantization method'
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '如果制定了量化位数,本参数默认为bnb',
|
| 27 |
+
'en': 'Default is bnb if quantization_bit is specified'
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'quant_bits': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '量化bit数',
|
| 33 |
+
'en': 'Quantization bit'
|
| 34 |
+
},
|
| 35 |
+
'info': {
|
| 36 |
+
'zh': '设置量化bit数, 0代表不进行量化',
|
| 37 |
+
'en': 'Set the quantization bit, 0 for no quantization'
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
'bnb_4bit_compute_dtype': {
|
| 41 |
+
'label': {
|
| 42 |
+
'zh': 'bnb_4bit_compute_dtype',
|
| 43 |
+
'en': 'bnb_4bit_compute_dtype'
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
'bnb_4bit_quant_type': {
|
| 47 |
+
'label': {
|
| 48 |
+
'zh': 'bnb_4bit_quant_type',
|
| 49 |
+
'en': 'bnb_4bit_quant_type'
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
'bnb_4bit_use_double_quant': {
|
| 53 |
+
'label': {
|
| 54 |
+
'zh': 'bnb_4bit_use_double_quant',
|
| 55 |
+
'en': 'bnb_4bit_use_double_quant'
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 62 |
+
with gr.Accordion(elem_id='quantization_tab', open=False):
|
| 63 |
+
with gr.Row():
|
| 64 |
+
gr.Dropdown(elem_id='quant_bits', value=None)
|
| 65 |
+
gr.Dropdown(elem_id='quant_method', value=None)
|
| 66 |
+
gr.Dropdown(elem_id='bnb_4bit_compute_dtype', value=None)
|
| 67 |
+
gr.Dropdown(elem_id='bnb_4bit_quant_type', value=None)
|
| 68 |
+
gr.Checkbox(elem_id='bnb_4bit_use_double_quant', value=None)
|
ms-swift/swift/ui/llm_train/report_to.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReportTo(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'reporter': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '训练记录',
|
| 17 |
+
'en': 'Training report'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'report_to': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '训练记录方式',
|
| 23 |
+
'en': 'Report to'
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
'swanlab_token': {
|
| 27 |
+
'label': {
|
| 28 |
+
'zh': 'swanlab登录token',
|
| 29 |
+
'en': 'The login token of swanlab'
|
| 30 |
+
},
|
| 31 |
+
},
|
| 32 |
+
'swanlab_project': {
|
| 33 |
+
'label': {
|
| 34 |
+
'zh': 'swanlab项目名称',
|
| 35 |
+
'en': 'Project of swanlab'
|
| 36 |
+
},
|
| 37 |
+
},
|
| 38 |
+
'swanlab_workspace': {
|
| 39 |
+
'label': {
|
| 40 |
+
'zh': 'swanlab工作空间',
|
| 41 |
+
'en': 'Workspace of swanlab'
|
| 42 |
+
},
|
| 43 |
+
},
|
| 44 |
+
'swanlab_exp_name': {
|
| 45 |
+
'label': {
|
| 46 |
+
'zh': 'swanlab实验名称',
|
| 47 |
+
'en': 'Experiment of swanlab'
|
| 48 |
+
},
|
| 49 |
+
},
|
| 50 |
+
'swanlab_mode': {
|
| 51 |
+
'label': {
|
| 52 |
+
'zh': 'swanlab工作模式',
|
| 53 |
+
'en': 'Work mode of swanlab'
|
| 54 |
+
},
|
| 55 |
+
},
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 60 |
+
with gr.Accordion(elem_id='reporter', open=False):
|
| 61 |
+
with gr.Blocks():
|
| 62 |
+
with gr.Row():
|
| 63 |
+
gr.Dropdown(
|
| 64 |
+
elem_id='report_to',
|
| 65 |
+
multiselect=True,
|
| 66 |
+
is_list=True,
|
| 67 |
+
choices=['tensorboard', 'wandb', 'swanlab'],
|
| 68 |
+
allow_custom_value=True,
|
| 69 |
+
scale=20)
|
| 70 |
+
gr.Textbox(elem_id='swanlab_token', lines=1, scale=20)
|
| 71 |
+
gr.Textbox(elem_id='swanlab_project', lines=1, scale=20)
|
| 72 |
+
with gr.Row():
|
| 73 |
+
gr.Textbox(elem_id='swanlab_workspace', lines=1, scale=20)
|
| 74 |
+
gr.Textbox(elem_id='swanlab_exp_name', lines=1, scale=20)
|
| 75 |
+
gr.Dropdown(elem_id='swanlab_mode', scale=20)
|
ms-swift/swift/ui/llm_train/rlhf.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Type
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from swift.llm import ModelType
|
| 8 |
+
from swift.llm.model.register import get_all_models
|
| 9 |
+
from swift.ui.base import BaseUI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RLHF(BaseUI):
|
| 13 |
+
|
| 14 |
+
group = 'llm_train'
|
| 15 |
+
|
| 16 |
+
locale_dict = {
|
| 17 |
+
'rlhf_tab': {
|
| 18 |
+
'label': {
|
| 19 |
+
'zh': '人类对齐参数设置',
|
| 20 |
+
'en': 'RLHF settings'
|
| 21 |
+
},
|
| 22 |
+
},
|
| 23 |
+
'rlhf_type': {
|
| 24 |
+
'label': {
|
| 25 |
+
'zh': '人类对齐算法类型',
|
| 26 |
+
'en': 'RLHF type'
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
'ref_model_type': {
|
| 30 |
+
'label': {
|
| 31 |
+
'zh': '选择ref模型',
|
| 32 |
+
'en': 'Select ref model'
|
| 33 |
+
},
|
| 34 |
+
'info': {
|
| 35 |
+
'zh': 'SWIFT已支持的模型名称',
|
| 36 |
+
'en': 'Base model supported by SWIFT'
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
'ref_model': {
|
| 40 |
+
'label': {
|
| 41 |
+
'zh': 'ref模型id或路径',
|
| 42 |
+
'en': 'Ref model id or path'
|
| 43 |
+
},
|
| 44 |
+
'info': {
|
| 45 |
+
'zh': '实际的模型id或路径',
|
| 46 |
+
'en': 'The actual model id or path'
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
'beta': {
|
| 50 |
+
'label': {
|
| 51 |
+
'zh': 'KL正则项系数',
|
| 52 |
+
'en': 'KL regression ratio'
|
| 53 |
+
},
|
| 54 |
+
},
|
| 55 |
+
'rpo_alpha': {
|
| 56 |
+
'label': {
|
| 57 |
+
'zh': 'DPO中混合sft交叉熵的系数',
|
| 58 |
+
'en': 'DPO Cross Entropy ratio'
|
| 59 |
+
},
|
| 60 |
+
},
|
| 61 |
+
'simpo_gamma': {
|
| 62 |
+
'label': {
|
| 63 |
+
'zh': 'SimPO reward margin',
|
| 64 |
+
'en': 'SimPO reward margin'
|
| 65 |
+
},
|
| 66 |
+
},
|
| 67 |
+
'desirable_weight': {
|
| 68 |
+
'label': {
|
| 69 |
+
'zh': 'KTO符合项系数',
|
| 70 |
+
'en': 'KTO desirable ratio'
|
| 71 |
+
},
|
| 72 |
+
},
|
| 73 |
+
'undesirable_weight': {
|
| 74 |
+
'label': {
|
| 75 |
+
'zh': 'KTO不符合项系数',
|
| 76 |
+
'en': 'KTO undesirable ratio'
|
| 77 |
+
},
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 83 |
+
with gr.Accordion(elem_id='rlhf_tab', open=False):
|
| 84 |
+
with gr.Blocks():
|
| 85 |
+
with gr.Row():
|
| 86 |
+
gr.Dropdown(elem_id='rlhf_type', value=None)
|
| 87 |
+
gr.Dropdown(
|
| 88 |
+
elem_id='ref_model', scale=20, value=None, choices=get_all_models(), allow_custom_value=True)
|
| 89 |
+
gr.Dropdown(elem_id='ref_model_type', choices=ModelType.get_model_name_list(), value=None, scale=20)
|
| 90 |
+
with gr.Row():
|
| 91 |
+
gr.Slider(elem_id='beta', minimum=0., maximum=5.0, step=0.1, scale=20)
|
| 92 |
+
gr.Slider(elem_id='rpo_alpha', minimum=0., maximum=2, step=0.1, scale=20)
|
| 93 |
+
gr.Slider(elem_id='simpo_gamma', minimum=0., maximum=2.0, step=0.1, scale=20)
|
| 94 |
+
gr.Slider(elem_id='desirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
|
| 95 |
+
gr.Slider(elem_id='undesirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def after_build_ui(cls, base_tab: Type['BaseUI']):
|
| 99 |
+
cls.element('ref_model').change(
|
| 100 |
+
partial(cls.update_input_model, allow_keys=['ref_model_type'], has_record=False, is_ref_model=True),
|
| 101 |
+
inputs=[cls.element('ref_model')],
|
| 102 |
+
outputs=[cls.element('ref_model_type')])
|
ms-swift/swift/ui/llm_train/runtime.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import collections
|
| 3 |
+
import os.path
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import webbrowser
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, List, Tuple, Type
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import json
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import psutil
|
| 14 |
+
from packaging import version
|
| 15 |
+
from transformers import is_tensorboard_available
|
| 16 |
+
|
| 17 |
+
from swift.ui.base import BaseUI
|
| 18 |
+
from swift.ui.llm_train.utils import close_loop, run_command_in_subprocess
|
| 19 |
+
from swift.utils import TB_COLOR, TB_COLOR_SMOOTH, get_logger, read_tensorboard_file, tensorboard_smoothing
|
| 20 |
+
from swift.utils.utils import format_time
|
| 21 |
+
|
| 22 |
+
logger = get_logger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Runtime(BaseUI):
|
| 26 |
+
|
| 27 |
+
handlers: Dict[str, Tuple[List, Tuple]] = {}
|
| 28 |
+
|
| 29 |
+
group = 'llm_train'
|
| 30 |
+
|
| 31 |
+
all_plots = None
|
| 32 |
+
|
| 33 |
+
log_event = {}
|
| 34 |
+
|
| 35 |
+
sft_plot = [
|
| 36 |
+
{
|
| 37 |
+
'name': 'train/loss',
|
| 38 |
+
'smooth': 0.9,
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
'name': 'train/acc',
|
| 42 |
+
'smooth': None,
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
'name': 'train/learning_rate',
|
| 46 |
+
'smooth': None,
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
'name': 'eval/loss',
|
| 50 |
+
'smooth': 0.9,
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
'name': 'eval/acc',
|
| 54 |
+
'smooth': None,
|
| 55 |
+
},
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
dpo_plot = [
|
| 59 |
+
{
|
| 60 |
+
'name': 'train/loss',
|
| 61 |
+
'smooth': 0.9,
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
'name': 'train/rewards/accuracies',
|
| 65 |
+
'smooth': None,
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
'name': 'train/rewards/margins',
|
| 69 |
+
'smooth': 0.9,
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
'name': 'train/logps/chosen',
|
| 73 |
+
'smooth': 0.9,
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
'name': 'train/logps/rejected',
|
| 77 |
+
'smooth': 0.9,
|
| 78 |
+
},
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
kto_plot = [
|
| 82 |
+
{
|
| 83 |
+
'name': 'kl',
|
| 84 |
+
'smooth': None,
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
'name': 'rewards/chosen_sum',
|
| 88 |
+
'smooth': 0.9,
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
'name': 'logps/chosen_sum',
|
| 92 |
+
'smooth': 0.9,
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
'name': 'rewards/rejected_sum',
|
| 96 |
+
'smooth': 0.9,
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
'name': 'logps/rejected_sum',
|
| 100 |
+
'smooth': 0.9,
|
| 101 |
+
},
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
orpo_plot = [
|
| 105 |
+
{
|
| 106 |
+
'name': 'train/loss',
|
| 107 |
+
'smooth': 0.9,
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
'name': 'train/rewards/accuracies',
|
| 111 |
+
'smooth': None,
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
'name': 'train/rewards/margins',
|
| 115 |
+
'smooth': 0.9,
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
'name': 'train/rewards/chosen',
|
| 119 |
+
'smooth': 0.9,
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
'name': 'train/log_odds_ratio',
|
| 123 |
+
'smooth': 0.9,
|
| 124 |
+
},
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
locale_dict = {
|
| 128 |
+
'runtime_tab': {
|
| 129 |
+
'label': {
|
| 130 |
+
'zh': '运行时',
|
| 131 |
+
'en': 'Runtime'
|
| 132 |
+
},
|
| 133 |
+
},
|
| 134 |
+
'tb_not_found': {
|
| 135 |
+
'value': {
|
| 136 |
+
'zh': 'tensorboard未安装,使用pip install tensorboard进行安装',
|
| 137 |
+
'en': 'tensorboard not found, install it by pip install tensorboard',
|
| 138 |
+
}
|
| 139 |
+
},
|
| 140 |
+
'running_cmd': {
|
| 141 |
+
'label': {
|
| 142 |
+
'zh': '运行命令',
|
| 143 |
+
'en': 'Command line'
|
| 144 |
+
},
|
| 145 |
+
'info': {
|
| 146 |
+
'zh': '执行的实际命令',
|
| 147 |
+
'en': 'The actual command'
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
'show_log': {
|
| 151 |
+
'value': {
|
| 152 |
+
'zh': '展示运行状态',
|
| 153 |
+
'en': 'Show running status'
|
| 154 |
+
},
|
| 155 |
+
},
|
| 156 |
+
'stop_show_log': {
|
| 157 |
+
'value': {
|
| 158 |
+
'zh': '停止展示运行状态',
|
| 159 |
+
'en': 'Stop showing running status'
|
| 160 |
+
},
|
| 161 |
+
},
|
| 162 |
+
'logging_dir': {
|
| 163 |
+
'label': {
|
| 164 |
+
'zh': '日志路径',
|
| 165 |
+
'en': 'Logging dir'
|
| 166 |
+
},
|
| 167 |
+
'info': {
|
| 168 |
+
'zh': '支持手动传入文件路径',
|
| 169 |
+
'en': 'Support fill custom path in'
|
| 170 |
+
}
|
| 171 |
+
},
|
| 172 |
+
'log': {
|
| 173 |
+
'label': {
|
| 174 |
+
'zh': '日志输出',
|
| 175 |
+
'en': 'Logging content'
|
| 176 |
+
},
|
| 177 |
+
'info': {
|
| 178 |
+
'zh': '如果日志无更新请再次点击"展示日志内容"',
|
| 179 |
+
'en': 'Please press "Show log" if the log content is not updating'
|
| 180 |
+
}
|
| 181 |
+
},
|
| 182 |
+
'running_tasks': {
|
| 183 |
+
'label': {
|
| 184 |
+
'zh': '运行中任务',
|
| 185 |
+
'en': 'Running Tasks'
|
| 186 |
+
},
|
| 187 |
+
'info': {
|
| 188 |
+
'zh': '运行中的任务(所有的swift sft命令)',
|
| 189 |
+
'en': 'All running tasks(started by swift sft)'
|
| 190 |
+
}
|
| 191 |
+
},
|
| 192 |
+
'refresh_tasks': {
|
| 193 |
+
'value': {
|
| 194 |
+
'zh': '找回运行时任务',
|
| 195 |
+
'en': 'Find running tasks'
|
| 196 |
+
},
|
| 197 |
+
},
|
| 198 |
+
'kill_task': {
|
| 199 |
+
'value': {
|
| 200 |
+
'zh': '杀死任务',
|
| 201 |
+
'en': 'Kill running task'
|
| 202 |
+
},
|
| 203 |
+
},
|
| 204 |
+
'tb_url': {
|
| 205 |
+
'label': {
|
| 206 |
+
'zh': 'Tensorboard链接',
|
| 207 |
+
'en': 'Tensorboard URL'
|
| 208 |
+
},
|
| 209 |
+
'info': {
|
| 210 |
+
'zh': '仅展示,不可编辑',
|
| 211 |
+
'en': 'Not editable'
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
'start_tb': {
|
| 215 |
+
'value': {
|
| 216 |
+
'zh': '打开TensorBoard',
|
| 217 |
+
'en': 'Start TensorBoard'
|
| 218 |
+
},
|
| 219 |
+
},
|
| 220 |
+
'close_tb': {
|
| 221 |
+
'value': {
|
| 222 |
+
'zh': '关闭TensorBoard',
|
| 223 |
+
'en': 'Close TensorBoard'
|
| 224 |
+
},
|
| 225 |
+
},
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
@classmethod
|
| 229 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 230 |
+
with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
|
| 231 |
+
with gr.Blocks():
|
| 232 |
+
with gr.Row():
|
| 233 |
+
gr.Textbox(elem_id='running_cmd', lines=1, scale=20, interactive=False, max_lines=1)
|
| 234 |
+
gr.Textbox(elem_id='logging_dir', lines=1, scale=20, max_lines=1)
|
| 235 |
+
gr.Button(elem_id='show_log', scale=2, variant='primary')
|
| 236 |
+
gr.Button(elem_id='stop_show_log', scale=2)
|
| 237 |
+
gr.Textbox(elem_id='tb_url', lines=1, scale=10, interactive=False, max_lines=1)
|
| 238 |
+
gr.Button(elem_id='start_tb', scale=2, variant='primary')
|
| 239 |
+
gr.Button(elem_id='close_tb', scale=2)
|
| 240 |
+
with gr.Row():
|
| 241 |
+
gr.Textbox(elem_id='log', lines=6, visible=False)
|
| 242 |
+
with gr.Row():
|
| 243 |
+
gr.Dropdown(elem_id='running_tasks', scale=10)
|
| 244 |
+
gr.Button(elem_id='refresh_tasks', scale=1)
|
| 245 |
+
gr.Button(elem_id='kill_task', scale=1)
|
| 246 |
+
|
| 247 |
+
with gr.Row():
|
| 248 |
+
cls.all_plots = []
|
| 249 |
+
for idx, k in enumerate(Runtime.sft_plot):
|
| 250 |
+
name = k['name']
|
| 251 |
+
cls.all_plots.append(gr.Plot(elem_id=str(idx), label=name))
|
| 252 |
+
|
| 253 |
+
concurrency_limit = {}
|
| 254 |
+
if version.parse(gr.__version__) >= version.parse('4.0.0'):
|
| 255 |
+
concurrency_limit = {'concurrency_limit': 5}
|
| 256 |
+
base_tab.element('show_log').click(
|
| 257 |
+
Runtime.update_log, [base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots).then(
|
| 258 |
+
Runtime.wait, [base_tab.element('logging_dir'),
|
| 259 |
+
base_tab.element('running_tasks')], [cls.element('log')] + cls.all_plots,
|
| 260 |
+
**concurrency_limit)
|
| 261 |
+
|
| 262 |
+
base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])
|
| 263 |
+
|
| 264 |
+
base_tab.element('start_tb').click(
|
| 265 |
+
Runtime.start_tb,
|
| 266 |
+
[base_tab.element('logging_dir')],
|
| 267 |
+
[base_tab.element('tb_url')],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
base_tab.element('close_tb').click(
|
| 271 |
+
Runtime.close_tb,
|
| 272 |
+
[base_tab.element('logging_dir')],
|
| 273 |
+
[],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
base_tab.element('refresh_tasks').click(
|
| 277 |
+
Runtime.refresh_tasks,
|
| 278 |
+
[base_tab.element('running_tasks')],
|
| 279 |
+
[base_tab.element('running_tasks')],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
@classmethod
|
| 283 |
+
def get_plot(cls, task):
|
| 284 |
+
if not task or 'swift sft' in task or 'swift pt' in task:
|
| 285 |
+
return cls.sft_plot
|
| 286 |
+
|
| 287 |
+
args: dict = cls.parse_info_from_cmdline(task)[1]
|
| 288 |
+
train_type = args.get('rlhf_type', 'dpo')
|
| 289 |
+
if train_type in ('dpo', 'cpo', 'simpo'):
|
| 290 |
+
return cls.dpo_plot
|
| 291 |
+
elif train_type == 'kto':
|
| 292 |
+
return cls.kto_plot
|
| 293 |
+
elif train_type == 'orpo':
|
| 294 |
+
return cls.orpo_plot
|
| 295 |
+
|
| 296 |
+
@classmethod
|
| 297 |
+
def update_log(cls, task):
|
| 298 |
+
ret = [gr.update(visible=True)]
|
| 299 |
+
plot = Runtime.get_plot(task)
|
| 300 |
+
for i in range(len(plot)):
|
| 301 |
+
p = plot[i]
|
| 302 |
+
ret.append(gr.update(visible=True, label=p['name']))
|
| 303 |
+
return ret
|
| 304 |
+
|
| 305 |
+
@classmethod
|
| 306 |
+
def get_initial(cls, line):
|
| 307 |
+
tqdm_starts = ['Train:', 'Map:', 'Val:', 'Filter:']
|
| 308 |
+
for start in tqdm_starts:
|
| 309 |
+
if line.startswith(start):
|
| 310 |
+
return start
|
| 311 |
+
return None
|
| 312 |
+
|
| 313 |
+
@classmethod
|
| 314 |
+
def wait(cls, logging_dir, task):
|
| 315 |
+
if not logging_dir:
|
| 316 |
+
return [None] + Runtime.plot(task)
|
| 317 |
+
log_file = os.path.join(logging_dir, 'run.log')
|
| 318 |
+
cls.log_event[logging_dir] = False
|
| 319 |
+
offset = 0
|
| 320 |
+
latest_data = ''
|
| 321 |
+
lines = collections.deque(maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
|
| 322 |
+
try:
|
| 323 |
+
with open(log_file, 'r', encoding='utf-8') as input:
|
| 324 |
+
input.seek(offset)
|
| 325 |
+
fail_cnt = 0
|
| 326 |
+
while True:
|
| 327 |
+
try:
|
| 328 |
+
latest_data += input.read()
|
| 329 |
+
except UnicodeDecodeError:
|
| 330 |
+
continue
|
| 331 |
+
if not latest_data:
|
| 332 |
+
time.sleep(0.5)
|
| 333 |
+
fail_cnt += 1
|
| 334 |
+
if fail_cnt > 50:
|
| 335 |
+
break
|
| 336 |
+
|
| 337 |
+
if cls.log_event.get(logging_dir, False):
|
| 338 |
+
cls.log_event[logging_dir] = False
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
if '\n' not in latest_data:
|
| 342 |
+
continue
|
| 343 |
+
latest_lines = latest_data.split('\n')
|
| 344 |
+
if latest_data[-1] != '\n':
|
| 345 |
+
latest_data = latest_lines[-1]
|
| 346 |
+
latest_lines = latest_lines[:-1]
|
| 347 |
+
else:
|
| 348 |
+
latest_data = ''
|
| 349 |
+
lines.extend(latest_lines)
|
| 350 |
+
start = cls.get_initial(lines[-1])
|
| 351 |
+
if start:
|
| 352 |
+
i = len(lines) - 2
|
| 353 |
+
while i >= 0:
|
| 354 |
+
if lines[i].startswith(start):
|
| 355 |
+
del lines[i]
|
| 356 |
+
i -= 1
|
| 357 |
+
else:
|
| 358 |
+
break
|
| 359 |
+
yield ['\n'.join(lines)] + Runtime.plot(task)
|
| 360 |
+
except IOError:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
@classmethod
|
| 364 |
+
def break_log_event(cls, task):
|
| 365 |
+
if not task:
|
| 366 |
+
return
|
| 367 |
+
pid, all_args = Runtime.parse_info_from_cmdline(task)
|
| 368 |
+
cls.log_event[all_args['logging_dir']] = True
|
| 369 |
+
|
| 370 |
+
@classmethod
|
| 371 |
+
def show_log(cls, logging_dir):
|
| 372 |
+
webbrowser.open('file://' + os.path.join(logging_dir, 'run.log'), new=2)
|
| 373 |
+
|
| 374 |
+
@classmethod
|
| 375 |
+
def start_tb(cls, logging_dir):
|
| 376 |
+
if not is_tensorboard_available():
|
| 377 |
+
gr.Error(cls.locale('tb_not_found', cls.lang)['value'])
|
| 378 |
+
return ''
|
| 379 |
+
|
| 380 |
+
logging_dir = logging_dir.strip()
|
| 381 |
+
logging_dir = logging_dir if not logging_dir.endswith(os.sep) else logging_dir[:-1]
|
| 382 |
+
if logging_dir in cls.handlers:
|
| 383 |
+
return cls.handlers[logging_dir][1]
|
| 384 |
+
|
| 385 |
+
handler, lines = run_command_in_subprocess('tensorboard', '--logdir', logging_dir, timeout=2)
|
| 386 |
+
localhost_addr = ''
|
| 387 |
+
for line in lines:
|
| 388 |
+
if 'http://localhost:' in line:
|
| 389 |
+
line = line[line.index('http://localhost:'):]
|
| 390 |
+
localhost_addr = line[:line.index(' ')]
|
| 391 |
+
cls.handlers[logging_dir] = (handler, localhost_addr)
|
| 392 |
+
logger.info('===========Tensorboard Log============')
|
| 393 |
+
logger.info('\n'.join(lines))
|
| 394 |
+
webbrowser.open(localhost_addr, new=2)
|
| 395 |
+
return localhost_addr
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
def close_tb(logging_dir):
|
| 399 |
+
if logging_dir in Runtime.handlers:
|
| 400 |
+
close_loop(Runtime.handlers[logging_dir][0])
|
| 401 |
+
Runtime.handlers.pop(logging_dir)
|
| 402 |
+
|
| 403 |
+
@staticmethod
|
| 404 |
+
def refresh_tasks(running_task=None):
|
| 405 |
+
output_dir = running_task if not running_task or 'pid:' not in running_task else None
|
| 406 |
+
process_name = 'swift'
|
| 407 |
+
negative_name = 'swift.exe'
|
| 408 |
+
cmd_name = ['pt', 'sft', 'rlhf']
|
| 409 |
+
process = []
|
| 410 |
+
selected = None
|
| 411 |
+
for proc in psutil.process_iter():
|
| 412 |
+
try:
|
| 413 |
+
cmdlines = proc.cmdline()
|
| 414 |
+
except (psutil.ZombieProcess, psutil.AccessDenied, psutil.NoSuchProcess):
|
| 415 |
+
cmdlines = []
|
| 416 |
+
if any([process_name in cmdline
|
| 417 |
+
for cmdline in cmdlines]) and not any([negative_name in cmdline
|
| 418 |
+
for cmdline in cmdlines]) and any( # noqa
|
| 419 |
+
[cmdline in cmd_name for cmdline in cmdlines]): # noqa
|
| 420 |
+
process.append(Runtime.construct_running_task(proc))
|
| 421 |
+
if output_dir is not None and any( # noqa
|
| 422 |
+
[output_dir == cmdline for cmdline in cmdlines]): # noqa
|
| 423 |
+
selected = Runtime.construct_running_task(proc)
|
| 424 |
+
if not selected:
|
| 425 |
+
if running_task and running_task in process:
|
| 426 |
+
selected = running_task
|
| 427 |
+
if not selected and process:
|
| 428 |
+
selected = process[0]
|
| 429 |
+
return gr.update(choices=process, value=selected)
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def construct_running_task(proc):
|
| 433 |
+
pid = proc.pid
|
| 434 |
+
ts = time.time()
|
| 435 |
+
create_time = proc.create_time()
|
| 436 |
+
create_time_formatted = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d, %H:%M')
|
| 437 |
+
|
| 438 |
+
return f'pid:{pid}/create:{create_time_formatted}' \
|
| 439 |
+
f'/running:{format_time(ts-create_time)}/cmd:{" ".join(proc.cmdline())}'
|
| 440 |
+
|
| 441 |
+
@staticmethod
|
| 442 |
+
def parse_info_from_cmdline(task):
|
| 443 |
+
pid = None
|
| 444 |
+
if '/cmd:' in task:
|
| 445 |
+
for i in range(3):
|
| 446 |
+
slash = task.find('/')
|
| 447 |
+
if i == 0:
|
| 448 |
+
pid = task[:slash].split(':')[1]
|
| 449 |
+
task = task[slash + 1:]
|
| 450 |
+
if 'swift sft' in task:
|
| 451 |
+
args = task.split('swift sft')[1]
|
| 452 |
+
elif 'swift pt' in task:
|
| 453 |
+
args = task.split('swift pt')[1]
|
| 454 |
+
elif 'swift rlhf' in task:
|
| 455 |
+
args = task.split('swift rlhf')[1]
|
| 456 |
+
else:
|
| 457 |
+
raise ValueError(f'Cannot parse cmd line: {task}')
|
| 458 |
+
args = [arg.strip() for arg in args.split('--') if arg.strip()]
|
| 459 |
+
all_args = {}
|
| 460 |
+
for i in range(len(args)):
|
| 461 |
+
space = args[i].find(' ')
|
| 462 |
+
splits = args[i][:space], args[i][space + 1:]
|
| 463 |
+
all_args[splits[0]] = splits[1]
|
| 464 |
+
|
| 465 |
+
output_dir = all_args['output_dir']
|
| 466 |
+
if os.path.exists(os.path.join(output_dir, 'args.json')):
|
| 467 |
+
with open(os.path.join(output_dir, 'args.json'), 'r', encoding='utf-8') as f:
|
| 468 |
+
_json = json.load(f)
|
| 469 |
+
for key in all_args.keys():
|
| 470 |
+
all_args[key] = _json.get(key)
|
| 471 |
+
if isinstance(all_args[key], list):
|
| 472 |
+
if any([' ' in value for value in all_args[key]]):
|
| 473 |
+
all_args[key] = [f'"{value}"' for value in all_args[key]]
|
| 474 |
+
all_args[key] = ' '.join(all_args[key])
|
| 475 |
+
return pid, all_args
|
| 476 |
+
|
| 477 |
+
@staticmethod
|
| 478 |
+
def kill_task(task):
|
| 479 |
+
if task:
|
| 480 |
+
pid, all_args = Runtime.parse_info_from_cmdline(task)
|
| 481 |
+
output_dir = all_args['output_dir']
|
| 482 |
+
if sys.platform == 'win32':
|
| 483 |
+
os.system(f'taskkill /f /t /pid "{pid}"')
|
| 484 |
+
else:
|
| 485 |
+
os.system(f'pkill -9 -f {output_dir}')
|
| 486 |
+
time.sleep(1)
|
| 487 |
+
Runtime.break_log_event(task)
|
| 488 |
+
return [Runtime.refresh_tasks()] + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)
|
| 489 |
+
|
| 490 |
+
@staticmethod
|
| 491 |
+
def reset():
|
| 492 |
+
return None, 'output'
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def task_changed(task, base_tab):
|
| 496 |
+
if task:
|
| 497 |
+
_, all_args = Runtime.parse_info_from_cmdline(task)
|
| 498 |
+
else:
|
| 499 |
+
all_args = {}
|
| 500 |
+
elements = list(base_tab.valid_elements().values())
|
| 501 |
+
ret = []
|
| 502 |
+
for e in elements:
|
| 503 |
+
if e.elem_id in all_args:
|
| 504 |
+
if isinstance(e, gr.Dropdown) and e.multiselect:
|
| 505 |
+
arg = all_args[e.elem_id].split(' ')
|
| 506 |
+
else:
|
| 507 |
+
arg = all_args[e.elem_id]
|
| 508 |
+
ret.append(gr.update(value=arg))
|
| 509 |
+
else:
|
| 510 |
+
ret.append(gr.update())
|
| 511 |
+
Runtime.break_log_event(task)
|
| 512 |
+
return ret + [gr.update(value=None)] * (len(Runtime.get_plot(task)) + 1)
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def plot(task):
|
| 516 |
+
plot = Runtime.get_plot(task)
|
| 517 |
+
if not task:
|
| 518 |
+
return [None] * len(plot)
|
| 519 |
+
_, all_args = Runtime.parse_info_from_cmdline(task)
|
| 520 |
+
tb_dir = all_args['logging_dir']
|
| 521 |
+
if not os.path.exists(tb_dir):
|
| 522 |
+
return [None] * len(plot)
|
| 523 |
+
fname = [
|
| 524 |
+
fname for fname in os.listdir(tb_dir)
|
| 525 |
+
if os.path.isfile(os.path.join(tb_dir, fname)) and fname.startswith('events.out')
|
| 526 |
+
]
|
| 527 |
+
if fname:
|
| 528 |
+
fname = fname[0]
|
| 529 |
+
else:
|
| 530 |
+
return [None] * len(plot)
|
| 531 |
+
tb_path = os.path.join(tb_dir, fname)
|
| 532 |
+
data = read_tensorboard_file(tb_path)
|
| 533 |
+
|
| 534 |
+
plots = []
|
| 535 |
+
for k in plot:
|
| 536 |
+
name = k['name']
|
| 537 |
+
smooth = k['smooth']
|
| 538 |
+
if name == 'train/acc':
|
| 539 |
+
if 'train/token_acc' in data:
|
| 540 |
+
name = 'train/token_acc'
|
| 541 |
+
if 'train/seq_acc' in data:
|
| 542 |
+
name = 'train/seq_acc'
|
| 543 |
+
if name == 'eval/acc':
|
| 544 |
+
if 'eval/token_acc' in data:
|
| 545 |
+
name = 'eval/token_acc'
|
| 546 |
+
if 'eval/seq_acc' in data:
|
| 547 |
+
name = 'eval/seq_acc'
|
| 548 |
+
if name not in data:
|
| 549 |
+
plots.append(None)
|
| 550 |
+
continue
|
| 551 |
+
_data = data[name]
|
| 552 |
+
steps = [d['step'] for d in _data]
|
| 553 |
+
values = [d['value'] for d in _data]
|
| 554 |
+
if len(values) == 0:
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
plt.close('all')
|
| 558 |
+
fig = plt.figure()
|
| 559 |
+
ax = fig.add_subplot()
|
| 560 |
+
# _, ax = plt.subplots(1, 1, squeeze=True, figsize=(8, 5), dpi=100)
|
| 561 |
+
ax.set_title(name)
|
| 562 |
+
if len(values) == 1:
|
| 563 |
+
ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
|
| 564 |
+
elif smooth is not None:
|
| 565 |
+
ax.plot(steps, values, color=TB_COLOR)
|
| 566 |
+
values_s = tensorboard_smoothing(values, smooth)
|
| 567 |
+
ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
|
| 568 |
+
else:
|
| 569 |
+
ax.plot(steps, values, color=TB_COLOR_SMOOTH)
|
| 570 |
+
plots.append(fig)
|
| 571 |
+
return plots
|
ms-swift/swift/ui/llm_train/save.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Save(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'save_param': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '存储参数设置',
|
| 17 |
+
'en': 'Saving settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'push_to_hub': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '推送魔搭Hub',
|
| 23 |
+
'en': 'Push to modelscope hub',
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '是否推送魔搭的模型库',
|
| 27 |
+
'en': 'Whether push the output model to modelscope hub',
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'hub_model_id': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '魔搭模型id',
|
| 33 |
+
'en': 'The model-id in modelscope',
|
| 34 |
+
},
|
| 35 |
+
'info': {
|
| 36 |
+
'zh': '设置魔搭的模型id',
|
| 37 |
+
'en': 'Set the model-id of modelscope',
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
'hub_private_repo': {
|
| 41 |
+
'label': {
|
| 42 |
+
'zh': '设置仓库私有',
|
| 43 |
+
'en': 'Model is private',
|
| 44 |
+
},
|
| 45 |
+
'info': {
|
| 46 |
+
'zh': '以私有方式推送魔搭hub',
|
| 47 |
+
'en': 'Set the model as private',
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
'hub_strategy': {
|
| 51 |
+
'label': {
|
| 52 |
+
'zh': '推送策略',
|
| 53 |
+
'en': 'Push strategy',
|
| 54 |
+
},
|
| 55 |
+
'info': {
|
| 56 |
+
'zh': '设置模型推送策略',
|
| 57 |
+
'en': 'Set the push strategy',
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
'hub_token': {
|
| 61 |
+
'label': {
|
| 62 |
+
'zh': '仓库token',
|
| 63 |
+
'en': 'The hub token',
|
| 64 |
+
},
|
| 65 |
+
'info': {
|
| 66 |
+
'zh': '该token可以在www.modelscope.cn找到',
|
| 67 |
+
'en': 'Find the token in www.modelscope.cn',
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 74 |
+
with gr.Accordion(elem_id='save_param', open=False):
|
| 75 |
+
with gr.Blocks():
|
| 76 |
+
with gr.Row():
|
| 77 |
+
gr.Checkbox(elem_id='push_to_hub', scale=20)
|
| 78 |
+
gr.Textbox(elem_id='hub_model_id', lines=1, scale=20)
|
| 79 |
+
gr.Checkbox(elem_id='hub_private_repo', scale=20)
|
| 80 |
+
gr.Dropdown(
|
| 81 |
+
elem_id='hub_strategy',
|
| 82 |
+
scale=20,
|
| 83 |
+
choices=['end', 'every_save', 'checkpoint', 'all_checkpoints'])
|
| 84 |
+
gr.Textbox(elem_id='hub_token', lines=1, scale=20)
|
ms-swift/swift/ui/llm_train/self_cog.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from swift.ui.base import BaseUI
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SelfCog(BaseUI):
|
| 10 |
+
|
| 11 |
+
group = 'llm_train'
|
| 12 |
+
|
| 13 |
+
locale_dict = {
|
| 14 |
+
'self_cognition': {
|
| 15 |
+
'label': {
|
| 16 |
+
'zh': '自我认知任务参数设置',
|
| 17 |
+
'en': 'Self cognition settings'
|
| 18 |
+
},
|
| 19 |
+
},
|
| 20 |
+
'self_cognition_sample': {
|
| 21 |
+
'label': {
|
| 22 |
+
'zh': '数据及采样条数',
|
| 23 |
+
'en': 'Dataset sample size'
|
| 24 |
+
},
|
| 25 |
+
'info': {
|
| 26 |
+
'zh': '设置数据集采样的条数',
|
| 27 |
+
'en': 'Set the dataset sample size'
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
'model_name': {
|
| 31 |
+
'label': {
|
| 32 |
+
'zh': '模型认知名称',
|
| 33 |
+
'en': 'Model name'
|
| 34 |
+
},
|
| 35 |
+
'info': {
|
| 36 |
+
'zh': '设置模型应当认知自己的名字, 格式为:中文名字 英文名字,中间以空格分隔',
|
| 37 |
+
'en': 'Set the name of the model think itself of, the format is Chinesename Englishname, split by space'
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
'model_author': {
|
| 41 |
+
'label': {
|
| 42 |
+
'zh': '模型作者',
|
| 43 |
+
'en': 'Model author'
|
| 44 |
+
},
|
| 45 |
+
'info': {
|
| 46 |
+
'zh': '设置模型认知的自己的作者, 格式为:中文作者 英文作者,中间以空格分隔',
|
| 47 |
+
'en': 'Set the author of the model, the format is Chineseauthor Englishauthor, split by space'
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def do_build_ui(cls, base_tab: Type['BaseUI']):
|
| 54 |
+
with gr.Accordion(elem_id='self_cognition', open=False):
|
| 55 |
+
with gr.Row():
|
| 56 |
+
gr.Textbox(elem_id='model_name', scale=20, is_list=True)
|
| 57 |
+
gr.Textbox(elem_id='model_author', scale=20, is_list=True)
|
ms-swift/swift/utils/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from .env import (get_dist_setting, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist,
|
| 4 |
+
is_dist_ta, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph,
|
| 5 |
+
use_hf_hub, use_torchacc)
|
| 6 |
+
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_swanlab_available,
|
| 7 |
+
is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available,
|
| 8 |
+
is_xtuner_available)
|
| 9 |
+
from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl
|
| 10 |
+
from .logger import get_logger
|
| 11 |
+
from .np_utils import get_seed, stat_array, transform_jsonl_to_df
|
| 12 |
+
from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
|
| 13 |
+
from .torch_utils import (Serializer, activate_parameters, find_all_linears, find_embedding, find_layers, find_norm,
|
| 14 |
+
freeze_parameters, gc_collect, get_current_device, get_device, get_device_count,
|
| 15 |
+
get_model_parameter_info, get_n_params_grads, init_process_group, safe_ddp_context,
|
| 16 |
+
set_default_ddp_config, set_device, show_layers, time_synchronize)
|
| 17 |
+
from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
|
| 18 |
+
get_env_args, import_external_file, lower_bound, parse_args, patch_getattr, read_multi_line,
|
| 19 |
+
seed_everything, split_list, subprocess_run, test_time, upper_bound)
|
ms-swift/swift/utils/__pycache__/np_utils.cpython-310.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
ms-swift/swift/utils/constants.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
BIN_EXTENSIONS = [
|
| 4 |
+
'.*.bin',
|
| 5 |
+
'.*.ts',
|
| 6 |
+
'.*.pt',
|
| 7 |
+
'.*.data-00000-of-00001',
|
| 8 |
+
'.*.onnx',
|
| 9 |
+
'.*.meta',
|
| 10 |
+
'.*.pb',
|
| 11 |
+
'.*.index',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
PEFT_TYPE_KEY = 'peft_type'
|
| 15 |
+
SWIFT_TYPE_KEY = 'swift_type'
|
| 16 |
+
DEFAULT_ADAPTER = 'default'
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Invoke(object):
|
| 20 |
+
KEY = 'invoked_by'
|
| 21 |
+
THIRD_PARTY = 'third_party'
|
| 22 |
+
PRETRAINED = 'from_pretrained'
|
| 23 |
+
PIPELINE = 'pipeline'
|
| 24 |
+
TRAINER = 'trainer'
|
| 25 |
+
LOCAL_TRAINER = 'local_trainer'
|
| 26 |
+
PREPROCESSOR = 'preprocessor'
|
| 27 |
+
SWIFT = 'swift'
|
ms-swift/swift/utils/logger.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import importlib.util
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from types import MethodType
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from modelscope.utils.logger import get_logger as get_ms_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Avoid circular reference
|
| 13 |
+
def _is_local_master():
|
| 14 |
+
local_rank = int(os.getenv('LOCAL_RANK', -1))
|
| 15 |
+
return local_rank in {-1, 0}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
init_loggers = {}
|
| 19 |
+
|
| 20 |
+
# old format
|
| 21 |
+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 22 |
+
logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s')
|
| 23 |
+
|
| 24 |
+
info_set = set()
|
| 25 |
+
warning_set = set()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def info_once(self, msg, *args, **kwargs):
|
| 29 |
+
hash_id = kwargs.get('hash_id') or msg
|
| 30 |
+
if hash_id in info_set:
|
| 31 |
+
return
|
| 32 |
+
info_set.add(hash_id)
|
| 33 |
+
self.info(msg)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def warning_once(self, msg, *args, **kwargs):
|
| 37 |
+
hash_id = kwargs.get('hash_id') or msg
|
| 38 |
+
if hash_id in warning_set:
|
| 39 |
+
return
|
| 40 |
+
warning_set.add(hash_id)
|
| 41 |
+
self.warning(msg)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'):
|
| 45 |
+
""" Get logging logger
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
log_file: Log filename, if specified, file handler will be added to
|
| 49 |
+
logger
|
| 50 |
+
log_level: Logging level.
|
| 51 |
+
file_mode: Specifies the mode to open the file, if filename is
|
| 52 |
+
specified (if filemode is unspecified, it defaults to 'w').
|
| 53 |
+
"""
|
| 54 |
+
if log_level is None:
|
| 55 |
+
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
|
| 56 |
+
log_level = getattr(logging, log_level, logging.INFO)
|
| 57 |
+
logger_name = __name__.split('.')[0]
|
| 58 |
+
logger = logging.getLogger(logger_name)
|
| 59 |
+
logger.propagate = False
|
| 60 |
+
if logger_name in init_loggers:
|
| 61 |
+
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
|
| 62 |
+
return logger
|
| 63 |
+
|
| 64 |
+
# handle duplicate logs to the console
|
| 65 |
+
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
| 66 |
+
# to the root logger. As logger.propagate is True by default, this root
|
| 67 |
+
# level handler causes logging messages from rank>0 processes to
|
| 68 |
+
# unexpectedly show up on the console, creating much unwanted clutter.
|
| 69 |
+
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
| 70 |
+
# at the ERROR level.
|
| 71 |
+
for handler in logger.root.handlers:
|
| 72 |
+
if type(handler) is logging.StreamHandler:
|
| 73 |
+
handler.setLevel(logging.ERROR)
|
| 74 |
+
|
| 75 |
+
stream_handler = logging.StreamHandler()
|
| 76 |
+
handlers = [stream_handler]
|
| 77 |
+
|
| 78 |
+
is_worker0 = _is_local_master()
|
| 79 |
+
|
| 80 |
+
if is_worker0 and log_file is not None:
|
| 81 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
| 82 |
+
handlers.append(file_handler)
|
| 83 |
+
|
| 84 |
+
for handler in handlers:
|
| 85 |
+
handler.setFormatter(logger_format)
|
| 86 |
+
handler.setLevel(log_level)
|
| 87 |
+
logger.addHandler(handler)
|
| 88 |
+
|
| 89 |
+
if is_worker0:
|
| 90 |
+
logger.setLevel(log_level)
|
| 91 |
+
else:
|
| 92 |
+
logger.setLevel(logging.ERROR)
|
| 93 |
+
|
| 94 |
+
init_loggers[logger_name] = True
|
| 95 |
+
|
| 96 |
+
logger.info_once = MethodType(info_once, logger)
|
| 97 |
+
logger.warning_once = MethodType(warning_once, logger)
|
| 98 |
+
return logger
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
logger = get_logger()
|
| 102 |
+
ms_logger = get_ms_logger()
|
| 103 |
+
|
| 104 |
+
logger.handlers[0].setFormatter(logger_format)
|
| 105 |
+
ms_logger.handlers[0].setFormatter(logger_format)
|
| 106 |
+
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
|
| 107 |
+
if _is_local_master():
|
| 108 |
+
ms_logger.setLevel(log_level)
|
| 109 |
+
else:
|
| 110 |
+
ms_logger.setLevel(logging.ERROR)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@contextmanager
|
| 114 |
+
def ms_logger_ignore_error():
|
| 115 |
+
ms_logger = get_ms_logger()
|
| 116 |
+
origin_log_level = ms_logger.level
|
| 117 |
+
ms_logger.setLevel(logging.CRITICAL)
|
| 118 |
+
try:
|
| 119 |
+
yield
|
| 120 |
+
finally:
|
| 121 |
+
ms_logger.setLevel(origin_log_level)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
|
| 125 |
+
for handler in logger.handlers:
|
| 126 |
+
if isinstance(handler, logging.FileHandler):
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
if importlib.util.find_spec('torch') is not None:
|
| 130 |
+
is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0}
|
| 131 |
+
else:
|
| 132 |
+
is_worker0 = True
|
| 133 |
+
|
| 134 |
+
if is_worker0 and log_file is not None:
|
| 135 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
| 136 |
+
file_handler.setFormatter(logger_format)
|
| 137 |
+
file_handler.setLevel(log_level)
|
| 138 |
+
logger.addHandler(file_handler)
|
ms-swift/swift/utils/tb_utils.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
| 8 |
+
|
| 9 |
+
Item = Dict[str, float]
|
| 10 |
+
TB_COLOR, TB_COLOR_SMOOTH = '#FFE2D9', '#FF7043'
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_tensorboard_file(fpath: str) -> Dict[str, List[Item]]:
|
| 14 |
+
if not os.path.isfile(fpath):
|
| 15 |
+
raise FileNotFoundError(f'fpath: {fpath}')
|
| 16 |
+
ea = EventAccumulator(fpath)
|
| 17 |
+
ea.Reload()
|
| 18 |
+
res: Dict[str, List[Item]] = {}
|
| 19 |
+
tags = ea.Tags()['scalars']
|
| 20 |
+
for tag in tags:
|
| 21 |
+
values = ea.Scalars(tag)
|
| 22 |
+
r: List[Item] = []
|
| 23 |
+
for v in values:
|
| 24 |
+
r.append({'step': v.step, 'value': v.value})
|
| 25 |
+
res[tag] = r
|
| 26 |
+
return res
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def tensorboard_smoothing(values: List[float], smooth: float = 0.9) -> List[float]:
|
| 30 |
+
norm_factor = 0
|
| 31 |
+
x = 0
|
| 32 |
+
res: List[float] = []
|
| 33 |
+
for i in range(len(values)):
|
| 34 |
+
x = x * smooth + values[i] # Exponential decay
|
| 35 |
+
norm_factor *= smooth
|
| 36 |
+
norm_factor += 1
|
| 37 |
+
res.append(x / norm_factor)
|
| 38 |
+
return res
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def plot_images(images_dir: str,
|
| 42 |
+
tb_dir: str,
|
| 43 |
+
smooth_key: Optional[List[str]] = None,
|
| 44 |
+
smooth_val: float = 0.9,
|
| 45 |
+
figsize: Tuple[int, int] = (8, 5),
|
| 46 |
+
dpi: int = 100) -> None:
|
| 47 |
+
"""Using tensorboard's data content to plot images"""
|
| 48 |
+
smooth_key = smooth_key or []
|
| 49 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 50 |
+
fname = [fname for fname in os.listdir(tb_dir) if os.path.isfile(os.path.join(tb_dir, fname))][0]
|
| 51 |
+
tb_path = os.path.join(tb_dir, fname)
|
| 52 |
+
data = read_tensorboard_file(tb_path)
|
| 53 |
+
|
| 54 |
+
for k in data.keys():
|
| 55 |
+
_data = data[k]
|
| 56 |
+
steps = [d['step'] for d in _data]
|
| 57 |
+
values = [d['value'] for d in _data]
|
| 58 |
+
if len(values) == 0:
|
| 59 |
+
continue
|
| 60 |
+
_, ax = plt.subplots(1, 1, squeeze=True, figsize=figsize, dpi=dpi)
|
| 61 |
+
ax.set_title(k)
|
| 62 |
+
if len(values) == 1:
|
| 63 |
+
ax.scatter(steps, values, color=TB_COLOR_SMOOTH)
|
| 64 |
+
elif k in smooth_key:
|
| 65 |
+
ax.plot(steps, values, color=TB_COLOR)
|
| 66 |
+
values_s = tensorboard_smoothing(values, smooth_val)
|
| 67 |
+
ax.plot(steps, values_s, color=TB_COLOR_SMOOTH)
|
| 68 |
+
else:
|
| 69 |
+
ax.plot(steps, values, color=TB_COLOR_SMOOTH)
|
| 70 |
+
fpath = os.path.join(images_dir, k.replace('/', '_').replace('.', '_'))
|
| 71 |
+
plt.savefig(fpath, dpi=dpi, bbox_inches='tight')
|
| 72 |
+
plt.close()
|
ms-swift/swift/utils/torch_utils.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import gc
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
import uuid
|
| 9 |
+
from bisect import bisect_right
|
| 10 |
+
from contextlib import contextmanager, nullcontext
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from datasets.utils.filelock import FileLock
|
| 18 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 19 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 20 |
+
from transformers.utils import is_torch_cuda_available, is_torch_mps_available, is_torch_npu_available
|
| 21 |
+
|
| 22 |
+
from .env import get_dist_setting, is_dist, is_dist_ta, is_local_master, is_master
|
| 23 |
+
from .logger import get_logger
|
| 24 |
+
from .utils import deep_getattr
|
| 25 |
+
|
| 26 |
+
logger = get_logger()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _find_local_mac() -> str:
|
| 30 |
+
mac = uuid.getnode()
|
| 31 |
+
mac_address = ':'.join(('%012x' % mac)[i:i + 2] for i in range(0, 12, 2))
|
| 32 |
+
return mac_address
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_n_params_grads(model) -> Tuple[List[int], List[int]]:
|
| 36 |
+
n_params, n_grads = [], []
|
| 37 |
+
for p in model.parameters():
|
| 38 |
+
if is_deepspeed_zero3_enabled():
|
| 39 |
+
import deepspeed
|
| 40 |
+
context = deepspeed.zero.GatheredParameters(p)
|
| 41 |
+
else:
|
| 42 |
+
context = nullcontext()
|
| 43 |
+
with context:
|
| 44 |
+
n_params.append(p.numel())
|
| 45 |
+
n_grads.append(p.numel() if p.requires_grad else 0)
|
| 46 |
+
return n_params, n_grads
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_model_parameter_info(model: nn.Module, name: Optional[str] = None) -> str:
|
| 50 |
+
n_params, n_grads = get_n_params_grads(model)
|
| 51 |
+
n_params = sum(n_params)
|
| 52 |
+
n_grads = sum(n_grads)
|
| 53 |
+
n_buffers = sum(p.numel() for p in model.buffers())
|
| 54 |
+
|
| 55 |
+
if name is None:
|
| 56 |
+
name = model.__class__.__name__
|
| 57 |
+
|
| 58 |
+
n_params /= 1e6
|
| 59 |
+
n_grads /= 1e6
|
| 60 |
+
n_buffers /= 1e6
|
| 61 |
+
s = (f'{name}: '
|
| 62 |
+
f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable '
|
| 63 |
+
f'[{100 * n_grads / n_params:.4f}%]), '
|
| 64 |
+
f'{n_buffers:.4f}M Buffers.')
|
| 65 |
+
return s
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def find_sub_module(module: torch.nn.Module, module_name: str) -> List[torch.nn.Module]:
|
| 69 |
+
_modules = list()
|
| 70 |
+
for name, sub_module in module.named_modules():
|
| 71 |
+
if not name:
|
| 72 |
+
continue
|
| 73 |
+
if name.endswith(module_name):
|
| 74 |
+
_modules.append(sub_module)
|
| 75 |
+
return _modules
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def show_layers(model: nn.Module, max_lines: Optional[int] = 20) -> None:
|
| 79 |
+
named_p = list(model.named_parameters())
|
| 80 |
+
for i, (n, p) in enumerate(named_p):
|
| 81 |
+
if max_lines is not None and i >= max_lines:
|
| 82 |
+
logger.info('...')
|
| 83 |
+
break
|
| 84 |
+
logger.info(f'[{n}]: requires_grad={p.requires_grad}, dtype={p.dtype}, device={p.device}')
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def freeze_parameters(model: nn.Module,
|
| 88 |
+
freeze_parameters_ratio: float,
|
| 89 |
+
freeze_parameters: List[str],
|
| 90 |
+
freeze_parameters_regex: Optional[str] = None) -> None:
|
| 91 |
+
if freeze_parameters_ratio > 0:
|
| 92 |
+
n_parameters = get_n_params_grads(model)[0]
|
| 93 |
+
n_parameters = np.array(n_parameters, dtype=np.int64)
|
| 94 |
+
n_freeze_parameters = int(np.sum(n_parameters) * freeze_parameters_ratio)
|
| 95 |
+
n_parameters_cs = np.cumsum(n_parameters)
|
| 96 |
+
idx = bisect_right(n_parameters_cs, n_freeze_parameters)
|
| 97 |
+
for _, p in zip(range(idx), model.parameters()):
|
| 98 |
+
p.requires_grad = False
|
| 99 |
+
|
| 100 |
+
if len(freeze_parameters) > 0:
|
| 101 |
+
for n, p in model.named_parameters():
|
| 102 |
+
for freeze_p in freeze_parameters:
|
| 103 |
+
if n.startswith(freeze_p):
|
| 104 |
+
p.requires_grad = False
|
| 105 |
+
|
| 106 |
+
if freeze_parameters_regex is not None:
|
| 107 |
+
try:
|
| 108 |
+
pattern = re.compile(freeze_parameters_regex)
|
| 109 |
+
except re.error as e:
|
| 110 |
+
logger.warning(f"Invalid freeze_parameters_regex '{freeze_parameters_regex}': {e}")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
for n, p in model.named_parameters():
|
| 114 |
+
if pattern.search(n):
|
| 115 |
+
p.requires_grad = False
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def activate_parameters(model: nn.Module,
|
| 119 |
+
additional_trainable_parameters: List[str],
|
| 120 |
+
trainable_parameters_regex: Optional[str] = None) -> None:
|
| 121 |
+
has_activate = False
|
| 122 |
+
if len(additional_trainable_parameters) > 0:
|
| 123 |
+
for n, p in model.named_parameters():
|
| 124 |
+
for additional_tp in additional_trainable_parameters:
|
| 125 |
+
if n.startswith(additional_tp):
|
| 126 |
+
p.requires_grad = True
|
| 127 |
+
has_activate = True
|
| 128 |
+
if not has_activate:
|
| 129 |
+
logger.warning('len(additional_trainable_parameters) > 0 but no parameters are activated. '
|
| 130 |
+
f'additional_trainable_parameters: {additional_trainable_parameters}')
|
| 131 |
+
|
| 132 |
+
has_activate = False
|
| 133 |
+
if trainable_parameters_regex is not None:
|
| 134 |
+
try:
|
| 135 |
+
pattern = re.compile(trainable_parameters_regex)
|
| 136 |
+
except re.error as e:
|
| 137 |
+
logger.warning(f"Invalid trainable_parameters_regex '{trainable_parameters_regex}': {e}")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
for n, p in model.named_parameters():
|
| 141 |
+
if pattern.search(n):
|
| 142 |
+
p.requires_grad = True
|
| 143 |
+
has_activate = True
|
| 144 |
+
|
| 145 |
+
if not has_activate:
|
| 146 |
+
logger.warning('trainable_parameters_regex is provided but no parameters are activated. '
|
| 147 |
+
f'trainable_parameters_regex: {trainable_parameters_regex}')
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def time_synchronize() -> float:
|
| 151 |
+
torch.cuda.synchronize()
|
| 152 |
+
return time.perf_counter() # second
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _get_max_memory(device_ids: List[int]) -> Dict[Union[int, str], int]:
|
| 156 |
+
"""add feat in accelerate to support MP + DDP"""
|
| 157 |
+
import psutil
|
| 158 |
+
# Make sure CUDA is initialized on each GPU to have the right memory info.
|
| 159 |
+
for i in device_ids:
|
| 160 |
+
_ = torch.tensor([0], device=i)
|
| 161 |
+
|
| 162 |
+
device_ids_set = set(device_ids)
|
| 163 |
+
max_memory = {}
|
| 164 |
+
for i in range(get_device_count()):
|
| 165 |
+
max_memory[i] = 0
|
| 166 |
+
if i in device_ids_set:
|
| 167 |
+
max_memory[i] = torch.cuda.mem_get_info(i)[0]
|
| 168 |
+
max_memory['cpu'] = psutil.virtual_memory().available
|
| 169 |
+
return max_memory
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _sync_max_memory(max_memory: Dict[Union[int, str], int]) -> Dict[Union[int, str], int]:
|
| 173 |
+
"""Make sure that the model structure of MP(device_map) is the same, when using DDP."""
|
| 174 |
+
max_memory_list = [v for k, v in max_memory.items() if (v > 0 and k != 'cpu')]
|
| 175 |
+
_, local_rank, world_size, _ = get_dist_setting()
|
| 176 |
+
src_tensor = torch.tensor(max_memory_list).to(local_rank)
|
| 177 |
+
tgt_tensor_list = [torch.zeros_like(src_tensor) for _ in range(world_size)]
|
| 178 |
+
dist.all_gather(tgt_tensor_list, src_tensor)
|
| 179 |
+
tgt_tensor = torch.stack(tgt_tensor_list, dim=0)
|
| 180 |
+
new_max_memory_iter = iter(tgt_tensor.min(dim=0)[0].tolist())
|
| 181 |
+
new_max_memory = {}
|
| 182 |
+
for k, v in max_memory.items():
|
| 183 |
+
new_max_memory[k] = v
|
| 184 |
+
if v > 0 and k != 'cpu':
|
| 185 |
+
new_max_memory[k] = next(new_max_memory_iter)
|
| 186 |
+
return new_max_memory
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def find_layers(
|
| 190 |
+
model: nn.Module,
|
| 191 |
+
cond: Callable[[str, nn.Module], bool],
|
| 192 |
+
sub_module: Optional[str] = None,
|
| 193 |
+
min_name_len: Optional[int] = None,
|
| 194 |
+
) -> List[str]:
|
| 195 |
+
# The content of target_module_names cannot exist in inner_nodes.
|
| 196 |
+
sub_module_str = sub_module
|
| 197 |
+
if sub_module is None:
|
| 198 |
+
sub_module = model
|
| 199 |
+
else:
|
| 200 |
+
sub_module = deep_getattr(model, sub_module)
|
| 201 |
+
inner_nodes = set()
|
| 202 |
+
for name, module in model.named_modules():
|
| 203 |
+
name = re.sub(r'\d+\.', '{}.', name)
|
| 204 |
+
if not cond(name, module):
|
| 205 |
+
inner_nodes.add(name)
|
| 206 |
+
target_module_names = set()
|
| 207 |
+
for name, module in sub_module.named_modules():
|
| 208 |
+
if sub_module_str:
|
| 209 |
+
name = f'{sub_module_str}.{name}' if name else sub_module_str
|
| 210 |
+
if cond(name, module):
|
| 211 |
+
module_name_list = name.split('.')
|
| 212 |
+
module_name = module_name_list.pop()
|
| 213 |
+
i = 1
|
| 214 |
+
for inner_node in inner_nodes:
|
| 215 |
+
while module_name_list and inner_node.endswith(re.sub(
|
| 216 |
+
r'\d+\.', '{}.', module_name)) or min_name_len and i < min_name_len:
|
| 217 |
+
module_name = f'{module_name_list.pop()}.{module_name}'
|
| 218 |
+
i += 1
|
| 219 |
+
target_module_names.add(module_name)
|
| 220 |
+
return list(target_module_names)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def find_norm(model: nn.Module) -> List[str]:
|
| 224 |
+
# find_layer_norm
|
| 225 |
+
return find_layers(
|
| 226 |
+
model,
|
| 227 |
+
lambda name, module: isinstance(module, torch.nn.LayerNorm) or 'rmsnorm' in module.__class__.__name__.lower())
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def find_embedding(model: nn.Module) -> List[str]:
|
| 231 |
+
return find_layers(model, lambda name, module: isinstance(module, torch.nn.Embedding))
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def find_all_linears(model, model_arch=None, extra_layers=None, sub_module=None):
|
| 235 |
+
if model_arch is None:
|
| 236 |
+
from swift.llm import get_model_arch
|
| 237 |
+
model_arch = get_model_arch(model.model_meta.model_arch)
|
| 238 |
+
# lm_head
|
| 239 |
+
if model_arch and model_arch.lm_head:
|
| 240 |
+
output = model_arch.lm_head
|
| 241 |
+
idx = output.rfind('.')
|
| 242 |
+
lm_head_name = output[idx + 1:]
|
| 243 |
+
else:
|
| 244 |
+
lm_head_name = 'lm_head'
|
| 245 |
+
# 'score', 'classifier': classification model
|
| 246 |
+
# 'v_head': reward model
|
| 247 |
+
ignore_layers = [lm_head_name, 'score', 'v_head', 'classifier'] + ['lora_A', 'lora_B', 'base_layer']
|
| 248 |
+
ignore_linear_cls = [
|
| 249 |
+
'glulinear' # phi4-mm
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
def _cond(name, module):
|
| 253 |
+
module_name = module.__class__.__name__.lower()
|
| 254 |
+
if (extra_layers and isinstance(module, tuple(extra_layers)) or
|
| 255 |
+
('linear' in module_name and all(linear_cls not in module_name
|
| 256 |
+
for linear_cls in ignore_linear_cls))) and all(layer not in name
|
| 257 |
+
for layer in ignore_layers):
|
| 258 |
+
return True
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
return find_layers(model, _cond, sub_module=sub_module)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@contextmanager
|
| 265 |
+
def safe_ddp_context(hash_id: Optional[str], use_barrier: bool = False):
|
| 266 |
+
if use_barrier and dist.is_initialized():
|
| 267 |
+
if is_dist() or is_dist_ta():
|
| 268 |
+
if not is_master():
|
| 269 |
+
dist.barrier()
|
| 270 |
+
if not is_local_master():
|
| 271 |
+
# Compatible with multi-machine scenarios,
|
| 272 |
+
# where each machine uses different storage hardware.
|
| 273 |
+
dist.barrier()
|
| 274 |
+
yield
|
| 275 |
+
if is_dist() or is_dist_ta():
|
| 276 |
+
if is_master():
|
| 277 |
+
dist.barrier()
|
| 278 |
+
if is_local_master():
|
| 279 |
+
dist.barrier()
|
| 280 |
+
elif hash_id is not None:
|
| 281 |
+
lock_dir = os.path.join(get_cache_dir(), 'lockers')
|
| 282 |
+
os.makedirs(lock_dir, exist_ok=True)
|
| 283 |
+
file_path = hashlib.sha256(hash_id.encode('utf-8')).hexdigest() + '.lock'
|
| 284 |
+
file_path = os.path.join(lock_dir, file_path)
|
| 285 |
+
with FileLock(file_path):
|
| 286 |
+
yield
|
| 287 |
+
else:
|
| 288 |
+
yield
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_device(local_rank: Optional[Union[str, int]] = None) -> str:
|
| 292 |
+
if local_rank is None:
|
| 293 |
+
local_rank = max(0, get_dist_setting()[1])
|
| 294 |
+
local_rank = str(local_rank)
|
| 295 |
+
if is_torch_npu_available():
|
| 296 |
+
device = 'npu:{}'.format(local_rank)
|
| 297 |
+
elif is_torch_mps_available():
|
| 298 |
+
device = 'mps:{}'.format(local_rank)
|
| 299 |
+
elif is_torch_cuda_available():
|
| 300 |
+
device = 'cuda:{}'.format(local_rank)
|
| 301 |
+
else:
|
| 302 |
+
device = 'cpu'
|
| 303 |
+
|
| 304 |
+
return device
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def get_current_device():
|
| 308 |
+
if is_torch_npu_available():
|
| 309 |
+
current_device = torch.npu.current_device()
|
| 310 |
+
elif is_torch_cuda_available():
|
| 311 |
+
current_device = torch.cuda.current_device()
|
| 312 |
+
elif is_torch_mps_available():
|
| 313 |
+
current_device = 'mps'
|
| 314 |
+
else:
|
| 315 |
+
current_device = 'cpu'
|
| 316 |
+
return current_device
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def set_device(local_rank: Optional[Union[str, int]] = None):
|
| 320 |
+
if local_rank is None:
|
| 321 |
+
local_rank = max(0, get_dist_setting()[1])
|
| 322 |
+
if is_torch_npu_available():
|
| 323 |
+
torch.npu.set_device(local_rank)
|
| 324 |
+
elif is_torch_cuda_available():
|
| 325 |
+
torch.cuda.set_device(local_rank)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def get_device_count() -> int:
|
| 329 |
+
if is_torch_npu_available():
|
| 330 |
+
return torch.npu.device_count()
|
| 331 |
+
elif is_torch_cuda_available():
|
| 332 |
+
return torch.cuda.device_count()
|
| 333 |
+
else:
|
| 334 |
+
return 0
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def gc_collect() -> None:
|
| 338 |
+
gc.collect()
|
| 339 |
+
if is_torch_npu_available():
|
| 340 |
+
torch.npu.empty_cache()
|
| 341 |
+
elif is_torch_mps_available():
|
| 342 |
+
torch.mps.empty_cache()
|
| 343 |
+
elif is_torch_cuda_available():
|
| 344 |
+
torch.cuda.empty_cache()
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class Serializer:
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def to_tensor(obj):
|
| 351 |
+
res = pickle.dumps(obj)
|
| 352 |
+
res = np.array([len(res)], dtype=np.int64).tobytes() + res
|
| 353 |
+
res = np.frombuffer(res, dtype=np.uint8).copy()
|
| 354 |
+
res = torch.from_numpy(res)
|
| 355 |
+
return res
|
| 356 |
+
|
| 357 |
+
@staticmethod
|
| 358 |
+
def from_tensor(obj):
|
| 359 |
+
if isinstance(obj, torch.Tensor):
|
| 360 |
+
obj = obj.cpu().numpy()
|
| 361 |
+
res = obj.tobytes()
|
| 362 |
+
buffer_size = np.frombuffer(res[:8], dtype=np.int64)[0]
|
| 363 |
+
res = res[8:]
|
| 364 |
+
return pickle.loads(res[:buffer_size])
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def set_default_ddp_config():
|
| 368 |
+
# It runs normally with Python as well.
|
| 369 |
+
rank = int(os.getenv('RANK', -1))
|
| 370 |
+
if rank == -1:
|
| 371 |
+
os.environ['NPROC_PER_NODE'] = '1'
|
| 372 |
+
os.environ['RANK'] = '0'
|
| 373 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 374 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 375 |
+
os.environ['LOCAL_WORLD_SIZE'] = '1'
|
| 376 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 377 |
+
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def init_process_group(ddp_backend: Optional[str] = None):
|
| 381 |
+
if dist.is_initialized():
|
| 382 |
+
return
|
| 383 |
+
set_device()
|
| 384 |
+
if ddp_backend is None:
|
| 385 |
+
if is_torch_npu_available():
|
| 386 |
+
ddp_backend = 'hccl'
|
| 387 |
+
elif torch.cuda.is_available():
|
| 388 |
+
ddp_backend = 'nccl'
|
| 389 |
+
else:
|
| 390 |
+
ddp_backend = 'gloo'
|
| 391 |
+
dist.init_process_group(backend=ddp_backend)
|
ms-swift/tests/deploy/test_dataset.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def _test_client(port=8000):
|
| 2 |
+
import time
|
| 3 |
+
import aiohttp
|
| 4 |
+
from swift.llm import InferClient, InferRequest, RequestConfig, load_dataset, run_deploy
|
| 5 |
+
dataset = load_dataset(['AI-ModelScope/alpaca-gpt4-data-zh#1000'], num_proc=4)
|
| 6 |
+
infer_client = InferClient(port=port)
|
| 7 |
+
while True:
|
| 8 |
+
try:
|
| 9 |
+
infer_client.models
|
| 10 |
+
break
|
| 11 |
+
except Exception:
|
| 12 |
+
time.sleep(1)
|
| 13 |
+
pass
|
| 14 |
+
infer_requests = []
|
| 15 |
+
for data in dataset[0]:
|
| 16 |
+
infer_requests.append(InferRequest(**data))
|
| 17 |
+
request_config = RequestConfig(seed=42, max_tokens=256, temperature=0.8)
|
| 18 |
+
|
| 19 |
+
resp = infer_client.infer(infer_requests, request_config=request_config, use_tqdm=False)
|
| 20 |
+
print(len(resp))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _test(infer_backend):
|
| 24 |
+
import os
|
| 25 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 26 |
+
|
| 27 |
+
from swift.llm import DeployArguments
|
| 28 |
+
from swift.llm import run_deploy
|
| 29 |
+
args = DeployArguments(model='Qwen/Qwen2-7B-Instruct', infer_backend=infer_backend, verbose=False)
|
| 30 |
+
with run_deploy(args) as port:
|
| 31 |
+
_test_client(port)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_vllm():
|
| 35 |
+
_test('vllm')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_lmdeploy():
|
| 39 |
+
_test('lmdeploy')
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_pt():
|
| 43 |
+
_test('pt')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_vllm_origin():
|
| 47 |
+
import subprocess
|
| 48 |
+
import sys
|
| 49 |
+
from modelscope import snapshot_download
|
| 50 |
+
model_dir = snapshot_download('Qwen/Qwen2-7B-Instruct')
|
| 51 |
+
args = [sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--model', model_dir]
|
| 52 |
+
process = subprocess.Popen(args)
|
| 53 |
+
_test_client()
|
| 54 |
+
process.terminate()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
# test_vllm_origin()
|
| 59 |
+
# test_vllm()
|
| 60 |
+
test_lmdeploy()
|
| 61 |
+
# test_pt()
|