import torch from .models.tsal.sal_perceiver import AlignedShapeLatentPerceiver, ShapeAsLatentPerceiverEncoder def get_encoder( pretrained_path: str=None, freeze_decoder: bool=False, **kwargs ) -> AlignedShapeLatentPerceiver: model = AlignedShapeLatentPerceiver(**kwargs) if pretrained_path is not None: state_dict = torch.load(pretrained_path, weights_only=True) model.load_state_dict(state_dict) if freeze_decoder: model.geo_decoder.requires_grad_(False) model.encoder.query.requires_grad_(False) model.pre_kl.requires_grad_(False) model.post_kl.requires_grad_(False) model.transformer.requires_grad_(False) return model def get_encoder_simplified( pretrained_path: str=None, **kwargs ) -> ShapeAsLatentPerceiverEncoder: model = ShapeAsLatentPerceiverEncoder(**kwargs) if pretrained_path is not None: state_dict = torch.load(pretrained_path, weights_only=True) model.load_state_dict(state_dict) return model