Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2023 Haotian Liu | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from transformers.generation.utils import GenerateNonBeamOutput | |
| from transformers.utils import logging, is_accelerate_available | |
| from transformers.generation.configuration_utils import GenerationConfig | |
| from transformers.generation.logits_process import ( | |
| LogitsProcessorList, | |
| ) | |
| from transformers.generation.streamers import BaseStreamer | |
| from transformers.generation.stopping_criteria import ( | |
| StoppingCriteriaList, | |
| ) | |
| from transformers.utils import ModelOutput, logging | |
| import os | |
| logger = logging.get_logger(__name__) | |
| import collections | |
| import gc | |
| import itertools | |
| import os | |
| import re | |
| import shutil | |
| import tempfile | |
| from transformers import PreTrainedModel | |
| from transformers.integrations import is_deepspeed_zero3_enabled | |
| from transformers.pytorch_utils import id_tensor_storage | |
| from transformers.modeling_utils import ( | |
| is_fsdp_enabled, is_local_dist_rank_0, | |
| load_state_dict, set_initialized_submodules, | |
| _load_state_dict_into_model, | |
| _load_state_dict_into_meta_model, | |
| expand_device_map, get_disk_only_shard_files, | |
| get_disk_only_shard_files, | |
| ) | |
| if is_accelerate_available(): | |
| from accelerate.utils import ( | |
| find_tied_parameters, | |
| load_offloaded_weights, | |
| save_offload_index, | |
| set_module_tensor_to_device, | |
| ) | |
| from transformers.utils import logging | |
| from dataclasses import dataclass | |
| PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." | |
| class GenerateDecoderOnlyOutput(ModelOutput): | |
| """ | |
| Outputs of decoder-only generation models, when using non-beam methods. | |
| Args: | |
| sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
| if all batches finished early due to the `eos_token_id`. | |
| scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): | |
| Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): | |
| Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
| at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
| each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
| attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
| hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
| Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
| `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. | |
| past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. | |
| Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value | |
| tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape | |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if | |
| `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, | |
| encoder_sequence_length, embed_size_per_head)`. | |
| """ | |
| sequences: torch.LongTensor = None | |
| scores: Optional[Tuple[torch.FloatTensor]] = None | |
| logits: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None | |
| def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): | |
| # Convert old format to new format if needed from a PyTorch state_dict | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| new_key = None | |
| if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): | |
| logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) | |
| new_key = key.replace("gamma", "weight") | |
| if "beta" in key and "vision_tower.vision_tower" not in key: | |
| logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) | |
| new_key = key.replace("beta", "bias") | |
| if new_key: | |
| old_keys.append(key) | |
| new_keys.append(new_key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, "_metadata", None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| error_msgs = [] | |
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
| # so we need to apply the function recursively. | |
| def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| local_metadata["assign_to_params_buffers"] = assign_to_params_buffers | |
| args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | |
| # Parameters of module and children will start with prefix. We can exit early if there are none in this | |
| # state_dict | |
| if len([key for key in state_dict if key.startswith(prefix)]) > 0: | |
| if is_deepspeed_zero3_enabled(): | |
| import deepspeed | |
| # In sharded models, each shard has only part of the full state_dict, so only gather | |
| # parameters that are in the current state_dict. | |
| named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | |
| params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | |
| if len(params_to_gather) > 0: | |
| # because zero3 puts placeholders in model params, this context | |
| # manager gathers (unpartitions) the params of the current layer, then loads from | |
| # the state dict and then re-partitions them again | |
| with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | |
| if torch.distributed.get_rank() == 0: | |
| module._load_from_state_dict(*args) | |
| else: | |
| module._load_from_state_dict(*args) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, state_dict, prefix + name + ".", assign_to_params_buffers) | |
| load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) | |
| # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so | |
| # it's safe to delete it. | |
| del state_dict | |
| return error_msgs | |
| def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | |
| """ | |
| Checks if `model_to_load` supports param buffer assignment (such | |
| as when loading in empty weights) by first checking | |
| if the model explicitly disables it, then by ensuring that the state dict keys | |
| are a subset of the model's parameters. | |
| Note: We fully disable this if we are using `deepspeed` | |
| """ | |
| if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | |
| return False | |
| if is_deepspeed_zero3_enabled(): | |
| return False | |
| # Some models explicitly do not support param buffer assignment | |
| if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | |
| logger.debug( | |
| f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | |
| ) | |
| return False | |
| # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | |
| first_key = list(model_to_load.state_dict().keys())[0] | |
| if start_prefix + first_key in state_dict: | |
| return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | |
| # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) | |
| return False | |
| class BaseCausalLM(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| def _sample( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: LogitsProcessorList, | |
| stopping_criteria: StoppingCriteriaList, | |
| generation_config: GenerationConfig, | |
| synced_gpus: bool, | |
| streamer: Optional["BaseStreamer"], | |
| logits_warper: Optional[LogitsProcessorList] = None, | |
| **model_kwargs, | |
| ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and | |
| can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
| Parameters: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| logits_processor (`LogitsProcessorList`): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
| used to modify the prediction scores of the language modeling head applied at each generation step. | |
| stopping_criteria (`StoppingCriteriaList`): | |
| An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
| used to tell if the generation loop should stop. | |
| generation_config ([`~generation.GenerationConfig`]): | |
| The generation configuration to be used as parametrization of the decoding method. | |
| synced_gpus (`bool`): | |
| Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
| streamer (`BaseStreamer`, *optional*): | |
| Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
| through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
| logits_warper (`LogitsProcessorList`, *optional*): | |
| An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used | |
| to warp the prediction score distribution of the language modeling head applied before multinomial | |
| sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in | |
| `generation_config`) | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
| an encoder-decoder model the kwargs should include `encoder_outputs`. | |
| Return: | |
| [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: | |
| A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
| [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
| `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
| `model.config.is_encoder_decoder=True`. | |
| """ | |
| # init values | |
| pad_token_id = generation_config.pad_token_id | |
| output_attentions = generation_config.output_attentions | |
| output_hidden_states = generation_config.output_hidden_states | |
| output_scores = generation_config.output_scores | |
| output_logits = generation_config.output_logits | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | |
| do_sample = generation_config.do_sample | |
| if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): | |
| raise ValueError( | |
| "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " | |
| f"{logits_warper})." | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| raw_logits = () if (return_dict_in_generate and output_logits) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # keep track of which sequences are already finished | |
| batch_size = input_ids.shape[0] | |
| this_peer_finished = False | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
| model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) | |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| if do_sample: | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores,) | |
| if output_logits: | |
| raw_logits += (next_token_logits,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| probs = nn.functional.softmax(next_token_scores, dim=-1) | |
| # token selection | |
| if do_sample: | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| next_tokens = torch.argmax(next_token_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if has_eos_stopping_criteria: | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, | |
| model_kwargs, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| ) | |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
| this_peer_finished = unfinished_sequences.max() == 0 | |
| if streamer is not None: | |
| streamer.end() | |
| if return_dict_in_generate: | |
| return GenerateDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| logits=raw_logits, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| past_key_values=model_kwargs.get("past_key_values"), | |
| ) | |
| else: | |
| return input_ids | |
| def _load_pretrained_model( | |
| cls, | |
| model, | |
| state_dict, | |
| loaded_keys, | |
| resolved_archive_file, | |
| pretrained_model_name_or_path, | |
| ignore_mismatched_sizes=False, | |
| sharded_metadata=None, | |
| _fast_init=True, | |
| low_cpu_mem_usage=False, | |
| device_map=None, | |
| offload_folder=None, | |
| offload_state_dict=None, | |
| dtype=None, | |
| hf_quantizer=None, | |
| keep_in_fp32_modules=None, | |
| gguf_path=None, | |
| ): | |
| is_safetensors = False | |
| is_quantized = hf_quantizer is not None | |
| state_dict_folder = None | |
| state_dict_index = None | |
| if device_map is not None and "disk" in device_map.values(): | |
| archive_file = ( | |
| resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file | |
| ) | |
| is_safetensors = archive_file.endswith(".safetensors") | |
| if offload_folder is None and not is_safetensors: | |
| raise ValueError( | |
| "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" | |
| " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" | |
| " offers the weights in this format." | |
| ) | |
| if offload_folder is not None: | |
| os.makedirs(offload_folder, exist_ok=True) | |
| if offload_state_dict is None: | |
| offload_state_dict = True | |
| is_sharded_safetensors = is_safetensors and sharded_metadata is not None | |
| for key, param in model.state_dict().items(): | |
| if param.device == torch.device("meta"): | |
| try: | |
| set_module_tensor_to_device( | |
| model, key, "cuda", torch.empty(*param.size(), dtype=dtype) | |
| ) | |
| except: | |
| pass | |
| # tie the model weights before retrieving the state_dict | |
| model.tie_weights() | |
| # Retrieve missing & unexpected_keys | |
| model_state_dict = model.state_dict() | |
| expected_keys = list(model_state_dict.keys()) | |
| prefix = model.base_model_prefix | |
| def _fix_key(key): | |
| if "beta" in key and "vision_tower.vision_tower" not in key: | |
| return key.replace("beta", "bias") | |
| if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): | |
| return key.replace("gamma", "weight") | |
| return key | |
| original_loaded_keys = loaded_keys | |
| loaded_keys = [_fix_key(key) for key in loaded_keys] | |
| if len(prefix) > 0: | |
| has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) | |
| expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) | |
| else: | |
| has_prefix_module = False | |
| expects_prefix_module = False | |
| # key re-naming operations are never done on the keys | |
| # that are loaded, but always on the keys of the newly initialized model | |
| remove_prefix_from_model = not has_prefix_module and expects_prefix_module | |
| add_prefix_to_model = has_prefix_module and not expects_prefix_module | |
| if remove_prefix_from_model: | |
| _prefix = f"{prefix}." | |
| expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] | |
| expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] | |
| elif add_prefix_to_model: | |
| expected_keys = [".".join([prefix, s]) for s in expected_keys] | |
| missing_keys = sorted(set(expected_keys) - set(loaded_keys)) | |
| unexpected_keys = set(loaded_keys) - set(expected_keys) | |
| # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model | |
| # buffers | |
| model_buffers = {n for n, _ in model.named_buffers()} | |
| if remove_prefix_from_model: | |
| model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} | |
| elif add_prefix_to_model: | |
| model_buffers = {".".join([prefix, key]) for key in model_buffers} | |
| unexpected_keys = sorted(unexpected_keys - model_buffers) | |
| model.tie_weights() | |
| if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): | |
| ptrs = collections.defaultdict(list) | |
| for name, tensor in model.state_dict().items(): | |
| id_tensor = id_tensor_storage(tensor) | |
| ptrs[id_tensor].append(name) | |
| # These are all the pointers of shared tensors. | |
| tied_params = [names for _, names in ptrs.items() if len(names) > 1] | |
| else: | |
| # id function doesn't work for meta tensor so we need this function | |
| tied_params = find_tied_parameters(model) | |
| for group in tied_params: | |
| if remove_prefix_from_model: | |
| group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] | |
| elif add_prefix_to_model: | |
| group = [".".join([prefix, key]) for key in group] | |
| missing_in_group = [k for k in missing_keys if k in group] | |
| if len(missing_in_group) > 0 and len(missing_in_group) < len(group): | |
| missing_keys = [k for k in missing_keys if k not in missing_in_group] | |
| # Some models may have keys that are not in the state by design, removing them before needlessly warning | |
| # the user. | |
| if cls._keys_to_ignore_on_load_missing is not None: | |
| for pat in cls._keys_to_ignore_on_load_missing: | |
| missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | |
| if cls._keys_to_ignore_on_load_unexpected is not None: | |
| for pat in cls._keys_to_ignore_on_load_unexpected: | |
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | |
| if hf_quantizer is not None: | |
| missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) | |
| # retrieve weights on meta device and put them back on CPU. | |
| # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step | |
| if low_cpu_mem_usage: | |
| for key in missing_keys: | |
| if key in list(model_state_dict.keys()): | |
| key = key | |
| elif f"{prefix}.{key}" in list(model_state_dict.keys()): | |
| key = f"{prefix}.{key}" | |
| elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): | |
| key = ".".join(key.split(".")[1:]) | |
| param = model_state_dict[key] | |
| # upcast in fp32 if any | |
| target_dtype = dtype | |
| if ( | |
| keep_in_fp32_modules is not None | |
| and dtype == torch.float16 | |
| and any( | |
| module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules | |
| ) | |
| ): | |
| target_dtype = torch.float32 | |
| if param.device == torch.device("meta"): | |
| value = torch.empty(*param.size(), dtype=target_dtype) | |
| if ( | |
| not is_quantized | |
| or getattr(hf_quantizer, "requires_parameters_quantization", False) | |
| or not hf_quantizer.check_quantized_param( | |
| model, param_value=value, param_name=key, state_dict={} | |
| ) | |
| ): | |
| set_module_tensor_to_device(model, key, "cpu", value) | |
| else: | |
| hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) | |
| # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. | |
| if _fast_init: | |
| if not ignore_mismatched_sizes: | |
| if remove_prefix_from_model: | |
| _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] | |
| elif add_prefix_to_model: | |
| _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] | |
| else: | |
| _loaded_keys = loaded_keys | |
| not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) | |
| # If we're about to tie the output embeds to the input embeds we don't need to init them | |
| if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: | |
| output_embeddings = model.get_output_embeddings() | |
| if output_embeddings is not None: | |
| # Still need to initialize if there is a bias term since biases are not tied. | |
| if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: | |
| output_embeddings._is_hf_initialized = True | |
| else: | |
| not_initialized_submodules = dict(model.named_modules()) | |
| # This will only initialize submodules that are not marked as initialized by the line above. | |
| if is_deepspeed_zero3_enabled() and not is_quantized: | |
| import deepspeed | |
| not_initialized_parameters = list( | |
| set( | |
| itertools.chain.from_iterable( | |
| submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() | |
| ) | |
| ) | |
| ) | |
| with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): | |
| model.apply(model._initialize_weights) | |
| else: | |
| model.apply(model._initialize_weights) | |
| # Set some modules to fp32 if any | |
| if keep_in_fp32_modules is not None: | |
| for name, param in model.named_parameters(): | |
| if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): | |
| # param = param.to(torch.float32) does not work here as only in the local scope. | |
| param.data = param.data.to(torch.float32) | |
| # Make sure we are able to load base models as well as derived models (with heads) | |
| start_prefix = "" | |
| model_to_load = model | |
| if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: | |
| start_prefix = cls.base_model_prefix + "." | |
| if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: | |
| model_to_load = getattr(model, cls.base_model_prefix) | |
| base_model_expected_keys = list(model_to_load.state_dict().keys()) | |
| if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): | |
| raise ValueError( | |
| "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " | |
| "properly saved?" | |
| ) | |
| if device_map is not None: | |
| device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} | |
| def _find_mismatched_keys( | |
| state_dict, | |
| model_state_dict, | |
| loaded_keys, | |
| add_prefix_to_model, | |
| remove_prefix_from_model, | |
| ignore_mismatched_sizes, | |
| ): | |
| mismatched_keys = [] | |
| if ignore_mismatched_sizes: | |
| for checkpoint_key in loaded_keys: | |
| # If the checkpoint is sharded, we may not have the key here. | |
| if checkpoint_key not in state_dict: | |
| continue | |
| model_key = checkpoint_key | |
| if remove_prefix_from_model: | |
| # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | |
| model_key = f"{prefix}.{checkpoint_key}" | |
| elif add_prefix_to_model: | |
| # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | |
| model_key = ".".join(checkpoint_key.split(".")[1:]) | |
| if ( | |
| model_key in model_state_dict | |
| and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape | |
| ): | |
| if ( | |
| state_dict[checkpoint_key].shape[-1] == 1 | |
| and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() | |
| ): | |
| # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. | |
| # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. | |
| pass | |
| else: | |
| mismatched_keys.append( | |
| (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | |
| ) | |
| del state_dict[checkpoint_key] | |
| return mismatched_keys | |
| if resolved_archive_file is not None: | |
| folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) | |
| else: | |
| folder = None | |
| if device_map is not None and is_safetensors: | |
| param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) | |
| str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" | |
| if sharded_metadata is None: | |
| archive_file = ( | |
| resolved_archive_file[0] | |
| if isinstance(resolved_archive_file, (list, tuple)) | |
| else resolved_archive_file | |
| ) | |
| weight_map = {p: archive_file for p in original_loaded_keys} | |
| else: | |
| weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} | |
| offload_index = { | |
| p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} | |
| for p, f in weight_map.items() | |
| if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" | |
| } | |
| else: | |
| offload_index = None | |
| if state_dict is not None: | |
| # Whole checkpoint | |
| mismatched_keys = _find_mismatched_keys( | |
| state_dict, | |
| model_state_dict, | |
| original_loaded_keys, | |
| add_prefix_to_model, | |
| remove_prefix_from_model, | |
| ignore_mismatched_sizes, | |
| ) | |
| # For GGUF models `state_dict` is never set to None as the state dict is always small | |
| if gguf_path: | |
| error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | |
| model_to_load, | |
| state_dict, | |
| loaded_keys, | |
| start_prefix, | |
| expected_keys, | |
| device_map=device_map, | |
| offload_folder=offload_folder, | |
| offload_index=offload_index, | |
| state_dict_folder=state_dict_folder, | |
| state_dict_index=state_dict_index, | |
| dtype=dtype, | |
| hf_quantizer=hf_quantizer, | |
| is_safetensors=is_safetensors, | |
| keep_in_fp32_modules=keep_in_fp32_modules, | |
| unexpected_keys=unexpected_keys, | |
| ) | |
| else: | |
| # Sharded checkpoint or whole but low_cpu_mem_usage==True | |
| assign_to_params_buffers = check_support_param_buffer_assignment( | |
| model_to_load, state_dict, start_prefix | |
| ) | |
| error_msgs = _load_state_dict_into_model( | |
| model_to_load, state_dict, start_prefix, assign_to_params_buffers | |
| ) | |
| else: | |
| # This should always be a list but, just to be sure. | |
| if not isinstance(resolved_archive_file, list): | |
| resolved_archive_file = [resolved_archive_file] | |
| error_msgs = [] | |
| mismatched_keys = [] | |
| if not is_safetensors: | |
| offload_index = {} if device_map is not None and "disk" in device_map.values() else None | |
| if offload_state_dict: | |
| state_dict_folder = tempfile.mkdtemp() | |
| state_dict_index = {} | |
| else: | |
| state_dict_folder = None | |
| state_dict_index = None | |
| if is_sharded_safetensors: | |
| disk_only_shard_files = get_disk_only_shard_files( | |
| device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix | |
| ) | |
| disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] | |
| else: | |
| disk_only_shard_files = [] | |
| if len(resolved_archive_file) > 1: | |
| resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") | |
| assign_to_params_buffers = None | |
| for shard_file in resolved_archive_file: | |
| # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. | |
| if shard_file in disk_only_shard_files: | |
| continue | |
| state_dict = load_state_dict(shard_file, is_quantized=is_quantized) | |
| # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | |
| # matching the weights in the model. | |
| mismatched_keys += _find_mismatched_keys( | |
| state_dict, | |
| model_state_dict, | |
| original_loaded_keys, | |
| add_prefix_to_model, | |
| remove_prefix_from_model, | |
| ignore_mismatched_sizes, | |
| ) | |
| if low_cpu_mem_usage: | |
| if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: | |
| for key, param in model_to_load.state_dict().items(): | |
| if param.device == torch.device("meta"): | |
| set_module_tensor_to_device( | |
| model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) | |
| ) | |
| else: | |
| new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | |
| model_to_load, | |
| state_dict, | |
| loaded_keys, | |
| start_prefix, | |
| expected_keys, | |
| device_map=device_map, | |
| offload_folder=offload_folder, | |
| offload_index=offload_index, | |
| state_dict_folder=state_dict_folder, | |
| state_dict_index=state_dict_index, | |
| dtype=dtype, | |
| hf_quantizer=hf_quantizer, | |
| is_safetensors=is_safetensors, | |
| keep_in_fp32_modules=keep_in_fp32_modules, | |
| unexpected_keys=unexpected_keys, | |
| ) | |
| error_msgs += new_error_msgs | |
| else: | |
| # Sharded checkpoint or whole but low_cpu_mem_usage==True | |
| if assign_to_params_buffers is None: | |
| assign_to_params_buffers = check_support_param_buffer_assignment( | |
| model_to_load, state_dict, start_prefix | |
| ) | |
| error_msgs += _load_state_dict_into_model( | |
| model_to_load, state_dict, start_prefix, assign_to_params_buffers | |
| ) | |
| # force memory release | |
| del state_dict | |
| gc.collect() | |
| if offload_index is not None and len(offload_index) > 0: | |
| if model != model_to_load: | |
| # We need to add the prefix of the base model | |
| prefix = cls.base_model_prefix | |
| if not is_safetensors: | |
| for weight_name in offload_index: | |
| shutil.move( | |
| os.path.join(offload_folder, f"{weight_name}.dat"), | |
| os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), | |
| ) | |
| offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} | |
| if not is_safetensors: | |
| save_offload_index(offload_index, offload_folder) | |
| offload_index = None | |
| if offload_state_dict: | |
| # Load back temporarily offloaded state dict | |
| load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) | |
| shutil.rmtree(state_dict_folder) | |
| if len(error_msgs) > 0: | |
| error_msg = "\n\t".join(error_msgs) | |
| if "size mismatch" in error_msg: | |
| error_msg += ( | |
| "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." | |
| ) | |
| raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | |
| if len(unexpected_keys) > 0: | |
| archs = [] if model.config.architectures is None else model.config.architectures | |
| warner = logger.warning if model.__class__.__name__ in archs else logger.info | |
| warner( | |
| f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | |
| f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | |
| f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | |
| " with another architecture (e.g. initializing a BertForSequenceClassification model from a" | |
| " BertForPreTraining model).\n- This IS NOT expected if you are initializing" | |
| f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" | |
| " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | |
| ) | |
| else: | |
| logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
| if len(missing_keys) > 0: | |
| logger.warning( | |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
| f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" | |
| " TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
| ) | |
| elif len(mismatched_keys) == 0: | |
| logger.info( | |
| f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" | |
| f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" | |
| f" was trained on, you can already use {model.__class__.__name__} for predictions without further" | |
| " training." | |
| ) | |
| if len(mismatched_keys) > 0: | |
| mismatched_warning = "\n".join( | |
| [ | |
| f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | |
| for key, shape1, shape2 in mismatched_keys | |
| ] | |
| ) | |
| logger.warning( | |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
| f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" | |
| f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | |
| " to use it for predictions and inference." | |
| ) | |
| return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs |