ju4nppp's picture
Update app.py
e27d7b8 verified
raw
history blame
4.23 kB
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from PIL import Image
import os
import math
# Define your Generator architecture - with ngf=128 to match your training parameters
class Generator(nn.Module):
def __init__(self, ngpu=1, nz=100, ngf=128, nc=3):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# Load the model - Update path to point to the models folder
device = torch.device("cpu")
model_path = "models/netG_epoch_246.pth"
# Print file existence for debugging
print(f"Checking if model file exists: {os.path.exists(model_path)}")
print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
# Initialize the model with ngf=128 to match your training parameters
model = Generator(ngf=128).to(device)
# Try loading with error handling
try:
model.load_state_dict(torch.load(model_path, map_location=device))
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Try alternative loading methods if the first fails
try:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
print("Model loaded with strict=False")
except Exception as e2:
print(f"Error with alternative loading: {e2}")
# Set model to evaluation mode
model.eval()
print(f"Model initialized: {model is not None}")
def create_image_grid(images, rows, cols):
"""Create a grid of images"""
w, h = images[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
def generate_multiple_images(random_seed=42, num_images=4):
"""Generate multiple images using the DCGAN model"""
# Set seed for reproducibility
torch.manual_seed(random_seed)
# Generate multiple images
images = []
for i in range(num_images):
# Generate random noise with different seeds
noise = torch.randn(1, 100, 1, 1, device=device)
# Generate fake image
with torch.no_grad():
fake_image = model(noise).detach().cpu()
# Convert tensor to image
fake_img = fake_image * 0.5 + 0.5 # unnormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy()
fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8)
images.append(Image.fromarray(fake_img))
# Create a grid of images
rows = int(math.sqrt(num_images))
cols = int(math.ceil(num_images / rows))
grid = create_image_grid(images, rows, cols)
return grid
# Create Gradio interface
demo = gr.Interface(
fn=generate_multiple_images,
inputs=[
gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"),
gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images")
],
outputs=gr.Image(type="pil", label="Generated Computer Mice"),
title="DCGAN Computer Mouse Generator",
description="Generate multiple unique computer mouse designs using a DCGAN model.",
examples=[[42, 4], [23, 9], [7, 16]]
)
# Launch the app
if __name__ == "__main__":
demo.launch()