| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Dict |
| |
|
| | import torch |
| | from compressed_tensors import TRANSFORM_CONFIG_NAME |
| | from compressed_tensors.transform import TransformConfig, TransformFactory |
| | from compressed_tensors.utils.offload import has_offloaded_params |
| |
|
| |
|
| | __all__ = ["apply_transform_config"] |
| |
|
| |
|
| | def apply_transform_config(model: torch.nn.Module, config: TransformConfig): |
| | """ |
| | Apply a transform config to a model. Weight transforms are fused into weights, while |
| | activation transforms are attached as submodules and trigger via pytorch hooks |
| | |
| | :param model: model to apply config to |
| | :param config: transform config to apply |
| | """ |
| | for name, scheme in config.config_groups.items(): |
| | factory = TransformFactory.from_scheme(scheme, name=name) |
| | factory.apply_to_model(model) |
| |
|
| | |
| | setattr(model, TRANSFORM_CONFIG_NAME, config) |
| |
|
| | |
| | |
| | |
| | _tie_offloaded_tensors(model) |
| |
|
| |
|
| | def _tie_offloaded_tensors(model: torch.nn.Module): |
| | """ |
| | When accelerate replaces tensors with meta tensors during offloading, the meta |
| | tensors may not be identical, even if the offloaded values are identical. |
| | |
| | However, transformers can only serialize correctly if meta tensors are identical |
| | (see transformers#39263). |
| | |
| | This function collects all meta tensors which have shared offloaded values and sets |
| | those tensors to be identical so that they can be removed during serialization |
| | |
| | :param model: model potentially containing offloaded meta tensors to fix |
| | """ |
| |
|
| | |
| | |
| | ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() |
| | for module in model.modules(): |
| | if has_offloaded_params(module): |
| | for key, _ in module.named_parameters(recurse=False): |
| | offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() |
| |
|
| | if offloaded_ptr not in ptr_to_meta: |
| | ptr_to_meta[offloaded_ptr] = getattr(module, key) |
| | setattr(module, key, ptr_to_meta[offloaded_ptr]) |
| |
|