| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import functools |
| from contextlib import contextmanager |
| from typing import Literal, Optional |
|
|
| import packaging.version |
| import torch |
| import transformers |
| from torch import nn |
|
|
|
|
| def check_deepspeed_zero3_enabled() -> bool: |
| if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"): |
| from transformers.integrations import is_deepspeed_zero3_enabled |
| else: |
| from transformers.deepspeed import is_deepspeed_zero3_enabled |
| return is_deepspeed_zero3_enabled() |
|
|
|
|
| @contextmanager |
| def gather_params_ctx(param, modifier_rank: Optional[int] = 0, fwd_module: torch.nn.Module = None): |
| """Call DeepSpeed GatheredParameters context manager if DeepSpeed is enabled, otherwise do nothing.""" |
|
|
| if not check_deepspeed_zero3_enabled(): |
| yield |
| return |
|
|
| import deepspeed |
|
|
| with deepspeed.zero.GatheredParameters(param, modifier_rank=modifier_rank, fwd_module=fwd_module): |
| yield |
| return |
|
|
|
|
| def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter: |
| """ |
| Helper function to dequantize a quantized weight. |
| |
| This function should be extended if more quantization schemes are added to the library. |
| |
| If the weight is not quantized, it will be returned as is. |
| """ |
| if hasattr(module, "W_q"): |
| weight = module.dequantize() |
| return weight |
| elif type(module.weight).__module__.startswith("torchao."): |
| |
| weight = module.weight.dequantize() |
| return weight |
|
|
| weight = module.weight |
| if not isinstance(weight, torch.nn.Parameter): |
| if isinstance(weight, torch.Tensor): |
| |
| return weight |
| raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") |
|
|
| cls_name = weight.__class__.__name__ |
| if cls_name not in ("Params4bit", "Int8Params"): |
| return weight |
|
|
| quant_state = getattr(module, "state", None) |
| device = weight.device |
| is_cpu = device.type == torch.device("cpu").type |
| weight = dequantize_bnb_weight(weight, state=quant_state) |
| if is_cpu: |
| |
| module.weight = module.weight.to(device) |
| return weight |
|
|
|
|
| def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): |
| """Helper function to dequantize 4bit or 8bit bnb weights.""" |
| import bitsandbytes as bnb |
|
|
| device = weight.device |
|
|
| cls_name = weight.__class__.__name__ |
| if cls_name == "Params4bit": |
| dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
| return dequantized |
|
|
| |
| if state is None: |
| raise ValueError( |
| "No `state` was passed for bnb 8bit quantized weights. Please open an issue on the PEFT repository and " |
| "report the error: https://github.com/huggingface/peft/issues" |
| ) |
|
|
| if state.SCB is None: |
| state.SCB = weight.SCB |
|
|
| if hasattr(bnb.functional, "int8_vectorwise_dequant"): |
| |
| dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) |
| else: |
| |
| dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3 |
|
|
| return dequantized |
|
|
|
|
| def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bit"]: |
| """Returns '4bit' or '8bit' if bitsandbytes parameter, else False""" |
| if param.__class__.__name__ == "Params4bit": |
| return "4bit" |
| if param.__class__.__name__ == "Int8Params": |
| return "8bit" |
| return False |
|
|
|
|
| |
| |
| def get_layer_device_map(model): |
| """ |
| Derive the device map for the layers of the model. |
| """ |
| main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] |
|
|
| execution_device_map = { |
| name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items() |
| } |
|
|
| if execution_device_map is None: |
| return None |
|
|
| if len(execution_device_map) == 1 and "" in execution_device_map: |
| return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} |
|
|
| layer_device_map = {} |
| for layer in execution_device_map: |
| for idx in range(model.config.num_hidden_layers): |
| if f".{idx}." in f"{layer}.": |
| layer_device_map[idx] = execution_device_map[layer] |
| break |
| for idx in range(model.config.num_hidden_layers): |
| if idx not in layer_device_map: |
| raise RuntimeError(f"layer {idx} has not been mapped to a device.") |
| return layer_device_map |
|
|
|
|
| |
| |
| def map_cache_to_layer_device_map(model, cache) -> None: |
| """ |
| Ensure that the key and value cache of the model are on the same device as their corresponding layers. |
| """ |
| if not (isinstance(cache, transformers.Cache) and hasattr(model, "hf_device_map")): |
| return |
|
|
| if isinstance(cache, transformers.EncoderDecoderCache): |
| map_cache_to_layer_device_map(model, cache.self_attention_cache) |
| return |
|
|
| layer_device_map = get_layer_device_map(model) |
| for idx in range(model.config.num_hidden_layers): |
| layer_device = layer_device_map[idx] |
| if hasattr(cache, "layers"): |
| |
| layer = cache.layers[idx] |
| layer.keys = layer.keys.to(layer_device) |
| layer.values = layer.values.to(layer_device) |
| else: |
| |
| |
| cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) |
| cache.value_cache[idx] = cache.value_cache[idx].to(layer_device) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @contextmanager |
| def init_empty_weights(include_buffers: bool = None): |
| |
| with _init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: |
| yield f |
|
|
|
|
| @contextmanager |
| def _init_on_device(device: torch.device, include_buffers: bool = None): |
| |
| old_register_parameter = nn.Module.register_parameter |
| if include_buffers: |
| old_register_buffer = nn.Module.register_buffer |
|
|
| def register_empty_parameter(module, name, param): |
| |
| |
| |
| old_register_parameter(module, name, param) |
| if (param is not None) and (getattr(_init_on_device, "_skip", False) is not True): |
| param_cls = type(module._parameters[name]) |
| kwargs = module._parameters[name].__dict__ |
| kwargs["requires_grad"] = param.requires_grad |
| module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) |
|
|
| def register_empty_buffer(module, name, buffer, persistent=True): |
| old_register_buffer(module, name, buffer, persistent=persistent) |
| if buffer is not None: |
| module._buffers[name] = module._buffers[name].to(device) |
|
|
| |
| if include_buffers: |
| tensor_constructors_to_patch = { |
| torch_function_name: getattr(torch, torch_function_name) |
| for torch_function_name in ["empty", "zeros", "ones", "full"] |
| } |
| else: |
| tensor_constructors_to_patch = {} |
|
|
| def patch_tensor_constructor(fn): |
| def wrapper(*args, **kwargs): |
| kwargs["device"] = device |
| return fn(*args, **kwargs) |
|
|
| return wrapper |
|
|
| try: |
| nn.Module.register_parameter = register_empty_parameter |
| if include_buffers: |
| nn.Module.register_buffer = register_empty_buffer |
| for torch_function_name in tensor_constructors_to_patch.keys(): |
| setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) |
| yield |
| finally: |
| nn.Module.register_parameter = old_register_parameter |
| if include_buffers: |
| nn.Module.register_buffer = old_register_buffer |
| for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): |
| setattr(torch, torch_function_name, old_torch_function) |
|
|
|
|
| @contextmanager |
| def _skip_init_on_device(): |
| |
| old_val = getattr(_init_on_device, "_skip", False) |
| try: |
| _init_on_device._skip = True |
| yield |
| finally: |
| _init_on_device._skip = old_val |
|
|
|
|
| def skip_init_on_device(func): |
| """ |
| Ignore the init_on_device context manager when calling the decorated function. |
| |
| This is a narrow use decorator that allows us to avoid initializing on meta device even when we're inside the |
| init_empty_weights context. |
| |
| """ |
|
|
| |
| |
| |
| |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| with _skip_init_on_device(): |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| |
| |
| |
|
|