Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| import math | |
| # Define your Generator architecture - with ngf=128 to match your training parameters | |
| class Generator(nn.Module): | |
| def __init__(self, ngpu=1, nz=100, ngf=128, nc=3): | |
| super(Generator, self).__init__() | |
| self.ngpu = ngpu | |
| self.main = nn.Sequential( | |
| # input is Z, going into a convolution | |
| nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(ngf * 8), | |
| nn.ReLU(True), | |
| # state size. (ngf*8) x 4 x 4 | |
| nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf * 4), | |
| nn.ReLU(True), | |
| # state size. (ngf*4) x 8 x 8 | |
| nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf * 2), | |
| nn.ReLU(True), | |
| # state size. (ngf*2) x 16 x 16 | |
| nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(ngf), | |
| nn.ReLU(True), | |
| # state size. (ngf) x 32 x 32 | |
| nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), | |
| nn.Tanh() | |
| # state size. (nc) x 64 x 64 | |
| ) | |
| def forward(self, input): | |
| return self.main(input) | |
| # Load the model - Update path to point to the models folder | |
| device = torch.device("cpu") | |
| model_path = "models/netG_epoch_246.pth" | |
| # Print file existence for debugging | |
| print(f"Checking if model file exists: {os.path.exists(model_path)}") | |
| print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") | |
| # Initialize the model with ngf=128 to match your training parameters | |
| model = Generator(ngf=128).to(device) | |
| # Try loading with error handling | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Try alternative loading methods if the first fails | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=device), strict=False) | |
| print("Model loaded with strict=False") | |
| except Exception as e2: | |
| print(f"Error with alternative loading: {e2}") | |
| # Set model to evaluation mode | |
| model.eval() | |
| print(f"Model initialized: {model is not None}") | |
| def create_image_grid(images, rows, cols): | |
| """Create a grid of images""" | |
| w, h = images[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| for i, image in enumerate(images): | |
| grid.paste(image, box=(i%cols*w, i//cols*h)) | |
| return grid | |
| def generate_multiple_images(random_seed=42, num_images=4): | |
| """Generate multiple images using the DCGAN model""" | |
| # Set seed for reproducibility | |
| torch.manual_seed(random_seed) | |
| # Generate multiple images | |
| images = [] | |
| for i in range(num_images): | |
| # Generate random noise with different seeds | |
| noise = torch.randn(1, 100, 1, 1, device=device) | |
| # Generate fake image | |
| with torch.no_grad(): | |
| fake_image = model(noise).detach().cpu() | |
| # Convert tensor to image | |
| fake_img = fake_image * 0.5 + 0.5 # unnormalize | |
| fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy() | |
| fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8) | |
| images.append(Image.fromarray(fake_img)) | |
| # Create a grid of images | |
| rows = int(math.sqrt(num_images)) | |
| cols = int(math.ceil(num_images / rows)) | |
| grid = create_image_grid(images, rows, cols) | |
| return grid | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_multiple_images, | |
| inputs=[ | |
| gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"), | |
| gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images") | |
| ], | |
| outputs=gr.Image(type="pil", label="Generated Computer Mice"), | |
| title="DCGAN Computer Mouse Generator", | |
| description="Generate multiple unique computer mouse designs using a DCGAN model.", | |
| examples=[[42, 4], [23, 9], [7, 16]] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |