import gradio as gr import torch from torch import nn import torch.nn.functional as F import matplotlib.pyplot as plt from torchvision.utils import make_grid import torchvision.transforms.functional as func def get_one_hot_labels(labels, n_classes): return F.one_hot(labels%10, num_classes=n_classes) def get_input_dimensions(z_dim, mnist_shape, n_classes): generator_input_dim = z_dim+n_classes discriminator_im_chan = mnist_shape[0]+n_classes return generator_input_dim, discriminator_im_chan def combine_vectors(x, y): combined = torch.cat((x,y),dim=1).type(torch.float) return combined def get_noise(n_samples, input_dim, device='cpu'): return torch.randn(n_samples, input_dim, device=device) class Generator(nn.Module): def __init__(self, input_dim=10, im_chan=1, hidden_dim=64): super(Generator, self).__init__() self.input_dim = input_dim self.gen = nn.Sequential( self.make_gen_block(input_dim, hidden_dim * 4), self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.make_gen_block(hidden_dim * 2, hidden_dim), self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): if not final_layer: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) else: return nn.Sequential( nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), nn.Tanh(), ) def forward(self, noise): x = noise.view(len(noise), self.input_dim, 1, 1) return self.gen(x) z_dim = 64 n_classes=10 device="cpu" mnist_shape= (1,28,28) generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes) gen = Generator(input_dim=generator_input_dim).to(device) gen.load_state_dict(torch.load("gen.pt",map_location=torch.device('cpu'))) def generate(label): labels = torch.Tensor([label]).long() one_hot_labels = get_one_hot_labels(labels.to(device), n_classes) fake_noise = get_noise(1, z_dim, device=device) noise_and_labels = combine_vectors(fake_noise,one_hot_labels) fake = gen(noise_and_labels) fake_pil = func.to_pil_image(fake[0]) return fake_pil with gr.Blocks() as demo: gr.Markdown("# DCGANs Model") gr.Markdown("## This model generates handwritten digits 0-9. Select the digit you want to generate below.") with gr.Column(): label = gr.Slider(minimum=0, maximum=9, step=1, label="Select Digit") interface = gr.Interface(fn = generate, inputs=label, outputs=gr.Image(height=300, width=300)) demo.launch()