ma4389's picture
Update app.py
2c7c815 verified
import torch
import torch.nn as nn
import gradio as gr
from torchvision.utils import make_grid
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Fixed Generator definition (with bias=True)
class Generator(nn.Module):
def __init__(self, z_dim=100, channels_img=3, features_g=64):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0, bias=True),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=True),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=True),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=True),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, channels_img, 4, 2, 1, bias=True),
nn.Tanh()
)
def forward(self, x):
return self.net(x)
# Load generator
z_dim = 100
generator = Generator(z_dim=z_dim).to(device)
generator.load_state_dict(torch.load("generator_best.pth", map_location=device))
generator.eval()
# Image generation function
def generate_image(seed: int = 42):
torch.manual_seed(seed)
noise = torch.randn(1, z_dim, 1, 1, device=device)
with torch.no_grad():
fake_image = generator(noise).cpu()
# Convert to PIL image
img_tensor = (fake_image + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
img_tensor = img_tensor.squeeze(0)
to_pil = transforms.ToPILImage()
img_pil = to_pil(img_tensor)
return img_pil
# Gradio UI
gr.Interface(
fn=generate_image,
inputs=gr.Number(value=42, label="Random Seed"),
outputs=gr.Image(type="pil"),
title="DCGAN Image Generator",
description="Generate fake images using your trained DCGAN Generator"
).launch()