File size: 4,186 Bytes
9279e0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import numpy as np
import cv2
import tempfile
import matplotlib.pyplot as plt
from cog import BasePredictor, Path, Input, BaseModel
from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
class Predictor(BasePredictor):
def setup(self):
opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
opt_denoise = parse(opt_path_denoise, is_train=False)
opt_denoise["dist"] = False
opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
opt_deblur = parse(opt_path_deblur, is_train=False)
opt_deblur["dist"] = False
opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
opt_stereo = parse(opt_path_stereo, is_train=False)
opt_stereo["dist"] = False
self.models = {
"Image Denoising": create_model(opt_denoise),
"Image Debluring": create_model(opt_deblur),
"Stereo Image Super-Resolution": create_model(opt_stereo),
}
def predict(
self,
task_type: str = Input(
choices=[
"Image Denoising",
"Image Debluring",
"Stereo Image Super-Resolution",
],
default="Image Debluring",
description="Choose task type.",
),
image: Path = Input(
description="Input image. Stereo Image Super-Resolution, upload the left image here.",
),
image_r: Path = Input(
default=None,
description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
" Image Super-Resolution task.",
),
) -> Path:
out_path = Path(tempfile.mkdtemp()) / "output.png"
model = self.models[task_type]
if task_type == "Stereo Image Super-Resolution":
assert image_r is not None, (
"Please provide both left and right input image for "
"Stereo Image Super-Resolution task."
)
img_l = imread(str(image))
inp_l = img2tensor(img_l)
img_r = imread(str(image_r))
inp_r = img2tensor(img_r)
stereo_image_inference(model, inp_l, inp_r, str(out_path))
else:
img_input = imread(str(image))
inp = img2tensor(img_input)
out_path = Path(tempfile.mkdtemp()) / "output.png"
single_image_inference(model, inp, str(out_path))
return out_path
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def img2tensor(img, bgr2rgb=False, float32=True):
img = img.astype(np.float32) / 255.0
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
def single_image_inference(model, img, save_path):
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
if model.opt["val"].get("grids", False):
model.grids()
model.test()
if model.opt["val"].get("grids", False):
model.grids_inverse()
visuals = model.get_current_visuals()
sr_img = tensor2img([visuals["result"]])
imwrite(sr_img, save_path)
def stereo_image_inference(model, img_l, img_r, out_path):
img = torch.cat([img_l, img_r], dim=0)
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
if model.opt["val"].get("grids", False):
model.grids()
model.test()
if model.opt["val"].get("grids", False):
model.grids_inverse()
visuals = model.get_current_visuals()
img_L = visuals["result"][:, :3]
img_R = visuals["result"][:, 3:]
img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)
# save_stereo_image
h, w = img_L.shape[:2]
fig = plt.figure(figsize=(w // 40, h // 40))
ax1 = fig.add_subplot(2, 1, 1)
plt.title("NAFSSR output (Left)", fontsize=14)
ax1.axis("off")
ax1.imshow(img_L)
ax2 = fig.add_subplot(2, 1, 2)
plt.title("NAFSSR output (Right)", fontsize=14)
ax2.axis("off")
ax2.imshow(img_R)
plt.subplots_adjust(hspace=0.08)
plt.savefig(str(out_path), bbox_inches="tight", dpi=600)
|