Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from torchvision.utils import make_grid | |
| import torchvision.transforms.functional as func | |
| def get_one_hot_labels(labels, n_classes): | |
| return F.one_hot(labels%10, num_classes=n_classes) | |
| def get_input_dimensions(z_dim, mnist_shape, n_classes): | |
| generator_input_dim = z_dim+n_classes | |
| discriminator_im_chan = mnist_shape[0]+n_classes | |
| return generator_input_dim, discriminator_im_chan | |
| def combine_vectors(x, y): | |
| combined = torch.cat((x,y),dim=1).type(torch.float) | |
| return combined | |
| def get_noise(n_samples, input_dim, device='cpu'): | |
| return torch.randn(n_samples, input_dim, device=device) | |
| class Generator(nn.Module): | |
| def __init__(self, input_dim=10, im_chan=1, hidden_dim=64): | |
| super(Generator, self).__init__() | |
| self.input_dim = input_dim | |
| self.gen = nn.Sequential( | |
| self.make_gen_block(input_dim, hidden_dim * 4), | |
| self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), | |
| self.make_gen_block(hidden_dim * 2, hidden_dim), | |
| self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), | |
| ) | |
| def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): | |
| if not final_layer: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
| nn.BatchNorm2d(output_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| else: | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, noise): | |
| x = noise.view(len(noise), self.input_dim, 1, 1) | |
| return self.gen(x) | |
| z_dim = 64 | |
| n_classes=10 | |
| device="cpu" | |
| mnist_shape= (1,28,28) | |
| generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes) | |
| gen = Generator(input_dim=generator_input_dim).to(device) | |
| gen.load_state_dict(torch.load("gen.pt",map_location=torch.device('cpu'))) | |
| def generate(label): | |
| labels = torch.Tensor([label]).long() | |
| one_hot_labels = get_one_hot_labels(labels.to(device), n_classes) | |
| fake_noise = get_noise(1, z_dim, device=device) | |
| noise_and_labels = combine_vectors(fake_noise,one_hot_labels) | |
| fake = gen(noise_and_labels) | |
| fake_pil = func.to_pil_image(fake[0]) | |
| return fake_pil | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# DCGANs Model") | |
| gr.Markdown("## This model generates handwritten digits 0-9. Select the digit you want to generate below.") | |
| with gr.Column(): | |
| label = gr.Slider(minimum=0, maximum=9, step=1, label="Select Digit") | |
| interface = gr.Interface(fn = generate, inputs=label, outputs=gr.Image(height=300, width=300)) | |
| demo.launch() | |