File size: 3,974 Bytes
fd4bbc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""
author: Min Seok Lee and Wooseok Shin
"""
import os
import cv2
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
from tqdm import tqdm
from dataloader import get_test_augmentation, get_loader
from model.TRACER import TRACER
from util.utils import load_pretrained
class Inference():
def __init__(self, args, save_path):
super(Inference, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.test_transform = get_test_augmentation(img_size=args.img_size)
self.args = args
self.save_path = save_path
# Network
self.model = TRACER(args).to(self.device)
if args.multi_gpu:
self.model = nn.DataParallel(self.model).to(self.device)
path = load_pretrained(f'TE-{args.arch}')
self.model.load_state_dict(path)
print('###### pre-trained Model restored #####')
te_img_folder = os.path.join(args.data_path, args.dataset)
te_gt_folder = None
self.test_loader = get_loader(te_img_folder, te_gt_folder, edge_folder=None, phase='test',
batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, transform=self.test_transform)
if args.save_map is not None:
os.makedirs(os.path.join('mask', self.args.dataset), exist_ok=True)
os.makedirs(os.path.join('object', self.args.dataset), exist_ok=True)
def test(self):
self.model.eval()
t = time.time()
with torch.no_grad():
for i, (images, original_size, image_name) in enumerate(tqdm(self.test_loader)):
images = torch.tensor(images, device=self.device, dtype=torch.float32)
outputs, edge_mask, ds_map = self.model(images)
H, W = original_size
for i in range(images.size(0)):
h, w = H[i].item(), W[i].item()
output = F.interpolate(outputs[i].unsqueeze(0), size=(h, w), mode='bilinear')
# Save prediction map
if self.args.save_map is not None:
output = (output.squeeze().detach().cpu().numpy() * 255.0).astype(np.uint8)
salient_object = self.post_processing(images[i], output, h, w)
cv2.imwrite(os.path.join('mask', self.args.dataset, image_name[i] + '.png'), output)
cv2.imwrite(os.path.join('object', self.args.dataset, image_name[i] + '.png'), salient_object)
print(f'time: {time.time() - t:.3f}s')
def post_processing(self, original_image, output_image, height, width, threshold=200):
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
std=[1., 1., 1.]),
])
original_image = invTrans(original_image)
original_image = F.interpolate(original_image.unsqueeze(0), size=(height, width), mode='bilinear')
original_image = (original_image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)
rgba_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2BGRA)
output_rbga_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2BGRA)
output_rbga_image[:, :, 3] = output_image # Extract edges
edge_y, edge_x, _ = np.where(output_rbga_image <= threshold) # Edge coordinates
rgba_image[edge_y, edge_x, 3] = 0
return cv2.cvtColor(rgba_image, cv2.COLOR_RGBA2BGRA)
|