backup / scripts /maskscratches /scratches_detector.py
killbill007's picture
Upload 754 files
93871a1 verified
#############################################################################
#
# Source from:
# https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life
# Forked from:
#
# Reimplemented by: Leonel Hernández
#
##############################################################################
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
import torchvision as tv
from PIL import ImageFile, Image
from scripts.maskscratches.detection_models import networks
from scripts.util import data_transforms, scale_tensor, tensor_to_ndarray
ImageFile.LOAD_TRUNCATED_IMAGES = True
class ScratchesDetector:
def __init__(self, snapshot_folder):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_mask = networks.UNet(
in_channels=1,
out_channels=1,
depth=4,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
)
model_path = os.path.join(snapshot_folder, "detection/FT_Epoch_latest.pt")
checkpoint = torch.load(model_path, map_location=device)
self.model_mask.load_state_dict(checkpoint["model_state"])
self.model_mask.cpu()
self.model_mask.eval()
def process(self, image: Image) -> np.array:
logging.info("Start detecting scratches")
transformed_image = data_transforms(image, size="full_size")
image = transformed_image.convert("L")
image = tv.transforms.ToTensor()(image)
image = tv.transforms.Normalize([0.5], [0.5])(image)
image = torch.unsqueeze(image, 0)
_, _, ow, oh = image.shape
scratch_image_scale = scale_tensor(image)
scratch_image_scale = scratch_image_scale.cpu()
with torch.no_grad():
prediction = torch.sigmoid(self.model_mask(scratch_image_scale))
prediction = prediction.data.cpu()
prediction = F.interpolate(prediction, [ow, oh], mode="nearest")
tensor_mask = (prediction >= 0.4).float()
scratches_mask_image = tensor_to_ndarray(tensor_mask)
transformed_image = np.array(transformed_image)
return transformed_image, scratches_mask_image