Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| from typing import List | |
| from torch import nn | |
| def tie_encoder_decoder_weights( | |
| encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str | |
| ): | |
| uninitialized_encoder_weights: List[str] = [] | |
| if decoder.__class__ != encoder.__class__: | |
| logging.info( | |
| f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." | |
| ) | |
| def tie_encoder_to_decoder_recursively( | |
| decoder_pointer: nn.Module, | |
| encoder_pointer: nn.Module, | |
| module_name: str, | |
| uninitialized_encoder_weights: List[str], | |
| skip_key: str, | |
| depth=0, | |
| ): | |
| assert isinstance(decoder_pointer, nn.Module) and isinstance( | |
| encoder_pointer, nn.Module | |
| ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" | |
| if hasattr(decoder_pointer, "weight") and skip_key not in module_name: | |
| assert hasattr(encoder_pointer, "weight") | |
| encoder_pointer.weight = decoder_pointer.weight | |
| if hasattr(decoder_pointer, "bias"): | |
| assert hasattr(encoder_pointer, "bias") | |
| encoder_pointer.bias = decoder_pointer.bias | |
| print(module_name + " is tied") | |
| return | |
| encoder_modules = encoder_pointer._modules | |
| decoder_modules = decoder_pointer._modules | |
| if len(decoder_modules) > 0: | |
| assert ( | |
| len(encoder_modules) > 0 | |
| ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | |
| all_encoder_weights = set( | |
| [module_name + "/" + sub_name for sub_name in encoder_modules.keys()] | |
| ) | |
| encoder_layer_pos = 0 | |
| for name, module in decoder_modules.items(): | |
| if name.isdigit(): | |
| encoder_name = str(int(name) + encoder_layer_pos) | |
| decoder_name = name | |
| if not isinstance( | |
| decoder_modules[decoder_name], | |
| type(encoder_modules[encoder_name]), | |
| ) and len(encoder_modules) != len(decoder_modules): | |
| # this can happen if the name corresponds to the position in a list module list of layers | |
| # in this case the decoder has added a cross-attention that the encoder does not have | |
| # thus skip this step and subtract one layer pos from encoder | |
| encoder_layer_pos -= 1 | |
| continue | |
| elif name not in encoder_modules: | |
| continue | |
| elif depth > 500: | |
| raise ValueError( | |
| "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." | |
| ) | |
| else: | |
| decoder_name = encoder_name = name | |
| tie_encoder_to_decoder_recursively( | |
| decoder_modules[decoder_name], | |
| encoder_modules[encoder_name], | |
| module_name + "/" + name, | |
| uninitialized_encoder_weights, | |
| skip_key, | |
| depth=depth + 1, | |
| ) | |
| all_encoder_weights.remove(module_name + "/" + encoder_name) | |
| uninitialized_encoder_weights += list(all_encoder_weights) | |
| # tie weights recursively | |
| tie_encoder_to_decoder_recursively( | |
| decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key | |
| ) | |