| import torch | |
| import gradio as gr | |
| from torchvision.utils import save_image | |
| from torchvision.transforms import ToPILImage | |
| from model import Generator | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| NOISE_DIM = 256 | |
| G = Generator().to(DEVICE) | |
| G.load_state_dict(torch.load("generator.pth", map_location=DEVICE)) | |
| G.eval() | |
| to_pil = ToPILImage() | |
| def generate_image(): | |
| noise = torch.randn(1, NOISE_DIM).to(DEVICE) | |
| with torch.no_grad(): | |
| image = G(noise) | |
| image = (image + 1) / 2 | |
| return to_pil(image.squeeze(0)) | |
| demo = gr.Interface( | |
| fn=generate_image, | |
| inputs=None, | |
| outputs=gr.Image(), | |
| title="GAN Image Generator API" | |
| ) | |
| demo.launch() |