| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| |
|
| | from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder |
| |
|
| | def get_encoder( |
| | pretrained_path: str=None, |
| | freeze_decoder: bool=False, |
| | **kwargs |
| | ) -> AlignedShapeLatentPerceiver: |
| | model = AlignedShapeLatentPerceiver(**kwargs) |
| | if pretrained_path is not None: |
| | state_dict = torch.load(pretrained_path, weights_only=True) |
| | model.load_state_dict(state_dict) |
| | if freeze_decoder: |
| | model.geo_decoder.requires_grad_(False) |
| | model.encoder.query.requires_grad_(False) |
| | model.pre_kl.requires_grad_(False) |
| | model.post_kl.requires_grad_(False) |
| | model.transformer.requires_grad_(False) |
| | return model |
| |
|
| | def get_encoder_simplified( |
| | pretrained_path: str=None, |
| | **kwargs |
| | ) -> ShapeAsLatentPerceiverEncoder: |
| | model = ShapeAsLatentPerceiverEncoder(**kwargs) |
| | if pretrained_path is not None: |
| | state_dict = torch.load(pretrained_path, weights_only=True) |
| | model.load_state_dict(state_dict) |
| | return model |