Spaces:
Runtime error
Runtime error
| import os | |
| os.system('mim install mmcv') | |
| import numpy as np | |
| import models | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| def construct_sample(img, mean=0.5, std=0.5): | |
| img = transforms.ToTensor()(img) | |
| img = transforms.Resize((48, 48), InterpolationMode.BICUBIC)(img) | |
| img = transforms.Normalize(mean, std)(img) | |
| return img | |
| def build_model(cp): | |
| model_spec = torch.load(cp, map_location='cpu')['model'] | |
| print(model_spec['args']) | |
| model = models.make(model_spec, load_sd=True).to(device) | |
| return model | |
| # Function for building extraction | |
| def sr_func(img, cp, scale): | |
| if cp == 'UC': | |
| checkpoint = 'pretrain/UC_FunSR_RDN.pth' | |
| elif cp == 'AID': | |
| checkpoint = 'pretrain/AID_FunSR_RDN.pth' | |
| else: | |
| raise NotImplementedError | |
| sample = construct_sample(img) | |
| print('Use: ', device) | |
| model = build_model(checkpoint) | |
| model.eval() | |
| sample = sample.to(device) | |
| sample = sample.unsqueeze(0) | |
| ori_size = torch.tensor(sample.shape[2:]) # BCHW | |
| target_size = ori_size * scale | |
| target_size = target_size.long() | |
| lr_target_size_img = torch.nn.functional.interpolate(sample, scale_factor=scale, mode='nearest') | |
| with torch.no_grad(): | |
| pred = model(sample, target_size.tolist()) | |
| if isinstance(pred, list): | |
| pred = pred[-1] | |
| pred = pred * 0.5 + 0.5 | |
| pred *= 255 | |
| pred = pred[0].detach().cpu() | |
| lr_target_size_img = lr_target_size_img * 0.5 + 0.5 | |
| lr_target_size_img = 255 * lr_target_size_img[0].detach().cpu() | |
| lr_target_size_img = torch.clamp(lr_target_size_img, 0, 255).permute(1,2,0).numpy().astype(np.uint8) | |
| pred = torch.clamp(pred, 0, 255).permute(1,2,0).numpy().astype(np.uint8) | |
| line = np.ones((pred.shape[0], 5, 3), dtype=np.uint8) * 255 | |
| pred = np.concatenate((lr_target_size_img, line, pred), axis=1) | |
| return pred | |
| title = "FunSR" | |
| description = "Gradio demo for continuous remote sensing image super-resolution. Upload image from UCMerced or AID Dataset or click any one of the examples, " \ | |
| "Then change the upscaling magnification, and click \"Submit\" and wait for the super-resolved result. \n" \ | |
| "Paper: Continuous Remote Sensing Image Super-Resolution based on Context Interaction in Implicit Function Space" | |
| article = "<p style='text-align: center'><a href='https://kyanchen.github.io/FunSR/' target='_blank'>FunSR Project " \ | |
| "Page</a></p> " | |
| default_scale = 2.0 | |
| examples = [ | |
| ['examples/AID_school_161_LR.png', 'AID', default_scale], | |
| ['examples/AID_bridge_19_LR.png', 'AID', default_scale], | |
| ['examples/AID_parking_60_LR.png', 'AID', default_scale], | |
| ['examples/AID_commercial_32_LR.png', 'AID', default_scale], | |
| ['examples/UC_airplane95_LR.png', 'UC', default_scale], | |
| ['examples/UC_freeway35_LR.png', 'UC', default_scale], | |
| ['examples/UC_storagetanks54_LR.png', 'UC', default_scale], | |
| ['examples/UC_airplane00_LR.png', 'UC', default_scale], | |
| ] | |
| with gr.Blocks() as demo: | |
| image_input = gr.Image(type='pil', label='Input Img') | |
| # with gr.Row().style(equal_height=True): | |
| # image_LR_output = gr.outputs.Image(label='LR Img', type='numpy') | |
| image_output = gr.Image(label='SR Result', type='numpy') | |
| with gr.Row(): | |
| checkpoint = gr.Radio(['UC', 'AID'], label='Checkpoint') | |
| scale = gr.Slider(1, 10, value=4.0, step=0.1, label='scale') | |
| io = gr.Interface(fn=sr_func, | |
| inputs=[image_input, | |
| checkpoint, | |
| scale | |
| ], | |
| outputs=[ | |
| # image_LR_output, | |
| image_output | |
| ], | |
| title=title, | |
| description=description, | |
| article=article, | |
| allow_flagging='auto', | |
| examples=examples, | |
| cache_examples=True, | |
| ) | |
| io.launch() | |