| import huggingface_hub | |
| from .config import Config | |
| from transformers import PreTrainedModel | |
| from .dcgan import Generator | |
| # config = Config() | |
| # config.save_pretrained("WGAN-GP") | |
| class WGAN_GP(PreTrainedModel): | |
| config_class = Config | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.generator=Generator(config.cfg["imsize"],config.cfg["img_ch"],config.cfg["zdim"], | |
| config.cfg["norm_type"]["g"],config.cfg["final_activation"]["g"]) | |
| def forward(self, input): | |
| return self.generator(input) | |