Spaces:
Configuration error
Configuration error
| from torch import load, nn | |
| from visualizr import logger | |
| from visualizr.networks.encoder import Encoder | |
| from visualizr.networks.styledecoder import Synthesis | |
| class LIA_Model(nn.Module): | |
| def __init__( | |
| self, | |
| size=256, | |
| style_dim=512, | |
| motion_dim=20, | |
| channel_multiplier=1, | |
| blur_kernel=[1, 3, 3, 1], | |
| fusion_type="", | |
| ): | |
| super().__init__() | |
| self.enc = Encoder(size, style_dim, motion_dim, fusion_type) | |
| self.dec = Synthesis( | |
| size, style_dim, motion_dim, blur_kernel, channel_multiplier | |
| ) | |
| def get_start_direction_code(self, x_start, x_target, x_face, x_aug): | |
| enc_dic = self.enc(x_start, x_target, x_face, x_aug) | |
| wa, alpha, feats = enc_dic["h_source"], enc_dic["h_motion"], enc_dic["feats"] | |
| return wa, alpha, feats | |
| def render(self, start, direction, feats): | |
| return self.dec(start, direction, feats) | |
| def load_lightning_model(self, lia_pretrained_model_path): | |
| selfState = self.state_dict() | |
| state = load(lia_pretrained_model_path, map_location="cpu") | |
| for name, param in state.items(): | |
| origName = name | |
| if name not in selfState: | |
| name = name.replace("lia.", "") | |
| if name not in selfState: | |
| logger.exception("%s is not in the model." % origName) | |
| # You can ignore those errors as some parameters are only used for training | |
| continue | |
| if selfState[name].size() != state[origName].size(): | |
| logger.exception( | |
| "Wrong parameter length: %s, model: %s, loaded: %s" | |
| % (origName, selfState[name].size(), state[origName].size()) | |
| ) | |
| continue | |
| selfState[name].copy_(param) | |