nosiy_model / app.py
Tracy7777's picture
Update app.py
d218eb2 verified
import torch
import numpy as np
import gradio as gr
from PIL import Image
from models import DehazeFormer
def infer(raw_image):
network = DehazeFormer()
network = torch.load('./best_t_model_weights.pth')
network.to(torch.device('cpu'))
network.load_state_dict(torch.load('./best_t_model_weights.pth'))
#network.load_state_dict(torch.load('./best_t_model_weights.pth', map_location=torch.device('cpu'))['state_dict'])
# torch.save({'state_dict': network.state_dict()}, './saved_models/dehazeformer.pth')
network.eval()
image = np.array(raw_image, np.float32) / 255. * 2 - 1
image = torch.from_numpy(image)
image = image.permute((2, 0, 1)).unsqueeze(0)
with torch.no_grad():
output = network(image).clamp_(-1, 1)[0] * 0.5 + 0.5
output = output.permute((1, 2, 0))
output = np.array(output, np.float32)
output = np.round(output * 255.0)
output = Image.fromarray(output.astype(np.uint8))
return output
title = "noisy_model"
#description = f"We use a mixed dataset to train the model, allowing the trained model to work better on real hazy images. To allow the model to process high-resolution images more efficiently and effectively, we extend it to the [MCT](https://github.com/IDKiro/MCT) variant."
iface = gr.Interface(
infer,
inputs="image", outputs="image",
title=title,
# description=description,
allow_flagging='never',
)
iface.launch()