--- language: en tags: - lightweightgan license: apache-2.0 datasets: - glid3_orbs --- # orbgan lightweight GAN trained on my glid-3 orbs (https://huggingface.co/datasets/johnowhitaker/glid3_orbs) for demo I'm working on. Training notebook: https://colab.research.google.com/drive/16o1TdrxnQ54Msbr813XfPVsnEt2QTRAa?usp=sharing Inference notebook: https://colab.research.google.com/drive/1e7dR2dptM8F1xhRcyy-Aqow9YSe0NE3z?usp=sharing The lightwightgan code has an assert requiring a GPU. For inference on the CPU we ned to re-define the Generator class and some other functions - see minimal example here: https://colab.research.google.com/drive/1fnHLdJ7niPMGOOBjGkNsnc6iADpf1Ujd?usp=sharing . This approach was used to make the demo space here: https://huggingface.co/spaces/johnowhitaker/orbgan_demo Please credit if you use this, and feedback on the code is welcomed :) EDIT: you may need to use an older version of lightweightgan, eg from commit 708633205d60c99b1b9d4e6b47eb3722aa4159d6 since there have been some recent changes that happened after this model was trained. ## Demo: ```python from lightweight_gan import Generator import torch from matplotlib import pyplot as plt from huggingface_hub import PyTorchModelHubMixin # Initialize a generator model gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) # Load from local saved state dict # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) # Load from model hub: class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): pass gan_new.__class__ = GeneratorWithPyTorchModelHubMixin gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32]) # View some examples n_rows = 3 ims = gan_new(torch.randn(n_rows**2, 256)).clamp_(0., 1.) fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9)) for i, ax in enumerate(axs.flatten()): ax.imshow(ims[i].permute(1, 2, 0).detach().cpu().numpy()) plt.tight_layout() ```