import torch import gradio as gr from torchvision.utils import save_image from torchvision.transforms import ToPILImage from model import Generator DEVICE = "cuda" if torch.cuda.is_available() else "cpu" NOISE_DIM = 256 G = Generator().to(DEVICE) G.load_state_dict(torch.load("generator.pth", map_location=DEVICE)) G.eval() to_pil = ToPILImage() def generate_image(): noise = torch.randn(1, NOISE_DIM).to(DEVICE) with torch.no_grad(): image = G(noise) image = (image + 1) / 2 return to_pil(image.squeeze(0)) demo = gr.Interface( fn=generate_image, inputs=None, outputs=gr.Image(), title="GAN Image Generator API" ) demo.launch()