import torch import torch.nn as nn import albumentations as A from albumentations.pytorch import ToTensorV2 from model import Generator import numpy as np from PIL import Image import gradio as gr model = Generator(3) model_path = 'state_dict.pth' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) transform = A.Compose([ A.Resize(width=256, height=256), A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2() ]) def main(image): augmented = transform(image=image) tensor_img = augmented['image'] with torch.inference_mode(): pred = model(tensor_img.unsqueeze(0)) pred = pred.squeeze(0).permute(1, 2, 0).numpy() return pred app = gr.Interface( fn=main, inputs=gr.Image(), outputs=gr.Image(), examples=['1.jpg', '2.jpg', '3.jpg', '4.jpg'] ) app.launch()