| import torch | |
| from transformers import PreTrainedModel | |
| from .LightweightGANConfig import LightweightGANConfig | |
| from .deploy import Generator | |
| class LightweightGANModel(PreTrainedModel): | |
| config_class = LightweightGANConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = Generator( | |
| image_size=config.image_size, | |
| latent_dim=config.latent_dim, | |
| fmap_max=config.fmap_max, | |
| fmap_inverse_coef=config.fmap_inverse_coef, | |
| transparent=config.transparent, | |
| greyscale=config.greyscale, | |
| attn_res_layers=config.attn_res_layers, | |
| freq_chan_attn=config.freq_chan_attn, | |
| syncbatchnorm=config.syncbatchnorm, | |
| antialias=config.antialias, | |
| ) | |
| def forward(self, tensor): | |
| return self.model(tensor) | |
| def load_params(self, pt_file): | |
| self.model.load_state_dict(torch.load(pt_file)) | |