Spaces:
Build error
Build error
| """ Import Libraries """ | |
| import torch | |
| import torch.nn as nn | |
| from model import Generator | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import numpy as np | |
| import gradio as gr | |
| """ Loading Model """ | |
| state_dict_path = 'gen_monet_dict_1.pth' | |
| model = Generator(3) | |
| model.load_state_dict(torch.load(state_dict_path, map_location=torch.device('cpu'))) | |
| """ Init Transform """ | |
| augment = A.Compose([ | |
| A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0), | |
| ToTensorV2() | |
| ]) | |
| def main(image): | |
| augmented = augment(image=image) | |
| tensor_img = augmented['image'] | |
| with torch.inference_mode(): | |
| pred = model(tensor_img.unsqueeze(0)) | |
| pred = pred.squeeze(0).permute(1, 2, 0) * 0.5 + 0.5 | |
| return np.array(pred) | |
| app = gr.Interface( | |
| fn=main, | |
| inputs=gr.Image(), | |
| outputs=gr.Image(), | |
| examples=['2.jpeg', '4.jpg', '5.jpeg', '6.jpg'] | |
| ) | |
| app.launch() | |