Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.utils as vutils | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Define Generator architecture - must match what you used during training | |
| class Generator(nn.Module): | |
| def __init__(self, ngpu=1, nz=100, ngf=64, nc=3): | |
| super(Generator, self).__init__() | |
| self.ngpu = ngpu | |
| self.main = nn.Sequential( | |
| # input is Z, going into a convolution | |
| nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(ngf * 8), | |
| nn.ReLU(True), | |
| # state size. (ngf*8) x 4 x 4 | |
| nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf * 4), | |
| nn.ReLU(True), | |
| # state size. (ngf*4) x 8 x 8 | |
| nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf * 2), | |
| nn.ReLU(True), | |
| # state size. (ngf*2) x 16 x 16 | |
| nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf), | |
| nn.ReLU(True), | |
| # state size. (ngf) x 32 x 32 | |
| nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), | |
| nn.Tanh() | |
| # state size. (nc) x 64 x 64 | |
| ) | |
| def forward(self, input): | |
| return self.main(input) | |
| # Load the generator | |
| def load_model(model_path="model/netG_best.pth"): | |
| # Create the generator and load the saved weights | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device) | |
| try: | |
| netG.load_state_dict(torch.load(model_path, map_location=device)) | |
| netG.eval() # Set to evaluation mode | |
| print(f"Model loaded successfully from {model_path}") | |
| return netG, device | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, device | |
| # Generate images using the model | |
| def generate_images(num_images=16, seed=None, randomize=True): | |
| # Load the model (do this once when needed) | |
| global model, device | |
| if 'model' not in globals(): | |
| model, device = load_model() | |
| if model is None: | |
| return np.zeros((299, 299, 3)) | |
| # Set random seed for reproducibility if provided | |
| if seed is not None and not randomize: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Generate latent vectors | |
| nz = 100 # Size of the latent vector | |
| noise = torch.randn(num_images, nz, 1, 1, device=device) | |
| # Generate fake images | |
| with torch.no_grad(): | |
| fake_images = model(noise).detach().cpu() | |
| # Convert to grid for display | |
| grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images))) | |
| # Convert from tensor to numpy array for Gradio | |
| grid_np = grid.numpy().transpose((1, 2, 0)) | |
| # Make sure values are in 0-1 range | |
| grid_np = np.clip(grid_np, 0, 1) | |
| return grid_np | |
| # Create Gradio interface | |
| def create_gradio_app(): | |
| with gr.Blocks(title="Computer Mouse Generator") as app: | |
| gr.Markdown("# Computer Mouse GAN Generator") | |
| gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images") | |
| with gr.Row(): | |
| with gr.Column(): | |
| num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images") | |
| seed = gr.Number(label="Random Seed", value=42, precision=0) | |
| randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True) | |
| generate_button = gr.Button("Generate Mice") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Computer Mice") | |
| generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image) | |
| gr.Markdown("## About") | |
| gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images. | |
| The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments. | |
| The generator creates brand new, never-before-seen computer mice from random noise!""") | |
| return app | |
| # Initialize global variables | |
| model = None | |
| device = None | |
| # Launch the app if the script is run directly | |
| if __name__ == "__main__": | |
| app = create_gradio_app() | |
| app.launch() |