| |
| import os |
| from dataclasses import dataclass |
| from typing import Literal, Optional |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| from swift.utils import get_logger, init_process_group, set_default_ddp_config |
| from .base_args import BaseArguments, to_abspath |
| from .merge_args import MergeArguments |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class ExportArguments(MergeArguments, BaseArguments): |
| """ |
| ExportArguments is a dataclass that inherits from BaseArguments and MergeArguments. |
| |
| Args: |
| output_dir (Optional[str]): Directory where the output will be saved. |
| quant_n_samples (int): Number of samples for quantization. |
| max_length (int): Sequence length for quantization. |
| quant_batch_size (int): Batch size for quantization. |
| to_ollama (bool): Flag to indicate export model to ollama format. |
| push_to_hub (bool): Flag to indicate if the output should be pushed to the model hub. |
| hub_model_id (Optional[str]): Model ID for the hub. |
| hub_private_repo (bool): Flag to indicate if the hub repository is private. |
| commit_message (str): Commit message for pushing to the hub. |
| to_peft_format (bool): Flag to indicate if the output should be in PEFT format. |
| This argument is useless for now. |
| """ |
| output_dir: Optional[str] = None |
|
|
| |
| quant_method: Literal['awq', 'gptq', 'bnb'] = None |
| quant_n_samples: int = 256 |
| max_length: int = 2048 |
| quant_batch_size: int = 1 |
| group_size: int = 128 |
|
|
| |
| to_ollama: bool = False |
|
|
| |
| to_mcore: bool = False |
| to_hf: bool = False |
| mcore_model: Optional[str] = None |
| thread_count: Optional[int] = None |
| test_convert_precision: bool = False |
|
|
| |
| push_to_hub: bool = False |
| |
| hub_model_id: Optional[str] = None |
| hub_private_repo: bool = False |
| commit_message: str = 'update files' |
| |
| to_peft_format: bool = False |
| exist_ok: bool = False |
|
|
| def _init_output_dir(self): |
| if self.output_dir is None: |
| ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' |
| ckpt_dir, ckpt_name = os.path.split(ckpt_dir) |
| if self.to_peft_format: |
| suffix = 'peft' |
| elif self.quant_method: |
| suffix = f'{self.quant_method}-int{self.quant_bits}' |
| elif self.to_ollama: |
| suffix = 'ollama' |
| elif self.merge_lora: |
| suffix = 'merged' |
| elif self.to_mcore: |
| suffix = 'mcore' |
| elif self.to_hf: |
| suffix = 'hf' |
| else: |
| return |
|
|
| self.output_dir = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}') |
|
|
| self.output_dir = to_abspath(self.output_dir) |
| if not self.exist_ok and os.path.exists(self.output_dir): |
| raise FileExistsError(f'args.output_dir: `{self.output_dir}` already exists.') |
| logger.info(f'args.output_dir: `{self.output_dir}`') |
|
|
| def __post_init__(self): |
| if self.quant_batch_size == -1: |
| self.quant_batch_size = None |
| if self.quant_bits and self.quant_method is None: |
| raise ValueError('Please specify the quantization method using `--quant_method awq/gptq/bnb`.') |
| if self.quant_method and self.quant_bits is None: |
| raise ValueError('Please specify `--quant_bits`.') |
| if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None: |
| self.torch_dtype = torch.float16 |
| if self.to_mcore or self.to_hf: |
| self.mcore_model = to_abspath(self.mcore_model, check_path_exist=True) |
| if not dist.is_initialized(): |
| set_default_ddp_config() |
| init_process_group() |
|
|
| BaseArguments.__post_init__(self) |
| self._init_output_dir() |
| if self.quant_method in {'gptq', 'awq'} and len(self.dataset) == 0: |
| raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.') |
|
|