johnowhitaker commited on
Commit
3ee3292
·
1 Parent(s): 97d4a8e

Added example usage

Browse files
Files changed (1) hide show
  1. README.md +28 -1
README.md CHANGED
@@ -1 +1,28 @@
1
- lightweight GAN trained on glid-3 orbs for demo I'm working on.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightweight GAN trained on glid-3 orbs for demo I'm working on.
2
+
3
+ Demo:
4
+
5
+ from lightweight_gan import Generator
6
+ import torch
7
+ from matplotlib import pyplot as plt
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+ class GeneratorWithPyTorchModelHubMixin(gan.__class__, PyTorchModelHubMixin):
10
+ pass
11
+
12
+ # Initialize a generator model
13
+ gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])
14
+
15
+ # Load from local saved state dict
16
+ # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt'))
17
+
18
+ # Load from model hub:
19
+ gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
20
+ gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32])
21
+
22
+ # View some examples
23
+ n_rows = 3
24
+ ims = gan_new(torch.randn(n_rows**2, 256)).clamp_(0., 1.)
25
+ fig, axs = plt.subplots(n_rows, n_rows, figsize=(9, 9))
26
+ for i, ax in enumerate(axs.flatten()):
27
+ ax.imshow(ims[i].permute(1, 2, 0).detach().cpu().numpy())
28
+ plt.tight_layout()