| |
| import os |
| from contextlib import contextmanager |
| from functools import wraps |
| from types import MethodType |
| from typing import Dict, List, Optional, Union |
|
|
| import accelerate |
| import torch |
| import torch.nn as nn |
| import transformers |
| from accelerate.utils import find_device |
| from packaging import version |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from transformers import PreTrainedModel, dynamic_module_utils, trainer |
| from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
|
|
| from swift.llm import to_device, to_float_dtype |
| from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc |
| from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count |
| from .model_arch import get_model_arch |
| from .utils import HfConfigFactory |
|
|
| logger = get_logger() |
|
|
|
|
| def patch_fixed_float_dtype(module: torch.nn.Module, dtype): |
| """Patch the module, to make sure the consisitent dtype.""" |
|
|
| def get_float_dtype_hook(dtype): |
|
|
| def _float_dtype_hook(module, input, output): |
| return to_float_dtype(output, dtype) |
|
|
| return _float_dtype_hook |
|
|
| module.register_forward_hook(get_float_dtype_hook(dtype)) |
|
|
|
|
| def patch_fixed_device(module: torch.nn.Module, device): |
| """Move the output to the specific device""" |
|
|
| def get_device_hook(device): |
|
|
| def _device_hook(module, input, output): |
| return to_device(output, device) |
|
|
| return _device_hook |
|
|
| module.register_forward_hook(get_device_hook(device)) |
|
|
|
|
| def patch_output_clone(module: torch.nn.Module): |
| """Clone the output, to avoid the inplace problem""" |
|
|
| def _clone_hook(module, input, output): |
| return output.requires_grad_(True).clone() |
|
|
| module.register_forward_hook(_clone_hook) |
|
|
|
|
| def patch_output_normalizer(module: torch.nn.Module, model_meta): |
|
|
| def lm_head_forward(self, hidden_states): |
| return hidden_states |
|
|
| lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer'] |
| llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) |
| if llm_prefix: |
| llm_model = getattr(module, llm_prefix[0]) |
| else: |
| llm_model = module |
|
|
| if 'CausalLM' not in llm_model.__class__.__name__: |
| llm_model = module |
|
|
| found = False |
| for lm_head in lm_heads: |
| if hasattr(llm_model, lm_head): |
| getattr(llm_model, lm_head).forward = MethodType(lm_head_forward, getattr(llm_model, lm_head)) |
| found = True |
| break |
|
|
| assert found, 'Cannot find the proper lm_head name' |
|
|
| def forward(self, input_ids: torch.LongTensor = None, attention_mask=None, *args, **kwargs): |
|
|
| outputs = self.forward_origin(input_ids=input_ids, attention_mask=attention_mask, *args, **kwargs) |
| hidden_states = outputs.logits |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| embeddings = hidden_states[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = hidden_states.shape[0] |
| embeddings = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
| return { |
| 'last_hidden_state': embeddings.contiguous(), |
| } |
|
|
| llm_model.forward_origin = llm_model.forward |
| llm_model.forward = MethodType(forward, llm_model) |
|
|
|
|
| def patch_output_to_input_device(module: torch.nn.Module): |
| """Patch the module, to make sure the output is in the same device with the input. |
| |
| Args: |
| module: The module to be patched |
| """ |
|
|
| def _output_to_input_device_hook(module, args, kwargs, output): |
| device = find_device(args) or find_device(kwargs) |
| return to_device(output, device) |
|
|
| module.register_forward_hook(_output_to_input_device_hook, with_kwargs=True) |
|
|
|
|
| @contextmanager |
| def patch_device_map(): |
| _get_no_split_modules = PreTrainedModel._get_no_split_modules |
|
|
| def _new_get_no_split_modules(self, device_map: str): |
| for module in self.modules(): |
| if isinstance(module, PreTrainedModel) and module._no_split_modules is None: |
| module.__class__._no_split_modules = [] |
| return _get_no_split_modules(self, device_map) |
|
|
| PreTrainedModel._get_no_split_modules = _new_get_no_split_modules |
| try: |
| yield |
| finally: |
| PreTrainedModel._get_no_split_modules = _get_no_split_modules |
|
|
|
|
| @contextmanager |
| def patch_ignore_check_imports(): |
| import transformers.dynamic_module_utils as td |
|
|
| def _check_imports(filename) -> List[str]: |
| return td.get_relative_imports(filename) |
|
|
| _old_check_imports = td.check_imports |
| td.check_imports = _check_imports |
| try: |
| yield |
| finally: |
| td.check_imports = _old_check_imports |
|
|
|
|
| def _patch_sequence_classification(model, model_meta): |
| hidden_size = HfConfigFactory.get_config_attr(model.config, 'hidden_size') |
| initializer_range = HfConfigFactory.get_config_attr(model.config, 'initializer_range') |
|
|
| lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer'] |
| llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None) |
| if llm_prefix: |
| llm_model = getattr(model, llm_prefix[0]) |
| else: |
| llm_model = model |
| if 'CausalLM' not in llm_model.__class__.__name__: |
| llm_model = model |
| llm_model.num_labels = model.config.num_labels |
| llm_model.score = nn.Linear(hidden_size, llm_model.num_labels, bias=False, dtype=llm_model.dtype) |
| if llm_model.score.weight.device == torch.device('meta'): |
| llm_model.score.to_empty(device='cpu') |
| llm_model.score.weight.data.normal_(mean=0.0, std=initializer_range) |
| for lm_head in lm_heads: |
| if hasattr(llm_model, lm_head): |
| setattr(llm_model, lm_head, nn.Identity()) |
| break |
|
|
| origin_forward = llm_model.forward.__func__ |
|
|
| @wraps(origin_forward) |
| def new_forward(self, *args, **kwargs): |
| labels = kwargs.pop('labels', None) |
| return_dict = kwargs.pop('return_dict', None) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| input_ids = kwargs.get('input_ids') |
| inputs_embeds = kwargs.get('inputs_embeds') |
|
|
| output = origin_forward(self, *args, **kwargs) |
| output.logits = output.logits.to(self.score.weight.dtype) |
| logits = self.score(output.logits) |
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
|
|
| if self.config.pad_token_id is None and batch_size != 1: |
| raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') |
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = 'regression' |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = 'single_label_classification' |
| else: |
| self.config.problem_type = 'multi_label_classification' |
|
|
| if self.config.problem_type == 'regression': |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == 'single_label_classification': |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == 'multi_label_classification': |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
| if not return_dict: |
| output = (pooled_logits, ) + output[1:] |
| return ((loss, ) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=output.past_key_values, |
| hidden_states=output.hidden_states, |
| attentions=output.attentions, |
| ) |
|
|
| llm_model.forward = MethodType(new_forward, llm_model) |
|
|
|
|
| @contextmanager |
| def patch_automodel_for_sequence_classification(model_meta): |
| from_pretrained = PreTrainedModel.from_pretrained.__func__ |
|
|
| @classmethod |
| def _new_from_pretrained(cls, *args, **kwargs): |
| __init__ = cls.__init__ |
|
|
| def __new_init__(self, *args, **kwargs): |
| __init__(self, *args, **kwargs) |
| _patch_sequence_classification(self, model_meta) |
|
|
| cls.__init__ = __new_init__ |
| if hasattr(cls, '_tp_plan'): |
| cls._tp_plan = cls._tp_plan or {} |
| res = from_pretrained(cls, *args, **kwargs) |
| cls.__init__ = __init__ |
| return res |
|
|
| PreTrainedModel.from_pretrained = _new_from_pretrained |
|
|
| try: |
| yield |
| finally: |
| PreTrainedModel.from_pretrained = classmethod(from_pretrained) |
|
|
|
|
| @contextmanager |
| def patch_automodel(automodel_class, model_info): |
| from_pretrained = PreTrainedModel.from_pretrained.__func__ |
|
|
| @classmethod |
| def _new_from_pretrained(cls, *args, **kwargs): |
| if 'AutoAWQFor' in automodel_class.__name__: |
| kwargs.pop('use_cache', None) |
| if model_info.quant_method == 'gptq': |
| cls.main_input_name = 'input_ids' |
| if hasattr(cls, '_tp_plan'): |
| cls._tp_plan = cls._tp_plan or {} |
| model = from_pretrained(cls, *args, **kwargs) |
| return model |
|
|
| PreTrainedModel.from_pretrained = _new_from_pretrained |
|
|
| try: |
| yield |
| finally: |
| PreTrainedModel.from_pretrained = classmethod(from_pretrained) |
|
|
|
|
| _mp_ddp_patched = False |
|
|
|
|
| def patch_mp_ddp(): |
| """Patch ddp with device_map. |
| After patching, the ddp can run with the device_map. |
| This should be called before any training starts. |
| """ |
| global _mp_ddp_patched |
| if is_mp_ddp() and not _mp_ddp_patched: |
| _mp_ddp_patched = True |
| from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map |
|
|
| @wraps(infer_auto_device_map) |
| def _infer_auto_device_map_patch(model: nn.Module, |
| max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, |
| **kwargs) -> Dict[str, Union[int, str, torch.device]]: |
| """The auxiliary function for supports MP + DDP. Monkey Patching. |
| add feat in accelerate to support MP + DDP""" |
| verbose = kwargs.pop('verbose', False) |
| n_gpu = get_device_count() |
| _, local_rank, _, local_world_size = get_dist_setting() |
| device_ids = list(range(local_rank, n_gpu, local_world_size)) |
| max_memory = _get_max_memory(device_ids) |
| max_memory = _sync_max_memory(max_memory) |
| max_memory = get_balanced_memory(model, max_memory, low_zero=False, **kwargs) |
| max_memory = {k: v for k, v in max_memory.items() if v > 0} |
| return infer_auto_device_map(model, max_memory, verbose=verbose, **kwargs) |
|
|
| _old_ddp_init = DDP.__init__ |
| accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = ( |
| lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs)) |
| transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None |
| transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch |
|
|
| if is_mp_ddp() or use_torchacc(): |
| _old_accelerator_init = trainer.Accelerator.__init__ |
| trainer.Accelerator.__init__ = (lambda self, device_placement=False, *args, **kwargs: _old_accelerator_init( |
| self, device_placement=device_placement, *args, **kwargs)) |
| trainer.Accelerator.verify_device_map = lambda *args, **kwargs: False |
|
|
|
|
| @contextmanager |
| def patch_get_dynamic_module(): |
| origin_get_cached_module_file = dynamic_module_utils.get_cached_module_file |
|
|
| def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs): |
| with safe_ddp_context(hash_id=str(pretrained_model_name_or_path)): |
| return origin_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs) |
|
|
| dynamic_module_utils.get_cached_module_file = new_get_cached_module_file |
| try: |
| yield |
| finally: |
| dynamic_module_utils.get_cached_module_file = origin_get_cached_module_file |
|
|
|
|
| @contextmanager |
| def patch_tp_plan(): |
| if not is_mp_ddp() or version.parse(transformers.__version__) < version.parse('4.50'): |
| yield |
| return |
| WORLD_SIZE = os.environ.get('WORLD_SIZE') |
| os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE |
| os.environ.pop('WORLD_SIZE') |
| yield |
| os.environ['WORLD_SIZE'] = WORLD_SIZE |
|
|