# Copyright (c) Alibaba, Inc. and its affiliates. import os from contextlib import contextmanager from functools import wraps from types import MethodType from typing import Dict, List, Optional, Union import accelerate import torch import torch.nn as nn import transformers from accelerate.utils import find_device from packaging import version from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn.parallel import DistributedDataParallel as DDP from transformers import PreTrainedModel, dynamic_module_utils, trainer from transformers.modeling_outputs import SequenceClassifierOutputWithPast from swift.llm import to_device, to_float_dtype from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count from .model_arch import get_model_arch from .utils import HfConfigFactory logger = get_logger() def patch_fixed_float_dtype(module: torch.nn.Module, dtype): """Patch the module, to make sure the consisitent dtype.""" def get_float_dtype_hook(dtype): def _float_dtype_hook(module, input, output): return to_float_dtype(output, dtype) return _float_dtype_hook module.register_forward_hook(get_float_dtype_hook(dtype)) def patch_fixed_device(module: torch.nn.Module, device): """Move the output to the specific device""" def get_device_hook(device): def _device_hook(module, input, output): return to_device(output, device) return _device_hook module.register_forward_hook(get_device_hook(device)) def patch_output_clone(module: torch.nn.Module): """Clone the output, to avoid the inplace problem""" def _clone_hook(module, input, output): return output.requires_grad_(True).clone() module.register_forward_hook(_clone_hook) def patch_output_normalizer(module: torch.nn.Module, model_meta): def lm_head_forward(self, hidden_states): return hidden_states lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer'] llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) if llm_prefix: llm_model = getattr(module, llm_prefix[0]) else: llm_model = module if 'CausalLM' not in llm_model.__class__.__name__: llm_model = module found = False for lm_head in lm_heads: if hasattr(llm_model, lm_head): getattr(llm_model, lm_head).forward = MethodType(lm_head_forward, getattr(llm_model, lm_head)) found = True break assert found, 'Cannot find the proper lm_head name' def forward(self, input_ids: torch.LongTensor = None, attention_mask=None, *args, **kwargs): outputs = self.forward_origin(input_ids=input_ids, attention_mask=attention_mask, *args, **kwargs) hidden_states = outputs.logits left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: embeddings = hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = hidden_states.shape[0] embeddings = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return { 'last_hidden_state': embeddings.contiguous(), } llm_model.forward_origin = llm_model.forward llm_model.forward = MethodType(forward, llm_model) def patch_output_to_input_device(module: torch.nn.Module): """Patch the module, to make sure the output is in the same device with the input. Args: module: The module to be patched """ def _output_to_input_device_hook(module, args, kwargs, output): device = find_device(args) or find_device(kwargs) return to_device(output, device) module.register_forward_hook(_output_to_input_device_hook, with_kwargs=True) @contextmanager def patch_device_map(): _get_no_split_modules = PreTrainedModel._get_no_split_modules def _new_get_no_split_modules(self, device_map: str): for module in self.modules(): if isinstance(module, PreTrainedModel) and module._no_split_modules is None: module.__class__._no_split_modules = [] return _get_no_split_modules(self, device_map) PreTrainedModel._get_no_split_modules = _new_get_no_split_modules try: yield finally: PreTrainedModel._get_no_split_modules = _get_no_split_modules @contextmanager def patch_ignore_check_imports(): import transformers.dynamic_module_utils as td def _check_imports(filename) -> List[str]: return td.get_relative_imports(filename) _old_check_imports = td.check_imports td.check_imports = _check_imports try: yield finally: td.check_imports = _old_check_imports def _patch_sequence_classification(model, model_meta): hidden_size = HfConfigFactory.get_config_attr(model.config, 'hidden_size') initializer_range = HfConfigFactory.get_config_attr(model.config, 'initializer_range') lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer'] llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) if llm_prefix: llm_model = getattr(model, llm_prefix[0]) else: llm_model = model if 'CausalLM' not in llm_model.__class__.__name__: # fix qwen2_vl llm_model = model llm_model.num_labels = model.config.num_labels llm_model.score = nn.Linear(hidden_size, llm_model.num_labels, bias=False, dtype=llm_model.dtype) if llm_model.score.weight.device == torch.device('meta'): llm_model.score.to_empty(device='cpu') llm_model.score.weight.data.normal_(mean=0.0, std=initializer_range) for lm_head in lm_heads: if hasattr(llm_model, lm_head): setattr(llm_model, lm_head, nn.Identity()) break origin_forward = llm_model.forward.__func__ @wraps(origin_forward) def new_forward(self, *args, **kwargs): labels = kwargs.pop('labels', None) return_dict = kwargs.pop('return_dict', None) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids = kwargs.get('input_ids') inputs_embeds = kwargs.get('inputs_embeds') output = origin_forward(self, *args, **kwargs) output.logits = output.logits.to(self.score.weight.dtype) logits = self.score(output.logits) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = 'regression' elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = 'single_label_classification' else: self.config.problem_type = 'multi_label_classification' if self.config.problem_type == 'regression': loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == 'single_label_classification': loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == 'multi_label_classification': loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits, ) + output[1:] return ((loss, ) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=output.past_key_values, hidden_states=output.hidden_states, attentions=output.attentions, ) llm_model.forward = MethodType(new_forward, llm_model) @contextmanager def patch_automodel_for_sequence_classification(model_meta): from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def _new_from_pretrained(cls, *args, **kwargs): __init__ = cls.__init__ def __new_init__(self, *args, **kwargs): __init__(self, *args, **kwargs) _patch_sequence_classification(self, model_meta) cls.__init__ = __new_init__ if hasattr(cls, '_tp_plan'): # fix tp_plan cls._tp_plan = cls._tp_plan or {} res = from_pretrained(cls, *args, **kwargs) cls.__init__ = __init__ return res PreTrainedModel.from_pretrained = _new_from_pretrained try: yield finally: PreTrainedModel.from_pretrained = classmethod(from_pretrained) @contextmanager def patch_automodel(automodel_class, model_info): from_pretrained = PreTrainedModel.from_pretrained.__func__ @classmethod def _new_from_pretrained(cls, *args, **kwargs): if 'AutoAWQFor' in automodel_class.__name__: kwargs.pop('use_cache', None) if model_info.quant_method == 'gptq': cls.main_input_name = 'input_ids' if hasattr(cls, '_tp_plan'): # fix tp_plan cls._tp_plan = cls._tp_plan or {} model = from_pretrained(cls, *args, **kwargs) return model PreTrainedModel.from_pretrained = _new_from_pretrained try: yield finally: PreTrainedModel.from_pretrained = classmethod(from_pretrained) _mp_ddp_patched = False def patch_mp_ddp(): """Patch ddp with device_map. After patching, the ddp can run with the device_map. This should be called before any training starts. """ global _mp_ddp_patched if is_mp_ddp() and not _mp_ddp_patched: _mp_ddp_patched = True from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map @wraps(infer_auto_device_map) def _infer_auto_device_map_patch(model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, **kwargs) -> Dict[str, Union[int, str, torch.device]]: """The auxiliary function for supports MP + DDP. Monkey Patching. add feat in accelerate to support MP + DDP""" verbose = kwargs.pop('verbose', False) n_gpu = get_device_count() _, local_rank, _, local_world_size = get_dist_setting() device_ids = list(range(local_rank, n_gpu, local_world_size)) max_memory = _get_max_memory(device_ids) max_memory = _sync_max_memory(max_memory) max_memory = get_balanced_memory(model, max_memory, low_zero=False, **kwargs) max_memory = {k: v for k, v in max_memory.items() if v > 0} return infer_auto_device_map(model, max_memory, verbose=verbose, **kwargs) _old_ddp_init = DDP.__init__ accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = ( lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs)) transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch if is_mp_ddp() or use_torchacc(): _old_accelerator_init = trainer.Accelerator.__init__ trainer.Accelerator.__init__ = (lambda self, device_placement=False, *args, **kwargs: _old_accelerator_init( self, device_placement=device_placement, *args, **kwargs)) trainer.Accelerator.verify_device_map = lambda *args, **kwargs: False @contextmanager def patch_get_dynamic_module(): origin_get_cached_module_file = dynamic_module_utils.get_cached_module_file def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs): with safe_ddp_context(hash_id=str(pretrained_model_name_or_path)): return origin_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs) dynamic_module_utils.get_cached_module_file = new_get_cached_module_file try: yield finally: dynamic_module_utils.get_cached_module_file = origin_get_cached_module_file @contextmanager def patch_tp_plan(): if not is_mp_ddp() or version.parse(transformers.__version__) < version.parse('4.50'): yield return WORLD_SIZE = os.environ.get('WORLD_SIZE') os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE os.environ.pop('WORLD_SIZE') yield os.environ['WORLD_SIZE'] = WORLD_SIZE