File size: 2,511 Bytes
93871a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#############################################################################
#
#   Source from:
#   https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life
#   Forked from:
#
#   Reimplemented by: Leonel Hernández
#
##############################################################################
import logging
import os.path

import PIL.Image
import cv2
import numpy as np
import torch
from torchvision.transforms import transforms

from scripts.erasescratches.models import Pix2PixHDModel_Mapping
from scripts.erasescratches.options import Options
from scripts.maskscratches import ScratchesDetector
from scripts.util import irregular_hole_synthesize, tensor_to_ndarray

REPO_ID = "leonelhs/zeroscratches"


class EraseScratches:

    def __init__(self):

        snapshot_folder = './models/zeroscratches'
        model_path = os.path.join(snapshot_folder, "restoration")
        self.detector = ScratchesDetector(snapshot_folder)
        gpu_ids = []
        if torch.cuda.is_available():
            gpu_ids = [d for d in range(torch.cuda.device_count())]
        self.options = Options(model_path, gpu_ids)
        self.model_scratches = Pix2PixHDModel_Mapping()
        self.model_scratches.initialize(self.options)
        self.model_scratches.eval()

    def erase(self, image) -> np.array:
        transformed, mask = self.detector.process(image)
        logging.info("Start erasing scratches")

        img_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        mask_transform = transforms.ToTensor()

        if self.options.mask_dilation != 0:
            kernel = np.ones((3, 3), np.uint8)
            mask = np.array(mask)
            mask = cv2.dilate(mask, kernel, iterations=self.options.mask_dilation)
            mask = PIL.Image.fromarray(mask.astype('uint8'))

        transformed = irregular_hole_synthesize(transformed, mask)
        mask = mask_transform(mask)
        mask = mask[:1, :, :]  # Convert to single channel
        mask = mask.unsqueeze(0)
        transformed = img_transform(transformed)
        transformed = transformed.unsqueeze(0)

        try:
            with torch.no_grad():
                generated = self.model_scratches.inference(transformed, mask)
        except Exception as ex:
            raise TypeError("Skip photo due to an error:\n%s" % str(ex))

        tensor_restored = (generated.data.cpu() + 1.0) / 2.0
        return tensor_to_ndarray(tensor_restored)