| |
| import os |
| import platform |
| import re |
| from copy import deepcopy |
| from dataclasses import asdict, dataclass, field |
| from functools import partial |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import torch |
| import transformers |
| from packaging import version |
| from peft import PeftModel |
| from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, |
| AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase) |
| from transformers.integrations import is_deepspeed_zero3_enabled |
| from transformers.utils import (is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_mps_available, |
| is_torch_npu_available, strtobool) |
| from transformers.utils.versions import require_version |
|
|
| from swift.utils import get_dist_setting, get_logger, is_mp, is_unsloth_available, patch_getattr, use_torchacc |
| from .constant import ModelType |
| from .patcher import (patch_automodel, patch_automodel_for_sequence_classification, patch_get_dynamic_module, |
| patch_mp_ddp, patch_tp_plan) |
| from .utils import AttnImpl, HfConfigFactory, InitModelStrategy, ModelInfo, safe_snapshot_download |
|
|
| GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]] |
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class Model: |
| ms_model_id: Optional[str] = None |
| hf_model_id: Optional[str] = None |
| model_path: Optional[str] = None |
|
|
| ms_revision: Optional[str] = None |
| hf_revision: Optional[str] = None |
|
|
|
|
| @dataclass |
| class ModelGroup: |
| models: List[Model] |
|
|
| |
| ignore_patterns: Optional[List[str]] = None |
| requires: Optional[List[str]] = None |
| tags: List[str] = field(default_factory=list) |
|
|
| def __post_init__(self): |
| if not isinstance(self.models, (tuple, list)): |
| self.models = [self.models] |
|
|
|
|
| @dataclass |
| class ModelMeta: |
| model_type: Optional[str] |
| |
| |
| model_groups: List[ModelGroup] |
| template: Optional[str] |
| get_function: GetModelTokenizerFunction |
|
|
| model_arch: Optional[str] = None |
| architectures: List[str] = field(default_factory=list) |
| |
| additional_saved_files: List[str] = field(default_factory=list) |
| torch_dtype: Optional[torch.dtype] = None |
|
|
| is_multimodal: bool = False |
| is_reward: bool = False |
| task_type: Optional[str] = None |
|
|
| |
| ignore_patterns: Optional[List[str]] = None |
| |
| requires: List[str] = field(default_factory=list) |
| tags: List[str] = field(default_factory=list) |
|
|
| def __post_init__(self): |
| if self.template is None: |
| self.template = 'dummy' |
| if not isinstance(self.model_groups, (list, tuple)): |
| self.model_groups = [self.model_groups] |
|
|
| def get_matched_model_group(self, model_name: str) -> Optional[ModelGroup]: |
| for model_group in self.model_groups: |
| for model in model_group.models: |
| for key in ['ms_model_id', 'hf_model_id', 'model_path']: |
| value = getattr(model, key) |
|
|
| if isinstance(value, str) and model_name == value.rsplit('/', 1)[-1].lower(): |
| return model_group |
|
|
| def check_requires(self, model_info=None): |
| extra_requires = [] |
| if model_info and model_info.quant_method: |
| mapping = {'bnb': ['bitsandbytes'], 'awq': ['autoawq'], 'gptq': ['auto_gptq'], 'aqlm': ['aqlm']} |
| extra_requires += mapping.get(model_info.quant_method, []) |
| requires = [] |
| for require in self.requires + extra_requires: |
| try: |
| require_version(require) |
| except ImportError: |
| requires.append(f'"{require}"') |
| if requires: |
| requires = ' '.join(requires) |
| logger.warning(f'Please install the package: `pip install {requires} -U`.') |
|
|
|
|
| MODEL_MAPPING: Dict[str, ModelMeta] = {} |
|
|
|
|
| def register_model(model_meta: ModelMeta, *, exist_ok: bool = False) -> None: |
| """ |
| model_type: The unique ID for the model type. Models with the same model_type share |
| the same architectures, template, get_function, etc. |
| """ |
| model_type = model_meta.model_type |
| if not exist_ok and model_type in MODEL_MAPPING: |
| raise ValueError(f'The `{model_type}` has already been registered in the MODEL_MAPPING.') |
| from .constant import MLLMModelType, RMModelType |
| if model_type in MLLMModelType.__dict__: |
| model_meta.is_multimodal = True |
| if model_type in RMModelType.__dict__: |
| model_meta.is_reward = True |
| MODEL_MAPPING[model_type] = model_meta |
|
|
|
|
| def load_by_unsloth(args): |
| """Load model by unsloth""" |
| assert is_unsloth_available(), 'please install unsloth if using `use_unsloth=True`: `pip install unsloth`' |
| os.environ['UNSLOTH_RETURN_LOGITS'] = '1' |
| os.environ['UNSLOTH_DISABLE_STATISTICS'] = '1' |
| model_info = args.model_info |
| model_meta = args.model_meta |
| if model_meta.is_multimodal: |
| from unsloth import FastVisionModel as UnslothModel |
| else: |
| from unsloth import FastLanguageModel as UnslothModel |
| model, processor = UnslothModel.from_pretrained( |
| model_name=args.adapters and args.adapters[0] or args.model_dir, |
| dtype=args.torch_dtype, |
| max_seq_length=args.max_length, |
| full_finetuning=args.quant_bits is None, |
| load_in_4bit=args.quant_bits == 4, |
| load_in_8bit=args.quant_bits == 8, |
| ) |
| if isinstance(model, PeftModel): |
| base_model = model.model |
| else: |
| base_model = model |
| base_model.model_dir = args.model_dir |
| base_model.model_info = model_info |
| base_model.model_meta = model_meta |
| processor.model_info = model_info |
| processor.model_meta = model_meta |
| return model, processor |
|
|
|
|
| def _patch_awq_compat(model_info): |
| if version.parse(transformers.__version__) < version.parse('4.50') or model_info.quant_method != 'awq': |
| return |
|
|
| try: |
| |
| from transformers.quantizers.quantizer_awq import AwqQuantizer |
| from transformers.integrations import get_keys_to_not_convert |
| _process_model_before_weight_loading = AwqQuantizer._process_model_before_weight_loading |
|
|
| def _new_process_model_before_weight_loading(self, model, *args, **kwargs): |
| modules_to_not_convert = self.quantization_config.modules_to_not_convert |
| if modules_to_not_convert is not None: |
| self.quantization_config.modules_to_not_convert = list( |
| modules_to_not_convert) + get_keys_to_not_convert(model) |
| return _process_model_before_weight_loading(self, model, *args, **kwargs) |
|
|
| AwqQuantizer._process_model_before_weight_loading = _new_process_model_before_weight_loading |
| except Exception: |
| pass |
|
|
|
|
| def get_model_tokenizer_from_local(model_dir: str, |
| model_info: ModelInfo, |
| model_kwargs: Dict[str, Any], |
| load_model: bool = True, |
| *, |
| tokenizer=None, |
| model_config=None, |
| automodel_class=None, |
| **kwargs): |
| """Load the model and tokenizer from the local model_dir.""" |
| if model_config is None: |
| model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
| |
| if not hasattr(model_config, 'keys_to_ignore_at_inference'): |
| model_config.keys_to_ignore_at_inference = [] |
| if 'past_key_values' not in model_config.keys_to_ignore_at_inference: |
| model_config.keys_to_ignore_at_inference.append('past_key_values') |
|
|
| torch_dtype = model_info.torch_dtype |
| model_config.torch_dtype = torch_dtype |
| HfConfigFactory.compat_zero3(model_config) |
| rope_scaling = kwargs.get('rope_scaling') |
| if rope_scaling: |
| HfConfigFactory.set_config_attr(model_config, 'rope_scaling', rope_scaling) |
|
|
| if tokenizer is None: |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
|
| num_labels = model_info.num_labels or getattr(model_config, 'num_labels', None) |
| if num_labels and model_info.task_type == 'seq_cls': |
| model_info.num_labels = num_labels |
| model_config.num_labels = num_labels |
|
|
| model = None |
| if load_model: |
| _patch_awq_compat(model_info) |
| logger.info(f'model_kwargs: {model_kwargs}') |
| |
| if model_info.task_type == 'seq_cls' and automodel_class is None: |
| try: |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs) |
| except ValueError: |
| model = None |
|
|
| automodel_class = automodel_class or AutoModelForCausalLM |
| model_meta = kwargs['model_meta'] |
| if model is None: |
| if model_info.task_type == 'seq_cls' and not model_meta.is_reward: |
| context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta) |
| elif model_info.task_type == 'seq_cls' and model_meta.is_reward and model_config.num_labels > 1: |
| logger.warning('You are using a reward model for seq_cls task and num_labels > 1, ' |
| 'ignore_mismatched_sizes will be set to True') |
| model_kwargs['ignore_mismatched_sizes'] = True |
| context = partial(patch_automodel_for_sequence_classification, model_meta=model_meta) |
| else: |
| context = partial(patch_automodel, automodel_class=automodel_class, model_info=model_info) |
| with context(): |
| model = automodel_class.from_pretrained( |
| model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs) |
|
|
| |
| |
| has_remote_code = hasattr(model_config, 'auto_map') and automodel_class.__name__ in model_config.auto_map |
| if has_remote_code and model._auto_class is None: |
| model._auto_class = automodel_class.__name__ |
|
|
| if model_info.task_type == 'embedding' and automodel_class.__name__ != 'AutoModel': |
| from swift.llm.model.patcher import patch_output_normalizer |
| patch_output_normalizer(model, model_meta=model_meta) |
|
|
| init_strategy = kwargs.get('init_strategy') |
| if init_strategy is not None: |
| InitModelStrategy.init_parameters(model, init_strategy) |
|
|
| model_info.config = model_config if model is None else model.config |
| if model: |
| |
| pad_token_id = model.config.pad_token_id or tokenizer.pad_token_id |
| HfConfigFactory.set_model_config_attr(model, 'pad_token_id', pad_token_id) |
| return model, tokenizer |
|
|
|
|
| def get_model_tokenizer_with_flash_attn(model_dir: str, |
| model_info: ModelInfo, |
| model_kwargs: Dict[str, Any], |
| load_model: bool = True, |
| **kwargs): |
| model_config = kwargs.get('model_config') |
| if model_config is None: |
| model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
| AttnImpl.update_attn_impl(model_config, kwargs.get('attn_impl'), kwargs.get('attn_impl_keys')) |
| kwargs['model_config'] = model_config |
| return get_model_tokenizer_from_local(model_dir, model_info, model_kwargs, load_model, **kwargs) |
|
|
|
|
| def get_model_tokenizer_multimodal(model_dir: str, *args, **kwargs): |
| from transformers import AutoProcessor |
| processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) |
| kwargs['tokenizer'] = processor.tokenizer |
| model, _ = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs) |
| return model, processor |
|
|
|
|
| def get_model_tokenizer_reward_model(model_dir, *args, **kwargs): |
| model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
| if 'AutoModel' in (getattr(model_config, 'auto_map', None) or {}): |
| kwargs['automodel_class'] = AutoModel |
| return get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs) |
|
|
|
|
| def fix_do_sample_warning(generation_config: GenerationConfig) -> None: |
| |
| if generation_config.temperature == 0: |
| generation_config.do_sample = False |
| if generation_config.do_sample is False: |
| generation_config.temperature = 1. |
| generation_config.top_p = 1. |
| generation_config.top_k = 50 |
|
|
|
|
| def get_default_device_map(): |
| if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true': |
| return None |
| local_rank = get_dist_setting()[1] |
| if local_rank == -1: |
| local_rank = 0 |
| if is_torch_npu_available(): |
| return 'auto' if is_mp() else f'npu:{local_rank}' |
| elif is_torch_mps_available(): |
| return f'mps:{local_rank}' |
| elif is_torch_cuda_available(): |
| return 'auto' if is_mp() else f'cuda:{local_rank}' |
| else: |
| return 'cpu' |
|
|
|
|
| def get_default_torch_dtype(torch_dtype: Optional[torch.dtype]): |
| |
| if torch_dtype is not None: |
| return torch_dtype |
|
|
| try: |
| is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() |
| and torch.npu.is_bf16_supported()) |
| except: |
| is_bf16_available = False |
|
|
| if is_torch_cuda_available() or is_torch_npu_available(): |
| if is_bf16_available: |
| return torch.bfloat16 |
| else: |
| return torch.float16 |
| else: |
| |
| return torch.float32 |
|
|
|
|
| def get_model_name(model_id_or_path: str) -> Optional[str]: |
| assert isinstance(model_id_or_path, str), f'model_id_or_path: {model_id_or_path}' |
| |
| model_id_or_path = model_id_or_path.rstrip('/') |
| match_ = re.search('/models--.+?--(.+?)/snapshots/', model_id_or_path) |
| if match_ is not None: |
| return match_.group(1) |
|
|
| model_name = model_id_or_path.rsplit('/', 1)[-1] |
| if platform.system().lower() == 'windows': |
| model_name = model_name.rsplit('\\', 1)[-1] |
| |
| model_name = model_name.replace('___', '.') |
| return model_name |
|
|
|
|
| def get_all_models() -> List[str]: |
| use_hf = strtobool(os.environ.get('USE_HF', 'False')) |
| models = [] |
| for model_type in ModelType.get_model_name_list(): |
| model_meta = MODEL_MAPPING.get(model_type) |
| if model_meta: |
| for group in model_meta.model_groups: |
| for model in group.models: |
| if use_hf: |
| if model.hf_model_id: |
| models.append(model.hf_model_id) |
| else: |
| if model.ms_model_id: |
| models.append(model.ms_model_id) |
| return models |
|
|
|
|
| def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]: |
| model_name = get_model_name(model_id_or_path).lower() |
| for model_type, model_meta in MODEL_MAPPING.items(): |
| model_group = ModelMeta.get_matched_model_group(model_meta, model_name) |
| if model_group is not None: |
| model_meta = deepcopy(model_meta) |
| for k, v in asdict(model_group).items(): |
| if v is not None and k in model_meta.__dict__: |
| setattr(model_meta, k, v) |
| return model_meta |
|
|
|
|
| def _get_arch_mapping(): |
| res = {} |
| for model_type, model_meta in MODEL_MAPPING.items(): |
| architectures = model_meta.architectures |
| if not architectures: |
| architectures.append('null') |
| for arch in architectures: |
| if arch not in res: |
| res[arch] = [] |
| res[arch].append(model_type) |
| return res |
|
|
|
|
| def get_matched_model_types(architectures: Optional[List[str]]) -> List[str]: |
| """Get possible model_type.""" |
| architectures = architectures or ['null'] |
| if architectures: |
| architectures = architectures[0] |
| arch_mapping = _get_arch_mapping() |
| return arch_mapping.get(architectures) or [] |
|
|
|
|
| def _read_args_json_model_type(model_dir): |
| if not os.path.exists(os.path.join(model_dir, 'args.json')): |
| return |
| from swift.llm import BaseArguments |
| args = BaseArguments.from_pretrained(model_dir) |
| return args.model_type |
|
|
|
|
| def _get_model_info(model_dir: str, model_type: Optional[str], quantization_config) -> ModelInfo: |
| try: |
| config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) |
| except Exception: |
| config = PretrainedConfig.get_config_dict(model_dir)[0] |
| if quantization_config is not None: |
| HfConfigFactory.set_config_attr(config, 'quantization_config', quantization_config) |
| quant_info = HfConfigFactory.get_quant_info(config) or {} |
| torch_dtype = HfConfigFactory.get_torch_dtype(config, quant_info) |
| max_model_len = HfConfigFactory.get_max_model_len(config) |
| rope_scaling = HfConfigFactory.get_config_attr(config, 'rope_scaling') |
|
|
| if model_type is None: |
| model_type = _read_args_json_model_type(model_dir) |
| if model_type is None: |
| architectures = HfConfigFactory.get_config_attr(config, 'architectures') |
| model_types = get_matched_model_types(architectures) |
| if len(model_types) > 1: |
| raise ValueError('Please explicitly pass the model_type. For reference, ' |
| f'the available model_types: {model_types}.') |
| elif len(model_types) == 1: |
| model_type = model_types[0] |
| elif model_type not in MODEL_MAPPING: |
| raise ValueError(f"model_type: '{model_type}' not in {list(MODEL_MAPPING.keys())}") |
|
|
| res = ModelInfo( |
| model_type, |
| model_dir, |
| torch_dtype, |
| max_model_len, |
| quant_info.get('quant_method'), |
| quant_info.get('quant_bits'), |
| rope_scaling=rope_scaling) |
| return res |
|
|
|
|
| def get_model_info_meta( |
| model_id_or_path: str, |
| torch_dtype: Optional[torch.dtype] = None, |
| *, |
| |
| use_hf: Optional[bool] = None, |
| hub_token: Optional[str] = None, |
| revision: Optional[str] = None, |
| download_model: bool = False, |
| |
| model_type: Optional[str] = None, |
| quantization_config=None, |
| task_type=None, |
| num_labels=None, |
| **kwargs) -> Tuple[ModelInfo, ModelMeta]: |
| model_meta = get_matched_model_meta(model_id_or_path) |
| model_dir = safe_snapshot_download( |
| model_id_or_path, |
| revision=revision, |
| download_model=download_model, |
| use_hf=use_hf, |
| ignore_patterns=getattr(model_meta, 'ignore_patterns', None), |
| hub_token=hub_token) |
|
|
| model_type = model_type or getattr(model_meta, 'model_type', None) |
| model_info = _get_model_info(model_dir, model_type, quantization_config=quantization_config) |
| if model_type is None and model_info.model_type is not None: |
| model_type = model_info.model_type |
| logger.info(f'Setting model_type: {model_type}') |
| if model_meta is None and model_type is not None: |
| model_meta = MODEL_MAPPING[model_type] |
| if model_meta is None: |
| model_meta = ModelMeta(None, [], 'dummy', get_model_tokenizer_from_local, model_arch=None) |
| logger.info(f'Temporarily create model_meta: {model_meta}') |
|
|
| if torch_dtype is None: |
| torch_dtype = model_meta.torch_dtype or get_default_torch_dtype(model_info.torch_dtype) |
| logger.info(f'Setting torch_dtype: {torch_dtype}') |
| model_info.torch_dtype = torch_dtype |
| if task_type is None: |
| if model_meta.is_reward: |
| num_labels = 1 |
| if num_labels is None: |
| task_type = 'causal_lm' |
| else: |
| task_type = 'seq_cls' |
| if task_type == 'seq_cls': |
| assert num_labels is not None, 'Please pass the parameter `num_labels`.' |
| if model_meta.task_type is not None: |
| task_type = model_meta.task_type |
| model_info.task_type = task_type |
| model_info.num_labels = num_labels |
|
|
| model_meta.check_requires(model_info) |
| return model_info, model_meta |
|
|
|
|
| def get_model_tokenizer( |
| model_id_or_path: str, |
| torch_dtype: Optional[torch.dtype] = None, |
| device_map: Union[str, Dict[str, Any], None] = None, |
| *, |
| load_model: bool = True, |
| |
| use_hf: Optional[bool] = None, |
| hub_token: Optional[str] = None, |
| revision: Optional[str] = None, |
| download_model: Optional[bool] = None, |
| |
| model_type: Optional[str] = None, |
| quantization_config=None, |
| max_memory: Union[str, Dict[str, Any]] = None, |
| attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None, |
| rope_scaling: Optional[Dict[str, Any]] = None, |
| automodel_class=None, |
| task_type: Literal['causal_lm', 'seq_cls'] = None, |
| num_labels: Optional[int] = None, |
| model_kwargs: Optional[Dict[str, Any]] = None, |
| **kwargs) -> Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]: |
| """ |
| model_id_or_path: The path to the model or the model_id from modelscope/huggingface (controlled by `use_hf`). |
| torch_dtype: If you pass `None`, it will retrieve the torch_dtype from the config.json file. |
| model_kwargs: Passed to `automodel_class.from_pretrained`. |
| load_model: Whether to load the model. If set to False, the model will return `None`. |
| use_hf: Indicates whether the model download hub is modelscope or huggingface. |
| model_type: If it is not possible to uniquely determine the model_type from the architecture in config.json, |
| it needs to be provided. |
| attn_impl: If set to 'flash_attn': It will automatically convert names based on the model. |
| If set to None : It will be automatically selected between sdpa and eager. |
| download_model: Whether to download the model weights. If `None`, it will be selected based on load_model. |
| """ |
| patch_mp_ddp() |
| if model_kwargs is None: |
| model_kwargs = {} |
| if download_model is None: |
| download_model = load_model |
|
|
| model_info, model_meta = get_model_info_meta( |
| model_id_or_path, |
| torch_dtype, |
| use_hf=use_hf, |
| hub_token=hub_token, |
| revision=revision, |
| download_model=download_model, |
| model_type=model_type, |
| quantization_config=quantization_config, |
| task_type=task_type, |
| num_labels=num_labels) |
|
|
| if not use_torchacc() and device_map is None: |
| device_map = get_default_device_map() |
| model_kwargs['device_map'] = device_map |
| if quantization_config: |
| model_kwargs['quantization_config'] = quantization_config |
| if max_memory: |
| model_kwargs['max_memory'] = max_memory |
| model_dir = model_info.model_dir |
| get_function = model_meta.get_function |
| kwargs['automodel_class'] = automodel_class |
| kwargs['attn_impl'] = attn_impl |
| kwargs['rope_scaling'] = rope_scaling |
| kwargs['model_meta'] = model_meta |
| with patch_get_dynamic_module(), patch_tp_plan(): |
| model, processor = get_function(model_dir, model_info, model_kwargs, load_model, **kwargs) |
|
|
| if not isinstance(processor, PreTrainedTokenizerBase) and hasattr(processor, 'tokenizer'): |
| tokenizer = processor.tokenizer |
| patch_getattr(processor.__class__, 'tokenizer') |
| else: |
| tokenizer = processor |
| problem_type = kwargs.get('problem_type') |
| if problem_type is None and model_info.num_labels == 1: |
| problem_type = 'regression' |
| if problem_type is not None: |
| model_info.config.problem_type = problem_type |
| tokenizer.model_info = model_info |
| tokenizer.model_meta = model_meta |
|
|
| pad_token = tokenizer.pad_token_id |
| if pad_token is None: |
| pad_token = tokenizer.eos_token_id |
| if tokenizer.eos_token_id is None: |
| tokenizer.eos_token_id = pad_token |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token_id = pad_token |
| assert tokenizer.eos_token_id is not None |
| assert tokenizer.pad_token_id is not None |
|
|
| if model is not None: |
| model.model_info = model_info |
| model.model_meta = model_meta |
| model.model_dir = model_dir |
|
|
| |
| generation_config_path = os.path.join(model_dir, 'generation_config.json') |
| if not hasattr(model, 'generation_config') and os.path.isfile(generation_config_path): |
| model.generation_config = GenerationConfig.from_pretrained(model_dir) |
| |
| if getattr(model, 'generation_config', None): |
| fix_do_sample_warning(model.generation_config) |
| return model, processor |
|
|