| from typing import Union
|
| import cv2
|
| import torch
|
| import numpy as np
|
| from torch import nn
|
| from torchvision import transforms as T
|
|
|
|
|
| class SRCNN(nn.Module):
|
| def __init__(
|
| self,
|
| input_channels=3,
|
| output_channels=3,
|
| input_size=33,
|
| label_size=21,
|
| scale=2,
|
| device=None,
|
| ):
|
| super().__init__()
|
| self.input_size = input_size
|
| self.label_size = label_size
|
| self.pad = (self.input_size - self.label_size) // 2
|
| self.scale = scale
|
| self.model = nn.Sequential(
|
| nn.Conv2d(input_channels, 64, 9),
|
| nn.ReLU(),
|
| nn.Conv2d(64, 32, 1),
|
| nn.ReLU(),
|
| nn.Conv2d(32, output_channels, 5),
|
| nn.ReLU(),
|
| )
|
| self.transform = T.Compose(
|
| [T.ToTensor()]
|
| )
|
|
|
| if device is None:
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.device = device
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.model(x)
|
|
|
| @torch.no_grad()
|
| def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
| if torch.is_tensor(x):
|
| return x / 255.0
|
| else:
|
| return self.transform(x)
|
|
|
| @torch.no_grad()
|
| def post_process(self, x: torch.Tensor) -> torch.Tensor:
|
| return x.clip(0, 1) * 255.0
|
|
|
| @torch.no_grad()
|
| def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray:
|
| (h, w) = image.shape[:2]
|
| scale_w = int((w - w % self.label_size + self.input_size) * self.scale)
|
| scale_h = int((h - h % self.label_size + self.input_size) * self.scale)
|
|
|
| scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC)
|
|
|
| in_tensor = self.pre_process(scaled)
|
| out_tensor = torch.zeros_like(in_tensor)
|
|
|
|
|
| for y in range(0, scale_h - self.input_size + 1, self.label_size):
|
| for x in range(0, scale_w - self.input_size + 1, self.label_size):
|
|
|
| crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size]
|
|
|
| crop_inp = crop.unsqueeze(0).to(self.device)
|
| pred = self.forward(crop_inp).cpu().squeeze()
|
| out_tensor[
|
| :,
|
| y + self.pad : y + self.pad + self.label_size,
|
| x + self.pad : x + self.pad + self.label_size,
|
| ] = pred
|
|
|
| out_tensor = self.post_process(out_tensor)
|
| output = out_tensor.permute(1, 2, 0).numpy()
|
| output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2]
|
| output = np.clip(output, 0, 255).astype("uint8")
|
|
|
|
|
| if outscale != 2:
|
| interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4
|
| h, w = output.shape[0:2]
|
| output = cv2.resize(
|
| output,
|
| (int(w * outscale / 2), int(h * outscale / 2)),
|
| interpolation=interpolation,
|
| )
|
|
|
| return output, None
|
|
|