Spaces:
Runtime error
Runtime error
Commit
·
ae2d652
1
Parent(s):
a94775f
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download
|
|
| 9 |
from PIL import Image
|
| 10 |
from torch import nn
|
| 11 |
from torchvision.utils import save_image
|
| 12 |
-
|
| 13 |
|
| 14 |
class Generator(nn.Module):
|
| 15 |
def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
|
|
@@ -42,10 +42,6 @@ class Generator(nn.Module):
|
|
| 42 |
|
| 43 |
return pixel_values
|
| 44 |
|
| 45 |
-
model = Generator()
|
| 46 |
-
weights_path = hf_hub_download('huggingnft/dooggies', 'pytorch_model.bin')
|
| 47 |
-
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 48 |
-
|
| 49 |
|
| 50 |
@torch.no_grad()
|
| 51 |
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
@@ -76,7 +72,10 @@ def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
| 76 |
save_all=True, duration=100, loop=1)
|
| 77 |
|
| 78 |
|
| 79 |
-
def predict(choice, seed):
|
|
|
|
|
|
|
|
|
|
| 80 |
torch.manual_seed(seed)
|
| 81 |
|
| 82 |
if choice == 'interpolation':
|
|
@@ -92,9 +91,11 @@ def predict(choice, seed):
|
|
| 92 |
return 'image.png'
|
| 93 |
|
| 94 |
|
|
|
|
| 95 |
gr.Interface(
|
| 96 |
predict,
|
| 97 |
inputs=[
|
|
|
|
| 98 |
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
|
| 99 |
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
|
| 100 |
],
|
|
@@ -102,5 +103,5 @@ gr.Interface(
|
|
| 102 |
title="Cryptopunks GAN",
|
| 103 |
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
|
| 104 |
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
|
| 105 |
-
examples=[["interpolation",
|
| 106 |
).launch(cache_examples=True)
|
|
|
|
| 9 |
from PIL import Image
|
| 10 |
from torch import nn
|
| 11 |
from torchvision.utils import save_image
|
| 12 |
+
hfapi = HfApi()
|
| 13 |
|
| 14 |
class Generator(nn.Module):
|
| 15 |
def __init__(self, num_channels=4, latent_dim=100, hidden_size=64):
|
|
|
|
| 42 |
|
| 43 |
return pixel_values
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
@torch.no_grad()
|
| 47 |
def interpolate(save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
|
|
| 72 |
save_all=True, duration=100, loop=1)
|
| 73 |
|
| 74 |
|
| 75 |
+
def predict(model_name, choice, seed):
|
| 76 |
+
model = Generator()
|
| 77 |
+
weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
|
| 78 |
+
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 79 |
torch.manual_seed(seed)
|
| 80 |
|
| 81 |
if choice == 'interpolation':
|
|
|
|
| 91 |
return 'image.png'
|
| 92 |
|
| 93 |
|
| 94 |
+
models = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
|
| 95 |
gr.Interface(
|
| 96 |
predict,
|
| 97 |
inputs=[
|
| 98 |
+
gr.inputs.Dropdown(models, label='Model'),
|
| 99 |
gr.inputs.Dropdown(['image', 'interpolation'], label='Output Type'),
|
| 100 |
gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
|
| 101 |
],
|
|
|
|
| 103 |
title="Cryptopunks GAN",
|
| 104 |
description="These CryptoPunks do not exist. You have the choice of either generating random punks, or a gif showing the interpolation between two random punk grids.",
|
| 105 |
article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
|
| 106 |
+
examples=[["interpolation", 100], ["interpolation", 500], ["image", 100], ["image", 500]],
|
| 107 |
).launch(cache_examples=True)
|