import torch import numpy as np import gradio as gr from PIL import Image from models import DehazeFormer def infer(raw_image): network = DehazeFormer() network = torch.load('./best_t_model_weights.pth') network.to(torch.device('cpu')) network.load_state_dict(torch.load('./best_t_model_weights.pth')) #network.load_state_dict(torch.load('./best_t_model_weights.pth', map_location=torch.device('cpu'))['state_dict']) # torch.save({'state_dict': network.state_dict()}, './saved_models/dehazeformer.pth') network.eval() image = np.array(raw_image, np.float32) / 255. * 2 - 1 image = torch.from_numpy(image) image = image.permute((2, 0, 1)).unsqueeze(0) with torch.no_grad(): output = network(image).clamp_(-1, 1)[0] * 0.5 + 0.5 output = output.permute((1, 2, 0)) output = np.array(output, np.float32) output = np.round(output * 255.0) output = Image.fromarray(output.astype(np.uint8)) return output title = "noisy_model" #description = f"We use a mixed dataset to train the model, allowing the trained model to work better on real hazy images. To allow the model to process high-resolution images more efficiently and effectively, we extend it to the [MCT](https://github.com/IDKiro/MCT) variant." iface = gr.Interface( infer, inputs="image", outputs="image", title=title, # description=description, allow_flagging='never', ) iface.launch()