MonetGAN / app.py
vishnuraggav's picture
model last
d646b93
raw
history blame contribute delete
928 Bytes
""" Import Libraries """
import torch
import torch.nn as nn
from model import Generator
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import gradio as gr
""" Loading Model """
state_dict_path = 'gen_monet_dict_1.pth'
model = Generator(3)
model.load_state_dict(torch.load(state_dict_path, map_location=torch.device('cpu')))
""" Init Transform """
augment = A.Compose([
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
ToTensorV2()
])
def main(image):
augmented = augment(image=image)
tensor_img = augmented['image']
with torch.inference_mode():
pred = model(tensor_img.unsqueeze(0))
pred = pred.squeeze(0).permute(1, 2, 0) * 0.5 + 0.5
return np.array(pred)
app = gr.Interface(
fn=main,
inputs=gr.Image(),
outputs=gr.Image(),
examples=['2.jpeg', '4.jpg', '5.jpeg', '6.jpg']
)
app.launch()