File size: 2,016 Bytes
8dfebb7
 
 
 
 
 
82fd20b
8dfebb7
 
 
 
db6d10f
3ee3292
dbfe06e
 
7412750
 
0222c6b
 
dbfe06e
 
3167e1d
 
 
8dfebb7
3ee3292
82fd20b
 
3ee3292
 
 
 
 
 
 
 
 
 
 
e349e78
 
3ee3292
 
 
 
 
 
 
 
 
b566e22
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
---
language: en
tags:
- lightweightgan
license: apache-2.0
datasets:
- glid3_orbs
---

# orbgan

lightweight GAN trained on my glid-3 orbs (https://huggingface.co/datasets/johnowhitaker/glid3_orbs) for demo I'm working on.

Training notebook: https://colab.research.google.com/drive/16o1TdrxnQ54Msbr813XfPVsnEt2QTRAa?usp=sharing

Inference notebook: https://colab.research.google.com/drive/1e7dR2dptM8F1xhRcyy-Aqow9YSe0NE3z?usp=sharing

The lightwightgan code has an assert requiring a GPU. For inference on the CPU we ned to re-define the Generator class and some other functions - see minimal example here: https://colab.research.google.com/drive/1fnHLdJ7niPMGOOBjGkNsnc6iADpf1Ujd?usp=sharing . This approach was used to make the demo space here: https://huggingface.co/spaces/johnowhitaker/orbgan_demo

Please credit if you use this, and feedback on the code is welcomed :)


EDIT: you may need to use an older version of lightweightgan, eg from commit 708633205d60c99b1b9d4e6b47eb3722aa4159d6 since there have been some recent changes that happened after this model was trained.

## Demo:

```python
from lightweight_gan import Generator
import torch
from matplotlib import pyplot as plt
from huggingface_hub import PyTorchModelHubMixin

# Initialize a generator model
gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])

# Load from local saved state dict
# gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt')) 

# Load from model hub:
class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin):
    pass
gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32])

# View some examples
n_rows = 3
ims = gan_new(torch.randn(n_rows**2, 256)).clamp_(0., 1.)
fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
for i, ax in enumerate(axs.flatten()):
  ax.imshow(ims[i].permute(1, 2, 0).detach().cpu().numpy())
plt.tight_layout()
```