File size: 676 Bytes
33fa605 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
import gradio as gr
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from model import Generator
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NOISE_DIM = 256
G = Generator().to(DEVICE)
G.load_state_dict(torch.load("generator.pth", map_location=DEVICE))
G.eval()
to_pil = ToPILImage()
def generate_image():
noise = torch.randn(1, NOISE_DIM).to(DEVICE)
with torch.no_grad():
image = G(noise)
image = (image + 1) / 2
return to_pil(image.squeeze(0))
demo = gr.Interface(
fn=generate_image,
inputs=None,
outputs=gr.Image(),
title="GAN Image Generator API"
)
demo.launch() |