Klaus04's picture
Upload 3 files
7f22ed9 verified
import gradio as gr
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms.functional as func
def get_one_hot_labels(labels, n_classes):
return F.one_hot(labels%10, num_classes=n_classes)
def get_input_dimensions(z_dim, mnist_shape, n_classes):
generator_input_dim = z_dim+n_classes
discriminator_im_chan = mnist_shape[0]+n_classes
return generator_input_dim, discriminator_im_chan
def combine_vectors(x, y):
combined = torch.cat((x,y),dim=1).type(torch.float)
return combined
def get_noise(n_samples, input_dim, device='cpu'):
return torch.randn(n_samples, input_dim, device=device)
class Generator(nn.Module):
def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
super(Generator, self).__init__()
self.input_dim = input_dim
self.gen = nn.Sequential(
self.make_gen_block(input_dim, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Tanh(),
)
def forward(self, noise):
x = noise.view(len(noise), self.input_dim, 1, 1)
return self.gen(x)
z_dim = 64
n_classes=10
device="cpu"
mnist_shape= (1,28,28)
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)
gen = Generator(input_dim=generator_input_dim).to(device)
gen.load_state_dict(torch.load("gen.pt",map_location=torch.device('cpu')))
def generate(label):
labels = torch.Tensor([label]).long()
one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
fake_noise = get_noise(1, z_dim, device=device)
noise_and_labels = combine_vectors(fake_noise,one_hot_labels)
fake = gen(noise_and_labels)
fake_pil = func.to_pil_image(fake[0])
return fake_pil
with gr.Blocks() as demo:
gr.Markdown("# DCGANs Model")
gr.Markdown("## This model generates handwritten digits 0-9. Select the digit you want to generate below.")
with gr.Column():
label = gr.Slider(minimum=0, maximum=9, step=1, label="Select Digit")
interface = gr.Interface(fn = generate, inputs=label, outputs=gr.Image(height=300, width=300))
demo.launch()