Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Union
import json
from swift.hub import get_hub
from swift.llm import Processor, Template, get_model_tokenizer, get_template, load_by_unsloth, safe_snapshot_download
from swift.llm.utils import get_ckpt_dir
from swift.plugin import extra_tuners
from swift.utils import (check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master,
set_device, use_hf_hub)
from .data_args import DataArguments
from .generation_args import GenerationArguments
from .model_args import ModelArguments
from .quant_args import QuantizeArguments
from .template_args import TemplateArguments
logger = get_logger()
def get_supported_tuners():
return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft', 'bone'
} | set(extra_tuners.keys())
@dataclass
class CompatArguments:
ckpt_dir: Optional[str] = None
lora_modules: List[str] = field(default_factory=list)
def _handle_ckpt_dir(self: 'BaseArguments'):
assert os.path.isdir(self.ckpt_dir), f'self.ckpt_dir: {self.ckpt_dir}'
if (os.path.exists(os.path.join(self.ckpt_dir, 'adapter_config.json'))
or os.path.exists(os.path.join(self.ckpt_dir, 'default', 'adapter_config.json'))
or os.path.exists(os.path.join(self.ckpt_dir, 'reft'))):
if self.ckpt_dir in self.adapters:
return
self.adapters.insert(0, self.ckpt_dir)
else:
self.model = self.ckpt_dir
self.ckpt_dir = None
def __post_init__(self: 'BaseArguments'):
if self.ckpt_dir is not None:
self._handle_ckpt_dir()
if len(self.lora_modules) > 0:
self.adapters += self.lora_modules
@dataclass
class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments,
ModelArguments):
"""
BaseArguments class is a dataclass that inherits from multiple argument classes:
GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments.
Args:
tuner_backend(str): Support peft or unsloth.
train_type(str): The training type, support all supported tuners and `full`.
seed (int): Random seed for reproducibility. Default is 42.
model_kwargs (Optional[str]): Additional keyword arguments for the model. Default is None.
load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False.
use_hf (bool): Flag to determine if Hugging Face should be used. Default is False.
hub_token (Optional[str]): SDK token for authentication. Default is None.
custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None.
ignore_args_error (bool): Flag to ignore argument errors for notebook compatibility. Default is False.
use_swift_lora (bool): Use swift lora, a compatible argument
"""
tuner_backend: Literal['peft', 'unsloth'] = 'peft'
train_type: str = field(default='lora', metadata={'help': f'train_type choices: {list(get_supported_tuners())}'})
adapters: List[str] = field(default_factory=list)
external_plugins: List[str] = field(default_factory=list)
seed: int = 42
model_kwargs: Optional[Union[dict, str]] = None
load_args: bool = True
load_data_args: bool = False
use_hf: bool = False
# None: use env var `MODELSCOPE_API_TOKEN`
hub_token: Optional[str] = field(
default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'})
custom_register_path: List[str] = field(default_factory=list) # .py
# extra
ignore_args_error: bool = False # True: notebook compatibility
use_swift_lora: bool = False # True for using tuner_backend == swift, don't specify this unless you know what you are doing # noqa
def _prepare_training_args(self, training_args: Dict[str, Any]) -> None:
pass
def _init_custom_register(self) -> None:
"""Register custom .py file to datasets"""
if isinstance(self.custom_register_path, str):
self.custom_register_path = [self.custom_register_path]
if not self.custom_register_path:
return
for path in self.custom_register_path:
import_external_file(path)
logger.info(f'Successfully registered {self.custom_register_path}.')
def _import_external_plugins(self):
if isinstance(self.external_plugins, str):
self.external_plugins = [self.external_plugins]
if not self.external_plugins:
return
for external_plugin in self.external_plugins:
import_external_file(external_plugin)
logger.info(f'Successfully imported external_plugins: {self.external_plugins}.')
@staticmethod
def _check_is_adapter(adapter_dir: str) -> bool:
if (os.path.exists(os.path.join(adapter_dir, 'adapter_config.json'))
or os.path.exists(os.path.join(adapter_dir, 'default', 'adapter_config.json'))
or os.path.exists(os.path.join(adapter_dir, 'reft'))):
return True
return False
def _init_adapters(self):
if isinstance(self.adapters, str):
self.adapters = [self.adapters]
self.adapters = [
safe_snapshot_download(adapter, use_hf=self.use_hf, hub_token=self.hub_token) for adapter in self.adapters
]
def __post_init__(self):
if self.use_hf or use_hf_hub():
self.use_hf = True
os.environ['USE_HF'] = '1'
CompatArguments.__post_init__(self)
self._init_adapters()
self._init_ckpt_dir()
self._init_custom_register()
self._import_external_plugins()
self._init_model_kwargs()
# The Seq2SeqTrainingArguments has a property called world_size, which cannot be assigned a value.
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '
f'world_size: {self.global_world_size}, local_world_size: {self.local_world_size}')
if self.train_type not in extra_tuners:
for adapter in self.adapters:
assert self._check_is_adapter(adapter), (
f'`{adapter}` is not an adapter, please try using `--model` to pass it.')
ModelArguments.__post_init__(self)
QuantizeArguments.__post_init__(self)
TemplateArguments.__post_init__(self)
DataArguments.__post_init__(self)
self.hub = get_hub(self.use_hf)
if self.hub.try_login(self.hub_token):
logger.info('hub login successful!')
def _init_model_kwargs(self):
"""Prepare model kwargs and set them to the env"""
self.model_kwargs: Dict[str, Any] = self.parse_to_dict(self.model_kwargs)
for k, v in self.model_kwargs.items():
k = k.upper()
os.environ[k] = str(v)
@property
def is_adapter(self) -> bool:
return self.train_type not in {'full'}
@property
def supported_tuners(self):
return get_supported_tuners()
@property
def adapters_can_be_merged(self):
return {'lora', 'longlora', 'llamapro', 'adalora'}
@classmethod
def from_pretrained(cls, checkpoint_dir: str):
self = super().__new__(cls)
self.load_data_args = True
self.ckpt_dir = checkpoint_dir
self.load_args_from_ckpt()
all_keys = list(f.name for f in fields(BaseArguments))
for key in all_keys:
if not hasattr(self, key):
setattr(self, key, None)
return self
def _init_ckpt_dir(self, adapters=None):
# compat megatron
model = self.model or getattr(self, 'mcore_model', None) or getattr(self, 'load', None)
self.ckpt_dir = get_ckpt_dir(model, adapters or self.adapters)
if self.ckpt_dir and self.load_args:
self.load_args_from_ckpt()
def load_args_from_ckpt(self) -> None:
from ..train_args import TrainArguments
args_path = os.path.join(self.ckpt_dir, 'args.json')
assert os.path.exists(args_path), f'args_path: {args_path}'
with open(args_path, 'r', encoding='utf-8') as f:
old_args = json.load(f)
all_keys = list(f.name for f in fields(BaseArguments))
data_keys = list(f.name for f in fields(DataArguments))
load_keys = [
# quant_args
'bnb_4bit_quant_type',
'bnb_4bit_use_double_quant',
# base_args
'train_type',
'tuner_backend',
'use_swift_lora',
# data_args
'model_name',
'model_author',
'split_dataset_ratio',
# template_args
'use_chat_template',
]
skip_keys = list(f.name for f in fields(GenerationArguments) + fields(CompatArguments)) + ['adapters']
if not isinstance(self, TrainArguments):
skip_keys += ['max_length']
all_keys = set(all_keys) - set(skip_keys)
for key, old_value in old_args.items():
if key not in all_keys or old_value is None:
continue
if not self.load_data_args and key in data_keys:
continue
value = getattr(self, key, None)
if value is None or isinstance(value, (list, tuple)) and len(value) == 0 or key in load_keys:
setattr(self, key, old_value)
logger.info(f'Successfully loaded {args_path}.')
def save_args(self, output_dir=None) -> None:
if is_master():
output_dir = output_dir or self.output_dir
os.makedirs(output_dir, exist_ok=True)
fpath = os.path.join(output_dir, 'args.json')
logger.info(f'The {self.__class__.__name__} will be saved in: {fpath}')
with open(fpath, 'w', encoding='utf-8') as f:
json.dump(check_json_format(self.__dict__), f, ensure_ascii=False, indent=2)
def _init_device(self):
if is_dist():
set_device()
def get_template(self, processor: 'Processor', template_type: Optional[str] = None) -> 'Template':
template_kwargs = self.get_template_kwargs()
template_type = template_type or self.template
template = get_template(template_type, processor, **template_kwargs)
return template
def get_model_processor(self,
*,
model=None,
model_type=None,
model_revision=None,
task_type=None,
num_labels=None,
**kwargs):
if self.tuner_backend == 'unsloth':
return load_by_unsloth(self)
kwargs.update(self.get_model_kwargs())
# compat rlhf
kwargs['model_id_or_path'] = model or self.model
kwargs['model_type'] = model_type or self.model_type
kwargs['model_revision'] = model_revision or self.model_revision
kwargs['task_type'] = task_type or self.task_type
kwargs['num_labels'] = num_labels or self.num_labels
return get_model_tokenizer(**kwargs)