|
|
|
|
|
import os |
|
|
from dataclasses import dataclass, field, fields |
|
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from swift.hub import get_hub |
|
|
from swift.llm import Processor, Template, get_model_tokenizer, get_template, load_by_unsloth, safe_snapshot_download |
|
|
from swift.llm.utils import get_ckpt_dir |
|
|
from swift.plugin import extra_tuners |
|
|
from swift.utils import (check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master, |
|
|
set_device, use_hf_hub) |
|
|
from .data_args import DataArguments |
|
|
from .generation_args import GenerationArguments |
|
|
from .model_args import ModelArguments |
|
|
from .quant_args import QuantizeArguments |
|
|
from .template_args import TemplateArguments |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def get_supported_tuners(): |
|
|
return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft', 'bone' |
|
|
} | set(extra_tuners.keys()) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CompatArguments: |
|
|
ckpt_dir: Optional[str] = None |
|
|
lora_modules: List[str] = field(default_factory=list) |
|
|
|
|
|
def _handle_ckpt_dir(self: 'BaseArguments'): |
|
|
assert os.path.isdir(self.ckpt_dir), f'self.ckpt_dir: {self.ckpt_dir}' |
|
|
if (os.path.exists(os.path.join(self.ckpt_dir, 'adapter_config.json')) |
|
|
or os.path.exists(os.path.join(self.ckpt_dir, 'default', 'adapter_config.json')) |
|
|
or os.path.exists(os.path.join(self.ckpt_dir, 'reft'))): |
|
|
if self.ckpt_dir in self.adapters: |
|
|
return |
|
|
self.adapters.insert(0, self.ckpt_dir) |
|
|
else: |
|
|
self.model = self.ckpt_dir |
|
|
self.ckpt_dir = None |
|
|
|
|
|
def __post_init__(self: 'BaseArguments'): |
|
|
if self.ckpt_dir is not None: |
|
|
self._handle_ckpt_dir() |
|
|
|
|
|
if len(self.lora_modules) > 0: |
|
|
self.adapters += self.lora_modules |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, |
|
|
ModelArguments): |
|
|
""" |
|
|
BaseArguments class is a dataclass that inherits from multiple argument classes: |
|
|
GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments. |
|
|
|
|
|
Args: |
|
|
tuner_backend(str): Support peft or unsloth. |
|
|
train_type(str): The training type, support all supported tuners and `full`. |
|
|
seed (int): Random seed for reproducibility. Default is 42. |
|
|
model_kwargs (Optional[str]): Additional keyword arguments for the model. Default is None. |
|
|
load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False. |
|
|
use_hf (bool): Flag to determine if Hugging Face should be used. Default is False. |
|
|
hub_token (Optional[str]): SDK token for authentication. Default is None. |
|
|
custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None. |
|
|
ignore_args_error (bool): Flag to ignore argument errors for notebook compatibility. Default is False. |
|
|
use_swift_lora (bool): Use swift lora, a compatible argument |
|
|
""" |
|
|
tuner_backend: Literal['peft', 'unsloth'] = 'peft' |
|
|
train_type: str = field(default='lora', metadata={'help': f'train_type choices: {list(get_supported_tuners())}'}) |
|
|
adapters: List[str] = field(default_factory=list) |
|
|
external_plugins: List[str] = field(default_factory=list) |
|
|
|
|
|
seed: int = 42 |
|
|
model_kwargs: Optional[Union[dict, str]] = None |
|
|
load_args: bool = True |
|
|
load_data_args: bool = False |
|
|
|
|
|
use_hf: bool = False |
|
|
|
|
|
hub_token: Optional[str] = field( |
|
|
default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'}) |
|
|
custom_register_path: List[str] = field(default_factory=list) |
|
|
|
|
|
|
|
|
ignore_args_error: bool = False |
|
|
use_swift_lora: bool = False |
|
|
|
|
|
def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: |
|
|
pass |
|
|
|
|
|
def _init_custom_register(self) -> None: |
|
|
"""Register custom .py file to datasets""" |
|
|
if isinstance(self.custom_register_path, str): |
|
|
self.custom_register_path = [self.custom_register_path] |
|
|
if not self.custom_register_path: |
|
|
return |
|
|
for path in self.custom_register_path: |
|
|
import_external_file(path) |
|
|
logger.info(f'Successfully registered {self.custom_register_path}.') |
|
|
|
|
|
def _import_external_plugins(self): |
|
|
if isinstance(self.external_plugins, str): |
|
|
self.external_plugins = [self.external_plugins] |
|
|
if not self.external_plugins: |
|
|
return |
|
|
for external_plugin in self.external_plugins: |
|
|
import_external_file(external_plugin) |
|
|
logger.info(f'Successfully imported external_plugins: {self.external_plugins}.') |
|
|
|
|
|
@staticmethod |
|
|
def _check_is_adapter(adapter_dir: str) -> bool: |
|
|
if (os.path.exists(os.path.join(adapter_dir, 'adapter_config.json')) |
|
|
or os.path.exists(os.path.join(adapter_dir, 'default', 'adapter_config.json')) |
|
|
or os.path.exists(os.path.join(adapter_dir, 'reft'))): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def _init_adapters(self): |
|
|
if isinstance(self.adapters, str): |
|
|
self.adapters = [self.adapters] |
|
|
self.adapters = [ |
|
|
safe_snapshot_download(adapter, use_hf=self.use_hf, hub_token=self.hub_token) for adapter in self.adapters |
|
|
] |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.use_hf or use_hf_hub(): |
|
|
self.use_hf = True |
|
|
os.environ['USE_HF'] = '1' |
|
|
CompatArguments.__post_init__(self) |
|
|
self._init_adapters() |
|
|
self._init_ckpt_dir() |
|
|
self._init_custom_register() |
|
|
self._import_external_plugins() |
|
|
self._init_model_kwargs() |
|
|
|
|
|
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting() |
|
|
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, ' |
|
|
f'world_size: {self.global_world_size}, local_world_size: {self.local_world_size}') |
|
|
if self.train_type not in extra_tuners: |
|
|
for adapter in self.adapters: |
|
|
assert self._check_is_adapter(adapter), ( |
|
|
f'`{adapter}` is not an adapter, please try using `--model` to pass it.') |
|
|
ModelArguments.__post_init__(self) |
|
|
QuantizeArguments.__post_init__(self) |
|
|
TemplateArguments.__post_init__(self) |
|
|
DataArguments.__post_init__(self) |
|
|
|
|
|
self.hub = get_hub(self.use_hf) |
|
|
if self.hub.try_login(self.hub_token): |
|
|
logger.info('hub login successful!') |
|
|
|
|
|
def _init_model_kwargs(self): |
|
|
"""Prepare model kwargs and set them to the env""" |
|
|
self.model_kwargs: Dict[str, Any] = self.parse_to_dict(self.model_kwargs) |
|
|
for k, v in self.model_kwargs.items(): |
|
|
k = k.upper() |
|
|
os.environ[k] = str(v) |
|
|
|
|
|
@property |
|
|
def is_adapter(self) -> bool: |
|
|
return self.train_type not in {'full'} |
|
|
|
|
|
@property |
|
|
def supported_tuners(self): |
|
|
return get_supported_tuners() |
|
|
|
|
|
@property |
|
|
def adapters_can_be_merged(self): |
|
|
return {'lora', 'longlora', 'llamapro', 'adalora'} |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, checkpoint_dir: str): |
|
|
self = super().__new__(cls) |
|
|
self.load_data_args = True |
|
|
self.ckpt_dir = checkpoint_dir |
|
|
self.load_args_from_ckpt() |
|
|
all_keys = list(f.name for f in fields(BaseArguments)) |
|
|
for key in all_keys: |
|
|
if not hasattr(self, key): |
|
|
setattr(self, key, None) |
|
|
return self |
|
|
|
|
|
def _init_ckpt_dir(self, adapters=None): |
|
|
|
|
|
model = self.model or getattr(self, 'mcore_model', None) or getattr(self, 'load', None) |
|
|
self.ckpt_dir = get_ckpt_dir(model, adapters or self.adapters) |
|
|
if self.ckpt_dir and self.load_args: |
|
|
self.load_args_from_ckpt() |
|
|
|
|
|
def load_args_from_ckpt(self) -> None: |
|
|
from ..train_args import TrainArguments |
|
|
args_path = os.path.join(self.ckpt_dir, 'args.json') |
|
|
assert os.path.exists(args_path), f'args_path: {args_path}' |
|
|
with open(args_path, 'r', encoding='utf-8') as f: |
|
|
old_args = json.load(f) |
|
|
all_keys = list(f.name for f in fields(BaseArguments)) |
|
|
data_keys = list(f.name for f in fields(DataArguments)) |
|
|
load_keys = [ |
|
|
|
|
|
'bnb_4bit_quant_type', |
|
|
'bnb_4bit_use_double_quant', |
|
|
|
|
|
'train_type', |
|
|
'tuner_backend', |
|
|
'use_swift_lora', |
|
|
|
|
|
'model_name', |
|
|
'model_author', |
|
|
'split_dataset_ratio', |
|
|
|
|
|
'use_chat_template', |
|
|
] |
|
|
skip_keys = list(f.name for f in fields(GenerationArguments) + fields(CompatArguments)) + ['adapters'] |
|
|
if not isinstance(self, TrainArguments): |
|
|
skip_keys += ['max_length'] |
|
|
all_keys = set(all_keys) - set(skip_keys) |
|
|
for key, old_value in old_args.items(): |
|
|
if key not in all_keys or old_value is None: |
|
|
continue |
|
|
if not self.load_data_args and key in data_keys: |
|
|
continue |
|
|
value = getattr(self, key, None) |
|
|
if value is None or isinstance(value, (list, tuple)) and len(value) == 0 or key in load_keys: |
|
|
setattr(self, key, old_value) |
|
|
logger.info(f'Successfully loaded {args_path}.') |
|
|
|
|
|
def save_args(self, output_dir=None) -> None: |
|
|
if is_master(): |
|
|
output_dir = output_dir or self.output_dir |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
fpath = os.path.join(output_dir, 'args.json') |
|
|
logger.info(f'The {self.__class__.__name__} will be saved in: {fpath}') |
|
|
with open(fpath, 'w', encoding='utf-8') as f: |
|
|
json.dump(check_json_format(self.__dict__), f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def _init_device(self): |
|
|
if is_dist(): |
|
|
set_device() |
|
|
|
|
|
def get_template(self, processor: 'Processor', template_type: Optional[str] = None) -> 'Template': |
|
|
template_kwargs = self.get_template_kwargs() |
|
|
template_type = template_type or self.template |
|
|
template = get_template(template_type, processor, **template_kwargs) |
|
|
return template |
|
|
|
|
|
def get_model_processor(self, |
|
|
*, |
|
|
model=None, |
|
|
model_type=None, |
|
|
model_revision=None, |
|
|
task_type=None, |
|
|
num_labels=None, |
|
|
**kwargs): |
|
|
if self.tuner_backend == 'unsloth': |
|
|
return load_by_unsloth(self) |
|
|
kwargs.update(self.get_model_kwargs()) |
|
|
|
|
|
kwargs['model_id_or_path'] = model or self.model |
|
|
kwargs['model_type'] = model_type or self.model_type |
|
|
kwargs['model_revision'] = model_revision or self.model_revision |
|
|
kwargs['task_type'] = task_type or self.task_type |
|
|
kwargs['num_labels'] = num_labels or self.num_labels |
|
|
|
|
|
return get_model_tokenizer(**kwargs) |
|
|
|