Spaces:
Paused
Paused
| import torch | |
| from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder | |
| def get_encoder( | |
| pretrained_path: str=None, | |
| freeze_decoder: bool=False, | |
| **kwargs | |
| ) -> AlignedShapeLatentPerceiver: | |
| model = AlignedShapeLatentPerceiver(**kwargs) | |
| if pretrained_path is not None: | |
| state_dict = torch.load(pretrained_path, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| if freeze_decoder: | |
| model.geo_decoder.requires_grad_(False) | |
| model.encoder.query.requires_grad_(False) | |
| model.pre_kl.requires_grad_(False) | |
| model.post_kl.requires_grad_(False) | |
| model.transformer.requires_grad_(False) | |
| return model | |
| def get_encoder_simplified( | |
| pretrained_path: str=None, | |
| **kwargs | |
| ) -> ShapeAsLatentPerceiverEncoder: | |
| model = ShapeAsLatentPerceiverEncoder(**kwargs) | |
| if pretrained_path is not None: | |
| state_dict = torch.load(pretrained_path, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| return model |