|
|
--- |
|
|
language: en |
|
|
tags: |
|
|
- lightweightgan |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- glid3_orbs |
|
|
--- |
|
|
|
|
|
# orbgan |
|
|
|
|
|
lightweight GAN trained on my glid-3 orbs (https://huggingface.co/datasets/johnowhitaker/glid3_orbs) for demo I'm working on. |
|
|
|
|
|
Training notebook: https://colab.research.google.com/drive/16o1TdrxnQ54Msbr813XfPVsnEt2QTRAa?usp=sharing |
|
|
|
|
|
Inference notebook: https://colab.research.google.com/drive/1e7dR2dptM8F1xhRcyy-Aqow9YSe0NE3z?usp=sharing |
|
|
|
|
|
The lightwightgan code has an assert requiring a GPU. For inference on the CPU we ned to re-define the Generator class and some other functions - see minimal example here: https://colab.research.google.com/drive/1fnHLdJ7niPMGOOBjGkNsnc6iADpf1Ujd?usp=sharing . This approach was used to make the demo space here: https://huggingface.co/spaces/johnowhitaker/orbgan_demo |
|
|
|
|
|
Please credit if you use this, and feedback on the code is welcomed :) |
|
|
|
|
|
|
|
|
EDIT: you may need to use an older version of lightweightgan, eg from commit 708633205d60c99b1b9d4e6b47eb3722aa4159d6 since there have been some recent changes that happened after this model was trained. |
|
|
|
|
|
## Demo: |
|
|
|
|
|
```python |
|
|
from lightweight_gan import Generator |
|
|
import torch |
|
|
from matplotlib import pyplot as plt |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
# Initialize a generator model |
|
|
gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32]) |
|
|
|
|
|
# Load from local saved state dict |
|
|
# gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) |
|
|
|
|
|
# Load from model hub: |
|
|
class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin): |
|
|
pass |
|
|
gan_new.__class__ = GeneratorWithPyTorchModelHubMixin |
|
|
gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32]) |
|
|
|
|
|
# View some examples |
|
|
n_rows = 3 |
|
|
ims = gan_new(torch.randn(n_rows**2, 256)).clamp_(0., 1.) |
|
|
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9)) |
|
|
for i, ax in enumerate(axs.flatten()): |
|
|
ax.imshow(ims[i].permute(1, 2, 0).detach().cpu().numpy()) |
|
|
plt.tight_layout() |
|
|
``` |
|
|
|