Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from manipulate_model.encoder.encoder import Encoder | |
| from manipulate_model.decoder.decoder import Decoder | |
| class Model(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, config): | |
| super(Model, self).__init__() | |
| self.config = config | |
| self.encoder = Encoder(self.config) | |
| self.config.model.decoder.temporal_dim = self.encoder.get_temporal_dim() | |
| self.config.model.decoder.encoding_dim = self.encoder.get_encoding_dim() | |
| self.decoder = Decoder(self.config) | |
| def forward(self, x): | |
| if self.config.model.encoder_freeze: | |
| with torch.no_grad(): | |
| x = self.encoder(x) | |
| else: | |
| x = self.encoder(x) | |
| x = self.decoder(x) | |
| return x | |