Upload 7 files
Browse files- Inference.py +53 -0
- basics.py +74 -0
- data_loader_cache.py +385 -0
- hce_metric_main.py +188 -0
- pytorch18.yml +92 -0
- requirements.txt +87 -4
- train_valid_inference_main.py +731 -0
Inference.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from skimage import io
|
| 5 |
+
import time
|
| 6 |
+
from glob import glob
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
import torch, gc
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torchvision.transforms.functional import normalize
|
| 15 |
+
|
| 16 |
+
from models import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
dataset_path="../demo_datasets/your_dataset" #Your dataset path
|
| 21 |
+
model_path="../saved_models/IS-Net/isnet-general-use.pth" # the model path
|
| 22 |
+
result_path="../demo_datasets/your_dataset_result" #The folder path that you want to save the results
|
| 23 |
+
input_size=[1024,1024]
|
| 24 |
+
net=ISNetDIS()
|
| 25 |
+
|
| 26 |
+
if torch.cuda.is_available():
|
| 27 |
+
net.load_state_dict(torch.load(model_path))
|
| 28 |
+
net=net.cuda()
|
| 29 |
+
else:
|
| 30 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
| 31 |
+
net.eval()
|
| 32 |
+
im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
|
| 35 |
+
print("im_path: ", im_path)
|
| 36 |
+
im = io.imread(im_path)
|
| 37 |
+
if len(im.shape) < 3:
|
| 38 |
+
im = im[:, :, np.newaxis]
|
| 39 |
+
im_shp=im.shape[0:2]
|
| 40 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
| 41 |
+
im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
|
| 42 |
+
image = torch.divide(im_tensor,255.0)
|
| 43 |
+
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
| 44 |
+
|
| 45 |
+
if torch.cuda.is_available():
|
| 46 |
+
image=image.cuda()
|
| 47 |
+
result=net(image)
|
| 48 |
+
result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
|
| 49 |
+
ma = torch.max(result)
|
| 50 |
+
mi = torch.min(result)
|
| 51 |
+
result = (result-mi)/(ma-mi)
|
| 52 |
+
im_name=im_path.split('/')[-1].split('.')[0]
|
| 53 |
+
io.imsave(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))
|
basics.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
| 3 |
+
from skimage import io, transform
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from torchvision import transforms, utils
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import glob
|
| 17 |
+
|
| 18 |
+
def mae_torch(pred,gt):
|
| 19 |
+
|
| 20 |
+
h,w = gt.shape[0:2]
|
| 21 |
+
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
|
| 22 |
+
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
|
| 23 |
+
|
| 24 |
+
return maeError
|
| 25 |
+
|
| 26 |
+
def f1score_torch(pd,gt):
|
| 27 |
+
|
| 28 |
+
# print(gt.shape)
|
| 29 |
+
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
|
| 30 |
+
|
| 31 |
+
pp = pd[gt>128]
|
| 32 |
+
nn = pd[gt<=128]
|
| 33 |
+
|
| 34 |
+
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
|
| 35 |
+
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
pp_hist_flip = torch.flipud(pp_hist)
|
| 39 |
+
nn_hist_flip = torch.flipud(nn_hist)
|
| 40 |
+
|
| 41 |
+
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
| 42 |
+
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
| 43 |
+
|
| 44 |
+
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
|
| 45 |
+
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
|
| 46 |
+
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
|
| 47 |
+
|
| 48 |
+
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
|
| 52 |
+
|
| 53 |
+
import time
|
| 54 |
+
tic = time.time()
|
| 55 |
+
|
| 56 |
+
if(len(gt.shape)>2):
|
| 57 |
+
gt = gt[:,:,0]
|
| 58 |
+
|
| 59 |
+
pre, rec, f1 = f1score_torch(pred,gt)
|
| 60 |
+
mae = mae_torch(pred,gt)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
|
| 64 |
+
if(hypar["valid_out_dir"]!=""):
|
| 65 |
+
if(not os.path.exists(hypar["valid_out_dir"])):
|
| 66 |
+
os.mkdir(hypar["valid_out_dir"])
|
| 67 |
+
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
|
| 68 |
+
if(not os.path.exists(dataset_folder)):
|
| 69 |
+
os.mkdir(dataset_folder)
|
| 70 |
+
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
|
| 71 |
+
print(valid_dataset.dataset["im_name"][idx]+".png")
|
| 72 |
+
print("time for evaluation : ", time.time()-tic)
|
| 73 |
+
|
| 74 |
+
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
|
data_loader_cache.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## data loader
|
| 2 |
+
## Ackownledgement:
|
| 3 |
+
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
|
| 4 |
+
## for his helps in implementing cache machanism of our DIS dataloader.
|
| 5 |
+
from __future__ import print_function, division
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import json
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from skimage import io
|
| 13 |
+
import os
|
| 14 |
+
from glob import glob
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
from torchvision import transforms, utils
|
| 19 |
+
from torchvision.transforms.functional import normalize
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
#### --------------------- DIS dataloader cache ---------------------####
|
| 23 |
+
|
| 24 |
+
def get_im_gt_name_dict(datasets, flag='valid'):
|
| 25 |
+
print("------------------------------", flag, "--------------------------------")
|
| 26 |
+
name_im_gt_list = []
|
| 27 |
+
for i in range(len(datasets)):
|
| 28 |
+
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
|
| 29 |
+
tmp_im_list, tmp_gt_list = [], []
|
| 30 |
+
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
|
| 31 |
+
|
| 32 |
+
# img_name_dict[im_dirs[i][0]] = tmp_im_list
|
| 33 |
+
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
|
| 34 |
+
|
| 35 |
+
if(datasets[i]["gt_dir"]==""):
|
| 36 |
+
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
| 37 |
+
tmp_gt_list = []
|
| 38 |
+
else:
|
| 39 |
+
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
| 40 |
+
|
| 41 |
+
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
| 42 |
+
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if flag=="train": ## combine multiple training sets into one dataset
|
| 46 |
+
if len(name_im_gt_list)==0:
|
| 47 |
+
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
| 48 |
+
"im_path":tmp_im_list,
|
| 49 |
+
"gt_path":tmp_gt_list,
|
| 50 |
+
"im_ext":datasets[i]["im_ext"],
|
| 51 |
+
"gt_ext":datasets[i]["gt_ext"],
|
| 52 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
| 53 |
+
else:
|
| 54 |
+
name_im_gt_list[0]["dataset_name"] = name_im_gt_list[0]["dataset_name"] + "_" + datasets[i]["name"]
|
| 55 |
+
name_im_gt_list[0]["im_path"] = name_im_gt_list[0]["im_path"] + tmp_im_list
|
| 56 |
+
name_im_gt_list[0]["gt_path"] = name_im_gt_list[0]["gt_path"] + tmp_gt_list
|
| 57 |
+
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
|
| 58 |
+
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
|
| 59 |
+
exit()
|
| 60 |
+
name_im_gt_list[0]["im_ext"] = ".jpg"
|
| 61 |
+
name_im_gt_list[0]["gt_ext"] = ".png"
|
| 62 |
+
name_im_gt_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_list[0]["dataset_name"]
|
| 63 |
+
else: ## keep different validation or inference datasets as separate ones
|
| 64 |
+
name_im_gt_list.append({"dataset_name":datasets[i]["name"],
|
| 65 |
+
"im_path":tmp_im_list,
|
| 66 |
+
"gt_path":tmp_gt_list,
|
| 67 |
+
"im_ext":datasets[i]["im_ext"],
|
| 68 |
+
"gt_ext":datasets[i]["gt_ext"],
|
| 69 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
| 70 |
+
|
| 71 |
+
return name_im_gt_list
|
| 72 |
+
|
| 73 |
+
def create_dataloaders(name_im_gt_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False):
|
| 74 |
+
## model="train": return one dataloader for training
|
| 75 |
+
## model="valid": return a list of dataloaders for validation or testing
|
| 76 |
+
|
| 77 |
+
gos_dataloaders = []
|
| 78 |
+
gos_datasets = []
|
| 79 |
+
|
| 80 |
+
if(len(name_im_gt_list)==0):
|
| 81 |
+
return gos_dataloaders, gos_datasets
|
| 82 |
+
|
| 83 |
+
num_workers_ = 1
|
| 84 |
+
if(batch_size>1):
|
| 85 |
+
num_workers_ = 2
|
| 86 |
+
if(batch_size>4):
|
| 87 |
+
num_workers_ = 4
|
| 88 |
+
if(batch_size>8):
|
| 89 |
+
num_workers_ = 8
|
| 90 |
+
|
| 91 |
+
for i in range(0,len(name_im_gt_list)):
|
| 92 |
+
gos_dataset = GOSDatasetCache([name_im_gt_list[i]],
|
| 93 |
+
cache_size = cache_size,
|
| 94 |
+
cache_path = name_im_gt_list[i]["cache_dir"],
|
| 95 |
+
cache_boost = cache_boost,
|
| 96 |
+
transform = transforms.Compose(my_transforms))
|
| 97 |
+
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
|
| 98 |
+
gos_datasets.append(gos_dataset)
|
| 99 |
+
|
| 100 |
+
return gos_dataloaders, gos_datasets
|
| 101 |
+
|
| 102 |
+
def im_reader(im_path):
|
| 103 |
+
return io.imread(im_path)
|
| 104 |
+
|
| 105 |
+
def im_preprocess(im,size):
|
| 106 |
+
if len(im.shape) < 3:
|
| 107 |
+
im = im[:, :, np.newaxis]
|
| 108 |
+
if im.shape[2] == 1:
|
| 109 |
+
im = np.repeat(im, 3, axis=2)
|
| 110 |
+
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
| 111 |
+
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
| 112 |
+
if(len(size)<2):
|
| 113 |
+
return im_tensor, im.shape[0:2]
|
| 114 |
+
else:
|
| 115 |
+
im_tensor = torch.unsqueeze(im_tensor,0)
|
| 116 |
+
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
|
| 117 |
+
im_tensor = torch.squeeze(im_tensor,0)
|
| 118 |
+
|
| 119 |
+
return im_tensor.type(torch.uint8), im.shape[0:2]
|
| 120 |
+
|
| 121 |
+
def gt_preprocess(gt,size):
|
| 122 |
+
if len(gt.shape) > 2:
|
| 123 |
+
gt = gt[:, :, 0]
|
| 124 |
+
|
| 125 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
|
| 126 |
+
|
| 127 |
+
if(len(size)<2):
|
| 128 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
| 129 |
+
else:
|
| 130 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
|
| 131 |
+
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
|
| 132 |
+
gt_tensor = torch.squeeze(gt_tensor,0)
|
| 133 |
+
|
| 134 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
| 135 |
+
# return gt_tensor, gt.shape[0:2]
|
| 136 |
+
|
| 137 |
+
class GOSRandomHFlip(object):
|
| 138 |
+
def __init__(self,prob=0.5):
|
| 139 |
+
self.prob = prob
|
| 140 |
+
def __call__(self,sample):
|
| 141 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
| 142 |
+
|
| 143 |
+
# random horizontal flip
|
| 144 |
+
if random.random() >= self.prob:
|
| 145 |
+
image = torch.flip(image,dims=[2])
|
| 146 |
+
label = torch.flip(label,dims=[2])
|
| 147 |
+
|
| 148 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
| 149 |
+
|
| 150 |
+
class GOSResize(object):
|
| 151 |
+
def __init__(self,size=[320,320]):
|
| 152 |
+
self.size = size
|
| 153 |
+
def __call__(self,sample):
|
| 154 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
| 155 |
+
|
| 156 |
+
# import time
|
| 157 |
+
# start = time.time()
|
| 158 |
+
|
| 159 |
+
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
|
| 160 |
+
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
|
| 161 |
+
|
| 162 |
+
# print("time for resize: ", time.time()-start)
|
| 163 |
+
|
| 164 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
| 165 |
+
|
| 166 |
+
class GOSRandomCrop(object):
|
| 167 |
+
def __init__(self,size=[288,288]):
|
| 168 |
+
self.size = size
|
| 169 |
+
def __call__(self,sample):
|
| 170 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
| 171 |
+
|
| 172 |
+
h, w = image.shape[1:]
|
| 173 |
+
new_h, new_w = self.size
|
| 174 |
+
|
| 175 |
+
top = np.random.randint(0, h - new_h)
|
| 176 |
+
left = np.random.randint(0, w - new_w)
|
| 177 |
+
|
| 178 |
+
image = image[:,top:top+new_h,left:left+new_w]
|
| 179 |
+
label = label[:,top:top+new_h,left:left+new_w]
|
| 180 |
+
|
| 181 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class GOSNormalize(object):
|
| 185 |
+
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
|
| 186 |
+
self.mean = mean
|
| 187 |
+
self.std = std
|
| 188 |
+
|
| 189 |
+
def __call__(self,sample):
|
| 190 |
+
|
| 191 |
+
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape']
|
| 192 |
+
image = normalize(image,self.mean,self.std)
|
| 193 |
+
|
| 194 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class GOSDatasetCache(Dataset):
|
| 198 |
+
|
| 199 |
+
def __init__(self, name_im_gt_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None):
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
self.cache_size = cache_size
|
| 203 |
+
self.cache_path = cache_path
|
| 204 |
+
self.cache_file_name = cache_file_name
|
| 205 |
+
self.cache_boost_name = ""
|
| 206 |
+
|
| 207 |
+
self.cache_boost = cache_boost
|
| 208 |
+
# self.ims_npy = None
|
| 209 |
+
# self.gts_npy = None
|
| 210 |
+
|
| 211 |
+
## cache all the images and ground truth into a single pytorch tensor
|
| 212 |
+
self.ims_pt = None
|
| 213 |
+
self.gts_pt = None
|
| 214 |
+
|
| 215 |
+
## we will cache the npy as well regardless of the cache_boost
|
| 216 |
+
# if(self.cache_boost):
|
| 217 |
+
self.cache_boost_name = cache_file_name.split('.json')[0]
|
| 218 |
+
|
| 219 |
+
self.transform = transform
|
| 220 |
+
|
| 221 |
+
self.dataset = {}
|
| 222 |
+
|
| 223 |
+
## combine different datasets into one
|
| 224 |
+
dataset_names = []
|
| 225 |
+
dt_name_list = [] # dataset name per image
|
| 226 |
+
im_name_list = [] # image name
|
| 227 |
+
im_path_list = [] # im path
|
| 228 |
+
gt_path_list = [] # gt path
|
| 229 |
+
im_ext_list = [] # im ext
|
| 230 |
+
gt_ext_list = [] # gt ext
|
| 231 |
+
for i in range(0,len(name_im_gt_list)):
|
| 232 |
+
dataset_names.append(name_im_gt_list[i]["dataset_name"])
|
| 233 |
+
# dataset name repeated based on the number of images in this dataset
|
| 234 |
+
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]])
|
| 235 |
+
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]])
|
| 236 |
+
im_path_list.extend(name_im_gt_list[i]["im_path"])
|
| 237 |
+
gt_path_list.extend(name_im_gt_list[i]["gt_path"])
|
| 238 |
+
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]])
|
| 239 |
+
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]])
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
self.dataset["data_name"] = dt_name_list
|
| 243 |
+
self.dataset["im_name"] = im_name_list
|
| 244 |
+
self.dataset["im_path"] = im_path_list
|
| 245 |
+
self.dataset["ori_im_path"] = deepcopy(im_path_list)
|
| 246 |
+
self.dataset["gt_path"] = gt_path_list
|
| 247 |
+
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
|
| 248 |
+
self.dataset["im_shp"] = []
|
| 249 |
+
self.dataset["gt_shp"] = []
|
| 250 |
+
self.dataset["im_ext"] = im_ext_list
|
| 251 |
+
self.dataset["gt_ext"] = gt_ext_list
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
self.dataset["ims_pt_dir"] = ""
|
| 255 |
+
self.dataset["gts_pt_dir"] = ""
|
| 256 |
+
|
| 257 |
+
self.dataset = self.manage_cache(dataset_names)
|
| 258 |
+
|
| 259 |
+
def manage_cache(self,dataset_names):
|
| 260 |
+
if not os.path.exists(self.cache_path): # create the folder for cache
|
| 261 |
+
os.makedirs(self.cache_path)
|
| 262 |
+
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
| 263 |
+
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
| 264 |
+
return self.cache(cache_folder)
|
| 265 |
+
return self.load_cache(cache_folder)
|
| 266 |
+
|
| 267 |
+
def cache(self,cache_folder):
|
| 268 |
+
os.mkdir(cache_folder)
|
| 269 |
+
cached_dataset = deepcopy(self.dataset)
|
| 270 |
+
|
| 271 |
+
# ims_list = []
|
| 272 |
+
# gts_list = []
|
| 273 |
+
ims_pt_list = []
|
| 274 |
+
gts_pt_list = []
|
| 275 |
+
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
| 276 |
+
|
| 277 |
+
im_id = cached_dataset["im_name"][i]
|
| 278 |
+
print("im_path: ", im_path)
|
| 279 |
+
im = im_reader(im_path)
|
| 280 |
+
im, im_shp = im_preprocess(im,self.cache_size)
|
| 281 |
+
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
| 282 |
+
torch.save(im,im_cache_file)
|
| 283 |
+
|
| 284 |
+
cached_dataset["im_path"][i] = im_cache_file
|
| 285 |
+
if(self.cache_boost):
|
| 286 |
+
ims_pt_list.append(torch.unsqueeze(im,0))
|
| 287 |
+
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
| 288 |
+
|
| 289 |
+
gt = np.zeros(im.shape[0:2])
|
| 290 |
+
if len(self.dataset["gt_path"])!=0:
|
| 291 |
+
gt = im_reader(self.dataset["gt_path"][i])
|
| 292 |
+
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
| 293 |
+
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
| 294 |
+
torch.save(gt,gt_cache_file)
|
| 295 |
+
if len(self.dataset["gt_path"])>0:
|
| 296 |
+
cached_dataset["gt_path"][i] = gt_cache_file
|
| 297 |
+
else:
|
| 298 |
+
cached_dataset["gt_path"].append(gt_cache_file)
|
| 299 |
+
if(self.cache_boost):
|
| 300 |
+
gts_pt_list.append(torch.unsqueeze(gt,0))
|
| 301 |
+
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
| 302 |
+
|
| 303 |
+
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
|
| 304 |
+
# torch.save(gt_shp, shp_cache_file)
|
| 305 |
+
cached_dataset["im_shp"].append(im_shp)
|
| 306 |
+
# self.dataset["im_shp"].append(im_shp)
|
| 307 |
+
|
| 308 |
+
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
|
| 309 |
+
# torch.save(gt_shp, shp_cache_file)
|
| 310 |
+
cached_dataset["gt_shp"].append(gt_shp)
|
| 311 |
+
# self.dataset["gt_shp"].append(gt_shp)
|
| 312 |
+
|
| 313 |
+
if(self.cache_boost):
|
| 314 |
+
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
|
| 315 |
+
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
|
| 316 |
+
self.ims_pt = torch.cat(ims_pt_list,dim=0)
|
| 317 |
+
self.gts_pt = torch.cat(gts_pt_list,dim=0)
|
| 318 |
+
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
|
| 319 |
+
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
|
| 323 |
+
json.dump(cached_dataset, json_file)
|
| 324 |
+
json_file.close()
|
| 325 |
+
except Exception:
|
| 326 |
+
raise FileNotFoundError("Cannot create JSON")
|
| 327 |
+
return cached_dataset
|
| 328 |
+
|
| 329 |
+
def load_cache(self, cache_folder):
|
| 330 |
+
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
|
| 331 |
+
dataset = json.load(json_file)
|
| 332 |
+
json_file.close()
|
| 333 |
+
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
|
| 334 |
+
## otherwise the pytorch tensor will be loaded
|
| 335 |
+
if(self.cache_boost):
|
| 336 |
+
# self.ims_npy = np.load(dataset["ims_npy_dir"])
|
| 337 |
+
# self.gts_npy = np.load(dataset["gts_npy_dir"])
|
| 338 |
+
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
|
| 339 |
+
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
|
| 340 |
+
return dataset
|
| 341 |
+
|
| 342 |
+
def __len__(self):
|
| 343 |
+
return len(self.dataset["im_path"])
|
| 344 |
+
|
| 345 |
+
def __getitem__(self, idx):
|
| 346 |
+
|
| 347 |
+
im = None
|
| 348 |
+
gt = None
|
| 349 |
+
if(self.cache_boost and self.ims_pt is not None):
|
| 350 |
+
|
| 351 |
+
# start = time.time()
|
| 352 |
+
im = self.ims_pt[idx]#.type(torch.float32)
|
| 353 |
+
gt = self.gts_pt[idx]#.type(torch.float32)
|
| 354 |
+
# print(idx, 'time for pt loading: ', time.time()-start)
|
| 355 |
+
|
| 356 |
+
else:
|
| 357 |
+
# import time
|
| 358 |
+
# start = time.time()
|
| 359 |
+
# print("tensor***")
|
| 360 |
+
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
|
| 361 |
+
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
|
| 362 |
+
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
|
| 363 |
+
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
|
| 364 |
+
# print(idx,'time for tensor loading: ', time.time()-start)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
im_shp = self.dataset["im_shp"][idx]
|
| 368 |
+
# print("time for loading im and gt: ", time.time()-start)
|
| 369 |
+
|
| 370 |
+
# start_time = time.time()
|
| 371 |
+
im = torch.divide(im,255.0)
|
| 372 |
+
gt = torch.divide(gt,255.0)
|
| 373 |
+
# print(idx, 'time for normalize torch divide: ', time.time()-start_time)
|
| 374 |
+
|
| 375 |
+
sample = {
|
| 376 |
+
"imidx": torch.from_numpy(np.array(idx)),
|
| 377 |
+
"image": im,
|
| 378 |
+
"label": gt,
|
| 379 |
+
"shape": torch.from_numpy(np.array(im_shp)),
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
if self.transform:
|
| 383 |
+
sample = self.transform(sample)
|
| 384 |
+
|
| 385 |
+
return sample
|
hce_metric_main.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## hce_metric.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage import io
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import cv2 as cv
|
| 6 |
+
from skimage.morphology import skeletonize
|
| 7 |
+
from skimage.morphology import erosion, dilation, disk
|
| 8 |
+
from skimage.measure import label
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from glob import glob
|
| 14 |
+
import pickle as pkl
|
| 15 |
+
|
| 16 |
+
def filter_bdy_cond(bdy_, mask, cond):
|
| 17 |
+
|
| 18 |
+
cond = cv.dilate(cond.astype(np.uint8),disk(1))
|
| 19 |
+
labels = label(mask) # find the connected regions
|
| 20 |
+
lbls = np.unique(labels) # the indices of the connected regions
|
| 21 |
+
indep = np.ones(lbls.shape[0]) # the label of each connected regions
|
| 22 |
+
indep[0] = 0 # 0 indicate the background region
|
| 23 |
+
|
| 24 |
+
boundaries = []
|
| 25 |
+
h,w = cond.shape[0:2]
|
| 26 |
+
ind_map = np.zeros((h,w))
|
| 27 |
+
indep_cnt = 0
|
| 28 |
+
|
| 29 |
+
for i in range(0,len(bdy_)):
|
| 30 |
+
tmp_bdies = []
|
| 31 |
+
tmp_bdy = []
|
| 32 |
+
for j in range(0,bdy_[i].shape[0]):
|
| 33 |
+
r, c = bdy_[i][j,0,1],bdy_[i][j,0,0]
|
| 34 |
+
|
| 35 |
+
if(np.sum(cond[r,c])==0 or ind_map[r,c]!=0):
|
| 36 |
+
if(len(tmp_bdy)>0):
|
| 37 |
+
tmp_bdies.append(tmp_bdy)
|
| 38 |
+
tmp_bdy = []
|
| 39 |
+
continue
|
| 40 |
+
tmp_bdy.append([c,r])
|
| 41 |
+
ind_map[r,c] = ind_map[r,c] + 1
|
| 42 |
+
indep[labels[r,c]] = 0 # indicates part of the boundary of this region needs human correction
|
| 43 |
+
if(len(tmp_bdy)>0):
|
| 44 |
+
tmp_bdies.append(tmp_bdy)
|
| 45 |
+
|
| 46 |
+
# check if the first and the last boundaries are connected
|
| 47 |
+
# if yes, invert the first boundary and attach it after the last boundary
|
| 48 |
+
if(len(tmp_bdies)>1):
|
| 49 |
+
first_x, first_y = tmp_bdies[0][0]
|
| 50 |
+
last_x, last_y = tmp_bdies[-1][-1]
|
| 51 |
+
if((abs(first_x-last_x)==1 and first_y==last_y) or
|
| 52 |
+
(first_x==last_x and abs(first_y-last_y)==1) or
|
| 53 |
+
(abs(first_x-last_x)==1 and abs(first_y-last_y)==1)
|
| 54 |
+
):
|
| 55 |
+
tmp_bdies[-1].extend(tmp_bdies[0][::-1])
|
| 56 |
+
del tmp_bdies[0]
|
| 57 |
+
|
| 58 |
+
for k in range(0,len(tmp_bdies)):
|
| 59 |
+
tmp_bdies[k] = np.array(tmp_bdies[k])[:,np.newaxis,:]
|
| 60 |
+
if(len(tmp_bdies)>0):
|
| 61 |
+
boundaries.extend(tmp_bdies)
|
| 62 |
+
|
| 63 |
+
return boundaries, np.sum(indep)
|
| 64 |
+
|
| 65 |
+
# this function approximate each boundary by DP algorithm
|
| 66 |
+
# https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm
|
| 67 |
+
def approximate_RDP(boundaries,epsilon=1.0):
|
| 68 |
+
|
| 69 |
+
boundaries_ = []
|
| 70 |
+
boundaries_len_ = []
|
| 71 |
+
pixel_cnt_ = 0
|
| 72 |
+
|
| 73 |
+
# polygon approximate of each boundary
|
| 74 |
+
for i in range(0,len(boundaries)):
|
| 75 |
+
boundaries_.append(cv.approxPolyDP(boundaries[i],epsilon,False))
|
| 76 |
+
|
| 77 |
+
# count the control points number of each boundary and the total control points number of all the boundaries
|
| 78 |
+
for i in range(0,len(boundaries_)):
|
| 79 |
+
boundaries_len_.append(len(boundaries_[i]))
|
| 80 |
+
pixel_cnt_ = pixel_cnt_ + len(boundaries_[i])
|
| 81 |
+
|
| 82 |
+
return boundaries_, boundaries_len_, pixel_cnt_
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def relax_HCE(gt, rs, gt_ske, relax=5, epsilon=2.0):
|
| 86 |
+
# print("max(gt_ske): ", np.amax(gt_ske))
|
| 87 |
+
# gt_ske = gt_ske>128
|
| 88 |
+
# print("max(gt_ske): ", np.amax(gt_ske))
|
| 89 |
+
|
| 90 |
+
# Binarize gt
|
| 91 |
+
if(len(gt.shape)>2):
|
| 92 |
+
gt = gt[:,:,0]
|
| 93 |
+
|
| 94 |
+
epsilon_gt = 128#(np.amin(gt)+np.amax(gt))/2.0
|
| 95 |
+
gt = (gt>epsilon_gt).astype(np.uint8)
|
| 96 |
+
|
| 97 |
+
# Binarize rs
|
| 98 |
+
if(len(rs.shape)>2):
|
| 99 |
+
rs = rs[:,:,0]
|
| 100 |
+
epsilon_rs = 128#(np.amin(rs)+np.amax(rs))/2.0
|
| 101 |
+
rs = (rs>epsilon_rs).astype(np.uint8)
|
| 102 |
+
|
| 103 |
+
Union = np.logical_or(gt,rs)
|
| 104 |
+
TP = np.logical_and(gt,rs)
|
| 105 |
+
FP = rs - TP
|
| 106 |
+
FN = gt - TP
|
| 107 |
+
|
| 108 |
+
# relax the Union of gt and rs
|
| 109 |
+
Union_erode = Union.copy()
|
| 110 |
+
Union_erode = cv.erode(Union_erode.astype(np.uint8),disk(1),iterations=relax)
|
| 111 |
+
|
| 112 |
+
# --- get the relaxed False Positive regions for computing the human efforts in correcting them ---
|
| 113 |
+
FP_ = np.logical_and(FP,Union_erode) # get the relaxed FP
|
| 114 |
+
for i in range(0,relax):
|
| 115 |
+
FP_ = cv.dilate(FP_.astype(np.uint8),disk(1))
|
| 116 |
+
FP_ = np.logical_and(FP_, 1-np.logical_or(TP,FN))
|
| 117 |
+
FP_ = np.logical_and(FP, FP_)
|
| 118 |
+
|
| 119 |
+
# --- get the relaxed False Negative regions for computing the human efforts in correcting them ---
|
| 120 |
+
FN_ = np.logical_and(FN,Union_erode) # preserve the structural components of FN
|
| 121 |
+
## recover the FN, where pixels are not close to the TP borders
|
| 122 |
+
for i in range(0,relax):
|
| 123 |
+
FN_ = cv.dilate(FN_.astype(np.uint8),disk(1))
|
| 124 |
+
FN_ = np.logical_and(FN_,1-np.logical_or(TP,FP))
|
| 125 |
+
FN_ = np.logical_and(FN,FN_)
|
| 126 |
+
FN_ = np.logical_or(FN_, np.logical_xor(gt_ske,np.logical_and(TP,gt_ske))) # preserve the structural components of FN
|
| 127 |
+
|
| 128 |
+
## 2. =============Find exact polygon control points and independent regions==============
|
| 129 |
+
## find contours from FP_
|
| 130 |
+
ctrs_FP, hier_FP = cv.findContours(FP_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
| 131 |
+
## find control points and independent regions for human correction
|
| 132 |
+
bdies_FP, indep_cnt_FP = filter_bdy_cond(ctrs_FP, FP_, np.logical_or(TP,FN_))
|
| 133 |
+
## find contours from FN_
|
| 134 |
+
ctrs_FN, hier_FN = cv.findContours(FN_.astype(np.uint8), cv.RETR_TREE, cv.CHAIN_APPROX_NONE)
|
| 135 |
+
## find control points and independent regions for human correction
|
| 136 |
+
bdies_FN, indep_cnt_FN = filter_bdy_cond(ctrs_FN, FN_, 1-np.logical_or(np.logical_or(TP,FP_),FN_))
|
| 137 |
+
|
| 138 |
+
poly_FP, poly_FP_len, poly_FP_point_cnt = approximate_RDP(bdies_FP,epsilon=epsilon)
|
| 139 |
+
poly_FN, poly_FN_len, poly_FN_point_cnt = approximate_RDP(bdies_FN,epsilon=epsilon)
|
| 140 |
+
|
| 141 |
+
return poly_FP_point_cnt, indep_cnt_FP, poly_FN_point_cnt, indep_cnt_FN
|
| 142 |
+
|
| 143 |
+
def compute_hce(pred_root,gt_root,gt_ske_root):
|
| 144 |
+
|
| 145 |
+
gt_name_list = glob(pred_root+'/*.png')
|
| 146 |
+
gt_name_list = sorted([x.split('/')[-1] for x in gt_name_list])
|
| 147 |
+
|
| 148 |
+
hces = []
|
| 149 |
+
for gt_name in tqdm(gt_name_list, total=len(gt_name_list)):
|
| 150 |
+
gt_path = os.path.join(gt_root, gt_name)
|
| 151 |
+
pred_path = os.path.join(pred_root, gt_name)
|
| 152 |
+
|
| 153 |
+
gt = cv.imread(gt_path, cv.IMREAD_GRAYSCALE)
|
| 154 |
+
pred = cv.imread(pred_path, cv.IMREAD_GRAYSCALE)
|
| 155 |
+
|
| 156 |
+
ske_path = os.path.join(gt_ske_root,gt_name)
|
| 157 |
+
if os.path.exists(ske_path):
|
| 158 |
+
ske = cv.imread(ske_path,cv.IMREAD_GRAYSCALE)
|
| 159 |
+
ske = ske>128
|
| 160 |
+
else:
|
| 161 |
+
ske = skeletonize(gt>128)
|
| 162 |
+
|
| 163 |
+
FP_points, FP_indep, FN_points, FN_indep = relax_HCE(gt, pred,ske)
|
| 164 |
+
print(gt_path.split('/')[-1],FP_points, FP_indep, FN_points, FN_indep)
|
| 165 |
+
hces.append([FP_points, FP_indep, FN_points, FN_indep, FP_points+FP_indep+FN_points+FN_indep])
|
| 166 |
+
|
| 167 |
+
hce_metric ={'names': gt_name_list,
|
| 168 |
+
'hces': hces}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
file_metric = open(pred_root+'/hce_metric.pkl','wb')
|
| 172 |
+
pkl.dump(hce_metric,file_metric)
|
| 173 |
+
# file_metrics.write(cmn_metrics)
|
| 174 |
+
file_metric.close()
|
| 175 |
+
|
| 176 |
+
return np.mean(np.array(hces)[:,-1])
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
|
| 180 |
+
gt_root = "../DIS5K/DIS-VD/gt"
|
| 181 |
+
gt_ske_root = ""
|
| 182 |
+
pred_root = "../Results/isnet(ours)/DIS-VD"
|
| 183 |
+
|
| 184 |
+
print("The average HCE metric: ", compute_hce(pred_root,gt_root,gt_ske_root))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == '__main__':
|
| 188 |
+
main()
|
pytorch18.yml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: pytorch18
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- anaconda
|
| 5 |
+
- pytorch
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- _libgcc_mutex=0.1=main
|
| 9 |
+
- _openmp_mutex=4.5=1_gnu
|
| 10 |
+
- blas=1.0=mkl
|
| 11 |
+
- brotli=1.0.9=he6710b0_2
|
| 12 |
+
- bzip2=1.0.8=h7b6447c_0
|
| 13 |
+
- ca-certificates=2022.2.1=h06a4308_0
|
| 14 |
+
- certifi=2021.10.8=py37h06a4308_2
|
| 15 |
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
| 16 |
+
- colorama=0.4.4=pyhd3eb1b0_0
|
| 17 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
| 18 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 19 |
+
- cytoolz=0.11.0=py37h7b6447c_0
|
| 20 |
+
- dask-core=2021.10.0=pyhd3eb1b0_0
|
| 21 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 22 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
| 23 |
+
- freetype=2.11.0=h70c0345_0
|
| 24 |
+
- fsspec=2022.2.0=pyhd3eb1b0_0
|
| 25 |
+
- gmp=6.2.1=h2531618_2
|
| 26 |
+
- gnutls=3.6.15=he1e5248_0
|
| 27 |
+
- imageio=2.9.0=pyhd3eb1b0_0
|
| 28 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
| 29 |
+
- jpeg=9b=h024ee3a_2
|
| 30 |
+
- kiwisolver=1.3.2=py37h295c915_0
|
| 31 |
+
- lame=3.100=h7b6447c_0
|
| 32 |
+
- lcms2=2.12=h3be6417_0
|
| 33 |
+
- ld_impl_linux-64=2.35.1=h7274673_9
|
| 34 |
+
- libffi=3.3=he6710b0_2
|
| 35 |
+
- libgcc-ng=9.3.0=h5101ec6_17
|
| 36 |
+
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
| 37 |
+
- libgfortran4=7.5.0=ha8ba4b0_17
|
| 38 |
+
- libgomp=9.3.0=h5101ec6_17
|
| 39 |
+
- libiconv=1.15=h63c8f33_5
|
| 40 |
+
- libidn2=2.3.2=h7f8727e_0
|
| 41 |
+
- libpng=1.6.37=hbc83047_0
|
| 42 |
+
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
| 43 |
+
- libtasn1=4.16.0=h27cfd23_0
|
| 44 |
+
- libtiff=4.2.0=h85742a9_0
|
| 45 |
+
- libunistring=0.9.10=h27cfd23_0
|
| 46 |
+
- libuv=1.40.0=h7b6447c_0
|
| 47 |
+
- libwebp-base=1.2.2=h7f8727e_0
|
| 48 |
+
- locket=0.2.1=py37h06a4308_2
|
| 49 |
+
- lz4-c=1.9.3=h295c915_1
|
| 50 |
+
- matplotlib-base=3.5.1=py37ha18d171_1
|
| 51 |
+
- mkl=2021.4.0=h06a4308_640
|
| 52 |
+
- mkl-service=2.4.0=py37h7f8727e_0
|
| 53 |
+
- mkl_fft=1.3.1=py37hd3c417c_0
|
| 54 |
+
- mkl_random=1.2.2=py37h51133e4_0
|
| 55 |
+
- munkres=1.1.4=py_0
|
| 56 |
+
- ncurses=6.3=h7f8727e_2
|
| 57 |
+
- nettle=3.7.3=hbbd107a_1
|
| 58 |
+
- networkx=2.6.3=pyhd3eb1b0_0
|
| 59 |
+
- ninja=1.10.2=py37hd09550d_3
|
| 60 |
+
- numpy=1.21.2=py37h20f2e39_0
|
| 61 |
+
- numpy-base=1.21.2=py37h79a1101_0
|
| 62 |
+
- olefile=0.46=py37_0
|
| 63 |
+
- openh264=2.1.1=h4ff587b_0
|
| 64 |
+
- openssl=1.1.1n=h7f8727e_0
|
| 65 |
+
- packaging=21.3=pyhd3eb1b0_0
|
| 66 |
+
- partd=1.2.0=pyhd3eb1b0_1
|
| 67 |
+
- pillow=8.0.0=py37h9a89aac_0
|
| 68 |
+
- pip=21.2.2=py37h06a4308_0
|
| 69 |
+
- pyparsing=3.0.4=pyhd3eb1b0_0
|
| 70 |
+
- python=3.7.11=h12debd9_0
|
| 71 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 72 |
+
- pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
| 73 |
+
- pywavelets=1.1.1=py37h7b6447c_2
|
| 74 |
+
- pyyaml=6.0=py37h7f8727e_1
|
| 75 |
+
- readline=8.1.2=h7f8727e_1
|
| 76 |
+
- scikit-image=0.15.0=py37hb3f55d8_2
|
| 77 |
+
- scipy=1.7.3=py37hc147768_0
|
| 78 |
+
- setuptools=58.0.4=py37h06a4308_0
|
| 79 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 80 |
+
- sqlite=3.38.0=hc218d9a_0
|
| 81 |
+
- tk=8.6.11=h1ccaba5_0
|
| 82 |
+
- toolz=0.11.2=pyhd3eb1b0_0
|
| 83 |
+
- torchaudio=0.8.0=py37
|
| 84 |
+
- torchvision=0.9.0=py37_cu102
|
| 85 |
+
- tqdm=4.63.0=pyhd8ed1ab_0
|
| 86 |
+
- typing_extensions=3.10.0.2=pyh06a4308_0
|
| 87 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
| 88 |
+
- xz=5.2.5=h7b6447c_0
|
| 89 |
+
- yaml=0.2.5=h7b6447c_0
|
| 90 |
+
- zlib=1.2.11=h7f8727e_4
|
| 91 |
+
- zstd=1.4.9=haebb681_0
|
| 92 |
+
prefix: /home/solar/anaconda3/envs/pytorch18
|
requirements.txt
CHANGED
|
@@ -1,4 +1,87 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file may be used to create an environment using:
|
| 2 |
+
# $ conda create --name <env> --file <this file>
|
| 3 |
+
# platform: linux-64
|
| 4 |
+
_libgcc_mutex=0.1=main
|
| 5 |
+
_openmp_mutex=4.5=1_gnu
|
| 6 |
+
blas=1.0=mkl
|
| 7 |
+
brotli=1.0.9=he6710b0_2
|
| 8 |
+
bzip2=1.0.8=h7b6447c_0
|
| 9 |
+
ca-certificates=2022.2.1=h06a4308_0
|
| 10 |
+
certifi=2021.10.8=py37h06a4308_2
|
| 11 |
+
cloudpickle=2.0.0=pyhd3eb1b0_0
|
| 12 |
+
colorama=0.4.4=pyhd3eb1b0_0
|
| 13 |
+
cudatoolkit=10.2.89=hfd86e86_1
|
| 14 |
+
cycler=0.11.0=pyhd3eb1b0_0
|
| 15 |
+
cytoolz=0.11.0=py37h7b6447c_0
|
| 16 |
+
dask-core=2021.10.0=pyhd3eb1b0_0
|
| 17 |
+
ffmpeg=4.3=hf484d3e_0
|
| 18 |
+
fonttools=4.25.0=pyhd3eb1b0_0
|
| 19 |
+
freetype=2.11.0=h70c0345_0
|
| 20 |
+
fsspec=2022.2.0=pyhd3eb1b0_0
|
| 21 |
+
gmp=6.2.1=h2531618_2
|
| 22 |
+
gnutls=3.6.15=he1e5248_0
|
| 23 |
+
imageio=2.9.0=pyhd3eb1b0_0
|
| 24 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
| 25 |
+
jpeg=9b=h024ee3a_2
|
| 26 |
+
kiwisolver=1.3.2=py37h295c915_0
|
| 27 |
+
lame=3.100=h7b6447c_0
|
| 28 |
+
lcms2=2.12=h3be6417_0
|
| 29 |
+
ld_impl_linux-64=2.35.1=h7274673_9
|
| 30 |
+
libffi=3.3=he6710b0_2
|
| 31 |
+
libgcc-ng=9.3.0=h5101ec6_17
|
| 32 |
+
libgfortran-ng=7.5.0=ha8ba4b0_17
|
| 33 |
+
libgfortran4=7.5.0=ha8ba4b0_17
|
| 34 |
+
libgomp=9.3.0=h5101ec6_17
|
| 35 |
+
libiconv=1.15=h63c8f33_5
|
| 36 |
+
libidn2=2.3.2=h7f8727e_0
|
| 37 |
+
libpng=1.6.37=hbc83047_0
|
| 38 |
+
libstdcxx-ng=9.3.0=hd4cf53a_17
|
| 39 |
+
libtasn1=4.16.0=h27cfd23_0
|
| 40 |
+
libtiff=4.2.0=h85742a9_0
|
| 41 |
+
libunistring=0.9.10=h27cfd23_0
|
| 42 |
+
libuv=1.40.0=h7b6447c_0
|
| 43 |
+
libwebp-base=1.2.2=h7f8727e_0
|
| 44 |
+
locket=0.2.1=py37h06a4308_2
|
| 45 |
+
lz4-c=1.9.3=h295c915_1
|
| 46 |
+
matplotlib-base=3.5.1=py37ha18d171_1
|
| 47 |
+
mkl=2021.4.0=h06a4308_640
|
| 48 |
+
mkl-service=2.4.0=py37h7f8727e_0
|
| 49 |
+
mkl_fft=1.3.1=py37hd3c417c_0
|
| 50 |
+
mkl_random=1.2.2=py37h51133e4_0
|
| 51 |
+
munkres=1.1.4=py_0
|
| 52 |
+
ncurses=6.3=h7f8727e_2
|
| 53 |
+
nettle=3.7.3=hbbd107a_1
|
| 54 |
+
networkx=2.6.3=pyhd3eb1b0_0
|
| 55 |
+
ninja=1.10.2=py37hd09550d_3
|
| 56 |
+
numpy=1.21.2=py37h20f2e39_0
|
| 57 |
+
numpy-base=1.21.2=py37h79a1101_0
|
| 58 |
+
olefile=0.46=py37_0
|
| 59 |
+
openh264=2.1.1=h4ff587b_0
|
| 60 |
+
openssl=1.1.1n=h7f8727e_0
|
| 61 |
+
packaging=21.3=pyhd3eb1b0_0
|
| 62 |
+
partd=1.2.0=pyhd3eb1b0_1
|
| 63 |
+
pillow=8.0.0=py37h9a89aac_0
|
| 64 |
+
pip=21.2.2=py37h06a4308_0
|
| 65 |
+
pyparsing=3.0.4=pyhd3eb1b0_0
|
| 66 |
+
python=3.7.11=h12debd9_0
|
| 67 |
+
python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 68 |
+
pytorch=1.8.0=py3.7_cuda10.2_cudnn7.6.5_0
|
| 69 |
+
pywavelets=1.1.1=py37h7b6447c_2
|
| 70 |
+
pyyaml=6.0=py37h7f8727e_1
|
| 71 |
+
readline=8.1.2=h7f8727e_1
|
| 72 |
+
scikit-image=0.15.0=py37hb3f55d8_2
|
| 73 |
+
scipy=1.7.3=py37hc147768_0
|
| 74 |
+
setuptools=58.0.4=py37h06a4308_0
|
| 75 |
+
six=1.16.0=pyhd3eb1b0_1
|
| 76 |
+
sqlite=3.38.0=hc218d9a_0
|
| 77 |
+
tk=8.6.11=h1ccaba5_0
|
| 78 |
+
toolz=0.11.2=pyhd3eb1b0_0
|
| 79 |
+
torchaudio=0.8.0=py37
|
| 80 |
+
torchvision=0.9.0=py37_cu102
|
| 81 |
+
tqdm=4.63.0=pyhd8ed1ab_0
|
| 82 |
+
typing_extensions=3.10.0.2=pyh06a4308_0
|
| 83 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
| 84 |
+
xz=5.2.5=h7b6447c_0
|
| 85 |
+
yaml=0.2.5=h7b6447c_0
|
| 86 |
+
zlib=1.2.11=h7f8727e_4
|
| 87 |
+
zstd=1.4.9=haebb681_0
|
train_valid_inference_main.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from skimage import io
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch, gc
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
| 14 |
+
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
| 15 |
+
from models import *
|
| 16 |
+
|
| 17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
+
|
| 19 |
+
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
| 20 |
+
|
| 21 |
+
# train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
| 22 |
+
# cache_size = hypar["cache_size"],
|
| 23 |
+
# cache_boost = hypar["cache_boost_train"],
|
| 24 |
+
# my_transforms = [
|
| 25 |
+
# GOSRandomHFlip(),
|
| 26 |
+
# # GOSResize(hypar["input_size"]),
|
| 27 |
+
# # GOSRandomCrop(hypar["crop_size"]),
|
| 28 |
+
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
| 29 |
+
# ],
|
| 30 |
+
# batch_size = hypar["batch_size_train"],
|
| 31 |
+
# shuffle = True)
|
| 32 |
+
|
| 33 |
+
torch.manual_seed(hypar["seed"])
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
torch.cuda.manual_seed(hypar["seed"])
|
| 36 |
+
|
| 37 |
+
print("define gt encoder ...")
|
| 38 |
+
net = ISNetGTEncoder() #UNETGTENCODERCombine()
|
| 39 |
+
## load the existing model gt encoder
|
| 40 |
+
if(hypar["gt_encoder_model"]!=""):
|
| 41 |
+
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
net.load_state_dict(torch.load(model_path))
|
| 44 |
+
net.cuda()
|
| 45 |
+
else:
|
| 46 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
| 47 |
+
print("gt encoder restored from the saved weights ...")
|
| 48 |
+
return net ############
|
| 49 |
+
|
| 50 |
+
if torch.cuda.is_available():
|
| 51 |
+
net.cuda()
|
| 52 |
+
|
| 53 |
+
print("--- define optimizer for GT Encoder---")
|
| 54 |
+
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
| 55 |
+
|
| 56 |
+
model_path = hypar["model_path"]
|
| 57 |
+
model_save_fre = hypar["model_save_fre"]
|
| 58 |
+
max_ite = hypar["max_ite"]
|
| 59 |
+
batch_size_train = hypar["batch_size_train"]
|
| 60 |
+
batch_size_valid = hypar["batch_size_valid"]
|
| 61 |
+
|
| 62 |
+
if(not os.path.exists(model_path)):
|
| 63 |
+
os.mkdir(model_path)
|
| 64 |
+
|
| 65 |
+
ite_num = hypar["start_ite"] # count the total iteration number
|
| 66 |
+
ite_num4val = 0 #
|
| 67 |
+
running_loss = 0.0 # count the toal loss
|
| 68 |
+
running_tar_loss = 0.0 # count the target output loss
|
| 69 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
| 70 |
+
|
| 71 |
+
train_num = train_datasets[0].__len__()
|
| 72 |
+
|
| 73 |
+
net.train()
|
| 74 |
+
|
| 75 |
+
start_last = time.time()
|
| 76 |
+
gos_dataloader = train_dataloaders[0]
|
| 77 |
+
epoch_num = hypar["max_epoch_num"]
|
| 78 |
+
notgood_cnt = 0
|
| 79 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
| 80 |
+
|
| 81 |
+
for i, data in enumerate(gos_dataloader):
|
| 82 |
+
|
| 83 |
+
if(ite_num >= max_ite):
|
| 84 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
| 85 |
+
exit()
|
| 86 |
+
|
| 87 |
+
# start_read = time.time()
|
| 88 |
+
ite_num = ite_num + 1
|
| 89 |
+
ite_num4val = ite_num4val + 1
|
| 90 |
+
|
| 91 |
+
# get the inputs
|
| 92 |
+
labels = data['label']
|
| 93 |
+
|
| 94 |
+
if(hypar["model_digit"]=="full"):
|
| 95 |
+
labels = labels.type(torch.FloatTensor)
|
| 96 |
+
else:
|
| 97 |
+
labels = labels.type(torch.HalfTensor)
|
| 98 |
+
|
| 99 |
+
# wrap them in Variable
|
| 100 |
+
if torch.cuda.is_available():
|
| 101 |
+
labels_v = Variable(labels.cuda(), requires_grad=False)
|
| 102 |
+
else:
|
| 103 |
+
labels_v = Variable(labels, requires_grad=False)
|
| 104 |
+
|
| 105 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
| 106 |
+
|
| 107 |
+
# y zero the parameter gradients
|
| 108 |
+
start_inf_loss_back = time.time()
|
| 109 |
+
optimizer.zero_grad()
|
| 110 |
+
|
| 111 |
+
ds, fs = net(labels_v)#net(inputs_v)
|
| 112 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
| 113 |
+
|
| 114 |
+
loss.backward()
|
| 115 |
+
optimizer.step()
|
| 116 |
+
|
| 117 |
+
running_loss += loss.item()
|
| 118 |
+
running_tar_loss += loss2.item()
|
| 119 |
+
|
| 120 |
+
# del outputs, loss
|
| 121 |
+
del ds, loss2, loss
|
| 122 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
| 123 |
+
|
| 124 |
+
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
| 125 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
| 126 |
+
start_last = time.time()
|
| 127 |
+
|
| 128 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
| 129 |
+
notgood_cnt += 1
|
| 130 |
+
# net.eval()
|
| 131 |
+
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
| 132 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
|
| 133 |
+
|
| 134 |
+
net.train() # resume train
|
| 135 |
+
|
| 136 |
+
tmp_out = 0
|
| 137 |
+
print("last_f1:",last_f1)
|
| 138 |
+
print("tmp_f1:",tmp_f1)
|
| 139 |
+
for fi in range(len(last_f1)):
|
| 140 |
+
if(tmp_f1[fi]>last_f1[fi]):
|
| 141 |
+
tmp_out = 1
|
| 142 |
+
print("tmp_out:",tmp_out)
|
| 143 |
+
if(tmp_out):
|
| 144 |
+
notgood_cnt = 0
|
| 145 |
+
last_f1 = tmp_f1
|
| 146 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
| 147 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
| 148 |
+
maxf1 = '_'.join(tmp_f1_str)
|
| 149 |
+
meanM = '_'.join(tmp_mae_str)
|
| 150 |
+
# .cpu().detach().numpy()
|
| 151 |
+
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
|
| 152 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
| 153 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
| 154 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
| 155 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
| 156 |
+
"_maxF1_" + maxf1 + \
|
| 157 |
+
"_mae_" + meanM + \
|
| 158 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
| 159 |
+
torch.save(net.state_dict(), model_path + model_name)
|
| 160 |
+
|
| 161 |
+
running_loss = 0.0
|
| 162 |
+
running_tar_loss = 0.0
|
| 163 |
+
ite_num4val = 0
|
| 164 |
+
|
| 165 |
+
if(tmp_f1[0]>0.99):
|
| 166 |
+
print("GT encoder is well-trained and obtained...")
|
| 167 |
+
return net
|
| 168 |
+
|
| 169 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
| 170 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
| 171 |
+
exit()
|
| 172 |
+
|
| 173 |
+
print("Training Reaches The Maximum Epoch Number")
|
| 174 |
+
return net
|
| 175 |
+
|
| 176 |
+
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
| 177 |
+
net.eval()
|
| 178 |
+
print("Validating...")
|
| 179 |
+
epoch_num = hypar["max_epoch_num"]
|
| 180 |
+
|
| 181 |
+
val_loss = 0.0
|
| 182 |
+
tar_loss = 0.0
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
tmp_f1 = []
|
| 186 |
+
tmp_mae = []
|
| 187 |
+
tmp_time = []
|
| 188 |
+
|
| 189 |
+
start_valid = time.time()
|
| 190 |
+
for k in range(len(valid_dataloaders)):
|
| 191 |
+
|
| 192 |
+
valid_dataloader = valid_dataloaders[k]
|
| 193 |
+
valid_dataset = valid_datasets[k]
|
| 194 |
+
|
| 195 |
+
val_num = valid_dataset.__len__()
|
| 196 |
+
mybins = np.arange(0,256)
|
| 197 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
| 198 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
| 199 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
| 200 |
+
MAE = np.zeros((val_num))
|
| 201 |
+
|
| 202 |
+
val_cnt = 0.0
|
| 203 |
+
i_val = None
|
| 204 |
+
|
| 205 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
| 206 |
+
|
| 207 |
+
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
| 208 |
+
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
|
| 209 |
+
|
| 210 |
+
if(hypar["model_digit"]=="full"):
|
| 211 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
| 212 |
+
else:
|
| 213 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
| 214 |
+
|
| 215 |
+
# wrap them in Variable
|
| 216 |
+
if torch.cuda.is_available():
|
| 217 |
+
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
|
| 218 |
+
else:
|
| 219 |
+
labels_val_v = Variable(labels_val,requires_grad=False)
|
| 220 |
+
|
| 221 |
+
t_start = time.time()
|
| 222 |
+
ds_val = net(labels_val_v)[0]
|
| 223 |
+
t_end = time.time()-t_start
|
| 224 |
+
tmp_time.append(t_end)
|
| 225 |
+
|
| 226 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
| 227 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
| 228 |
+
|
| 229 |
+
# compute F measure
|
| 230 |
+
for t in range(hypar["batch_size_valid"]):
|
| 231 |
+
val_cnt = val_cnt + 1.0
|
| 232 |
+
print("num of val: ", val_cnt)
|
| 233 |
+
i_test = imidx_val[t].data.numpy()
|
| 234 |
+
|
| 235 |
+
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
| 236 |
+
|
| 237 |
+
## recover the prediction spatial size to the orignal image size
|
| 238 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
| 239 |
+
|
| 240 |
+
ma = torch.max(pred_val)
|
| 241 |
+
mi = torch.min(pred_val)
|
| 242 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
| 243 |
+
# pred_val = normPRED(pred_val)
|
| 244 |
+
|
| 245 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
| 246 |
+
if gt.max()==1:
|
| 247 |
+
gt=gt*255
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
gt = torch.tensor(gt).to(device)
|
| 250 |
+
|
| 251 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
| 252 |
+
|
| 253 |
+
PRE[i_test,:]=pre
|
| 254 |
+
REC[i_test,:] = rec
|
| 255 |
+
F1[i_test,:] = f1
|
| 256 |
+
MAE[i_test] = mae
|
| 257 |
+
|
| 258 |
+
del ds_val, gt
|
| 259 |
+
gc.collect()
|
| 260 |
+
torch.cuda.empty_cache()
|
| 261 |
+
|
| 262 |
+
# if(loss_val.data[0]>1):
|
| 263 |
+
val_loss += loss_val.item()#data[0]
|
| 264 |
+
tar_loss += loss2_val.item()#data[0]
|
| 265 |
+
|
| 266 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
| 267 |
+
|
| 268 |
+
del loss2_val, loss_val
|
| 269 |
+
|
| 270 |
+
print('============================')
|
| 271 |
+
PRE_m = np.mean(PRE,0)
|
| 272 |
+
REC_m = np.mean(REC,0)
|
| 273 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
| 274 |
+
# print('--------------:', np.mean(f1_m))
|
| 275 |
+
tmp_f1.append(np.amax(f1_m))
|
| 276 |
+
tmp_mae.append(np.mean(MAE))
|
| 277 |
+
print("The max F1 Score: %f"%(np.max(f1_m)))
|
| 278 |
+
print("MAE: ", np.mean(MAE))
|
| 279 |
+
|
| 280 |
+
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
|
| 281 |
+
|
| 282 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
| 283 |
+
|
| 284 |
+
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
| 285 |
+
|
| 286 |
+
if hypar["interm_sup"]:
|
| 287 |
+
print("Get the gt encoder ...")
|
| 288 |
+
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
|
| 289 |
+
## freeze the weights of gt encoder
|
| 290 |
+
for param in featurenet.parameters():
|
| 291 |
+
param.requires_grad=False
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
model_path = hypar["model_path"]
|
| 295 |
+
model_save_fre = hypar["model_save_fre"]
|
| 296 |
+
max_ite = hypar["max_ite"]
|
| 297 |
+
batch_size_train = hypar["batch_size_train"]
|
| 298 |
+
batch_size_valid = hypar["batch_size_valid"]
|
| 299 |
+
|
| 300 |
+
if(not os.path.exists(model_path)):
|
| 301 |
+
os.mkdir(model_path)
|
| 302 |
+
|
| 303 |
+
ite_num = hypar["start_ite"] # count the toal iteration number
|
| 304 |
+
ite_num4val = 0 #
|
| 305 |
+
running_loss = 0.0 # count the toal loss
|
| 306 |
+
running_tar_loss = 0.0 # count the target output loss
|
| 307 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
| 308 |
+
|
| 309 |
+
train_num = train_datasets[0].__len__()
|
| 310 |
+
|
| 311 |
+
net.train()
|
| 312 |
+
|
| 313 |
+
start_last = time.time()
|
| 314 |
+
gos_dataloader = train_dataloaders[0]
|
| 315 |
+
epoch_num = hypar["max_epoch_num"]
|
| 316 |
+
notgood_cnt = 0
|
| 317 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
| 318 |
+
|
| 319 |
+
for i, data in enumerate(gos_dataloader):
|
| 320 |
+
|
| 321 |
+
if(ite_num >= max_ite):
|
| 322 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
| 323 |
+
exit()
|
| 324 |
+
|
| 325 |
+
# start_read = time.time()
|
| 326 |
+
ite_num = ite_num + 1
|
| 327 |
+
ite_num4val = ite_num4val + 1
|
| 328 |
+
|
| 329 |
+
# get the inputs
|
| 330 |
+
inputs, labels = data['image'], data['label']
|
| 331 |
+
|
| 332 |
+
if(hypar["model_digit"]=="full"):
|
| 333 |
+
inputs = inputs.type(torch.FloatTensor)
|
| 334 |
+
labels = labels.type(torch.FloatTensor)
|
| 335 |
+
else:
|
| 336 |
+
inputs = inputs.type(torch.HalfTensor)
|
| 337 |
+
labels = labels.type(torch.HalfTensor)
|
| 338 |
+
|
| 339 |
+
# wrap them in Variable
|
| 340 |
+
if torch.cuda.is_available():
|
| 341 |
+
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
| 342 |
+
else:
|
| 343 |
+
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
| 344 |
+
|
| 345 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
| 346 |
+
|
| 347 |
+
# y zero the parameter gradients
|
| 348 |
+
start_inf_loss_back = time.time()
|
| 349 |
+
optimizer.zero_grad()
|
| 350 |
+
|
| 351 |
+
if hypar["interm_sup"]:
|
| 352 |
+
# forward + backward + optimize
|
| 353 |
+
ds,dfs = net(inputs_v)
|
| 354 |
+
_,fs = featurenet(labels_v) ## extract the gt encodings
|
| 355 |
+
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
|
| 356 |
+
else:
|
| 357 |
+
# forward + backward + optimize
|
| 358 |
+
ds,_ = net(inputs_v)
|
| 359 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
| 360 |
+
|
| 361 |
+
loss.backward()
|
| 362 |
+
optimizer.step()
|
| 363 |
+
|
| 364 |
+
# # print statistics
|
| 365 |
+
running_loss += loss.item()
|
| 366 |
+
running_tar_loss += loss2.item()
|
| 367 |
+
|
| 368 |
+
# del outputs, loss
|
| 369 |
+
del ds, loss2, loss
|
| 370 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
| 371 |
+
|
| 372 |
+
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
| 373 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
| 374 |
+
start_last = time.time()
|
| 375 |
+
|
| 376 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
| 377 |
+
notgood_cnt += 1
|
| 378 |
+
net.eval()
|
| 379 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
| 380 |
+
net.train() # resume train
|
| 381 |
+
|
| 382 |
+
tmp_out = 0
|
| 383 |
+
print("last_f1:",last_f1)
|
| 384 |
+
print("tmp_f1:",tmp_f1)
|
| 385 |
+
for fi in range(len(last_f1)):
|
| 386 |
+
if(tmp_f1[fi]>last_f1[fi]):
|
| 387 |
+
tmp_out = 1
|
| 388 |
+
print("tmp_out:",tmp_out)
|
| 389 |
+
if(tmp_out):
|
| 390 |
+
notgood_cnt = 0
|
| 391 |
+
last_f1 = tmp_f1
|
| 392 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
| 393 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
| 394 |
+
maxf1 = '_'.join(tmp_f1_str)
|
| 395 |
+
meanM = '_'.join(tmp_mae_str)
|
| 396 |
+
# .cpu().detach().numpy()
|
| 397 |
+
model_name = "/gpu_itr_"+str(ite_num)+\
|
| 398 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
| 399 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
| 400 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
| 401 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
| 402 |
+
"_maxF1_" + maxf1 + \
|
| 403 |
+
"_mae_" + meanM + \
|
| 404 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
| 405 |
+
torch.save(net.state_dict(), model_path + model_name)
|
| 406 |
+
|
| 407 |
+
running_loss = 0.0
|
| 408 |
+
running_tar_loss = 0.0
|
| 409 |
+
ite_num4val = 0
|
| 410 |
+
|
| 411 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
| 412 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
| 413 |
+
exit()
|
| 414 |
+
|
| 415 |
+
print("Training Reaches The Maximum Epoch Number")
|
| 416 |
+
|
| 417 |
+
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
| 418 |
+
net.eval()
|
| 419 |
+
print("Validating...")
|
| 420 |
+
epoch_num = hypar["max_epoch_num"]
|
| 421 |
+
|
| 422 |
+
val_loss = 0.0
|
| 423 |
+
tar_loss = 0.0
|
| 424 |
+
val_cnt = 0.0
|
| 425 |
+
|
| 426 |
+
tmp_f1 = []
|
| 427 |
+
tmp_mae = []
|
| 428 |
+
tmp_time = []
|
| 429 |
+
|
| 430 |
+
start_valid = time.time()
|
| 431 |
+
|
| 432 |
+
for k in range(len(valid_dataloaders)):
|
| 433 |
+
|
| 434 |
+
valid_dataloader = valid_dataloaders[k]
|
| 435 |
+
valid_dataset = valid_datasets[k]
|
| 436 |
+
|
| 437 |
+
val_num = valid_dataset.__len__()
|
| 438 |
+
mybins = np.arange(0,256)
|
| 439 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
| 440 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
| 441 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
| 442 |
+
MAE = np.zeros((val_num))
|
| 443 |
+
|
| 444 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
| 445 |
+
val_cnt = val_cnt + 1.0
|
| 446 |
+
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
| 447 |
+
|
| 448 |
+
if(hypar["model_digit"]=="full"):
|
| 449 |
+
inputs_val = inputs_val.type(torch.FloatTensor)
|
| 450 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
| 451 |
+
else:
|
| 452 |
+
inputs_val = inputs_val.type(torch.HalfTensor)
|
| 453 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
| 454 |
+
|
| 455 |
+
# wrap them in Variable
|
| 456 |
+
if torch.cuda.is_available():
|
| 457 |
+
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
|
| 458 |
+
else:
|
| 459 |
+
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
|
| 460 |
+
|
| 461 |
+
t_start = time.time()
|
| 462 |
+
ds_val = net(inputs_val_v)[0]
|
| 463 |
+
t_end = time.time()-t_start
|
| 464 |
+
tmp_time.append(t_end)
|
| 465 |
+
|
| 466 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
| 467 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
| 468 |
+
|
| 469 |
+
# compute F measure
|
| 470 |
+
for t in range(hypar["batch_size_valid"]):
|
| 471 |
+
i_test = imidx_val[t].data.numpy()
|
| 472 |
+
|
| 473 |
+
pred_val = ds_val[0][t,:,:,:] # B x 1 x H x W
|
| 474 |
+
|
| 475 |
+
## recover the prediction spatial size to the orignal image size
|
| 476 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
| 477 |
+
|
| 478 |
+
# pred_val = normPRED(pred_val)
|
| 479 |
+
ma = torch.max(pred_val)
|
| 480 |
+
mi = torch.min(pred_val)
|
| 481 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
| 482 |
+
|
| 483 |
+
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
|
| 484 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
| 485 |
+
if gt.max()==1:
|
| 486 |
+
gt=gt*255
|
| 487 |
+
else:
|
| 488 |
+
gt = np.zeros((shapes_val[t][0],shapes_val[t][1]))
|
| 489 |
+
with torch.no_grad():
|
| 490 |
+
gt = torch.tensor(gt).to(device)
|
| 491 |
+
|
| 492 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
PRE[i_test,:]=pre
|
| 496 |
+
REC[i_test,:] = rec
|
| 497 |
+
F1[i_test,:] = f1
|
| 498 |
+
MAE[i_test] = mae
|
| 499 |
+
|
| 500 |
+
del ds_val, gt
|
| 501 |
+
gc.collect()
|
| 502 |
+
torch.cuda.empty_cache()
|
| 503 |
+
|
| 504 |
+
# if(loss_val.data[0]>1):
|
| 505 |
+
val_loss += loss_val.item()#data[0]
|
| 506 |
+
tar_loss += loss2_val.item()#data[0]
|
| 507 |
+
|
| 508 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
| 509 |
+
|
| 510 |
+
del loss2_val, loss_val
|
| 511 |
+
|
| 512 |
+
print('============================')
|
| 513 |
+
PRE_m = np.mean(PRE,0)
|
| 514 |
+
REC_m = np.mean(REC,0)
|
| 515 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
| 516 |
+
|
| 517 |
+
tmp_f1.append(np.amax(f1_m))
|
| 518 |
+
tmp_mae.append(np.mean(MAE))
|
| 519 |
+
|
| 520 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
| 521 |
+
|
| 522 |
+
def main(train_datasets,
|
| 523 |
+
valid_datasets,
|
| 524 |
+
hypar): # model: "train", "test"
|
| 525 |
+
|
| 526 |
+
### --- Step 1: Build datasets and dataloaders ---
|
| 527 |
+
dataloaders_train = []
|
| 528 |
+
dataloaders_valid = []
|
| 529 |
+
|
| 530 |
+
if(hypar["mode"]=="train"):
|
| 531 |
+
print("--- create training dataloader ---")
|
| 532 |
+
## collect training dataset
|
| 533 |
+
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
|
| 534 |
+
## build dataloader for training datasets
|
| 535 |
+
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
| 536 |
+
cache_size = hypar["cache_size"],
|
| 537 |
+
cache_boost = hypar["cache_boost_train"],
|
| 538 |
+
my_transforms = [
|
| 539 |
+
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
|
| 540 |
+
# GOSResize(hypar["input_size"]),
|
| 541 |
+
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
|
| 542 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
| 543 |
+
],
|
| 544 |
+
batch_size = hypar["batch_size_train"],
|
| 545 |
+
shuffle = True)
|
| 546 |
+
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
|
| 547 |
+
cache_size = hypar["cache_size"],
|
| 548 |
+
cache_boost = hypar["cache_boost_train"],
|
| 549 |
+
my_transforms = [
|
| 550 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
| 551 |
+
],
|
| 552 |
+
batch_size = hypar["batch_size_valid"],
|
| 553 |
+
shuffle = False)
|
| 554 |
+
print(len(train_dataloaders), " train dataloaders created")
|
| 555 |
+
|
| 556 |
+
print("--- create valid dataloader ---")
|
| 557 |
+
## build dataloader for validation or testing
|
| 558 |
+
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
|
| 559 |
+
## build dataloader for training datasets
|
| 560 |
+
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
|
| 561 |
+
cache_size = hypar["cache_size"],
|
| 562 |
+
cache_boost = hypar["cache_boost_valid"],
|
| 563 |
+
my_transforms = [
|
| 564 |
+
GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
| 565 |
+
# GOSResize(hypar["input_size"])
|
| 566 |
+
],
|
| 567 |
+
batch_size=hypar["batch_size_valid"],
|
| 568 |
+
shuffle=False)
|
| 569 |
+
print(len(valid_dataloaders), " valid dataloaders created")
|
| 570 |
+
# print(valid_datasets[0]["data_name"])
|
| 571 |
+
|
| 572 |
+
### --- Step 2: Build Model and Optimizer ---
|
| 573 |
+
print("--- build model ---")
|
| 574 |
+
net = hypar["model"]#GOSNETINC(3,1)
|
| 575 |
+
|
| 576 |
+
# convert to half precision
|
| 577 |
+
if(hypar["model_digit"]=="half"):
|
| 578 |
+
net.half()
|
| 579 |
+
for layer in net.modules():
|
| 580 |
+
if isinstance(layer, nn.BatchNorm2d):
|
| 581 |
+
layer.float()
|
| 582 |
+
|
| 583 |
+
if torch.cuda.is_available():
|
| 584 |
+
net.cuda()
|
| 585 |
+
|
| 586 |
+
if(hypar["restore_model"]!=""):
|
| 587 |
+
print("restore model from:")
|
| 588 |
+
print(hypar["model_path"]+"/"+hypar["restore_model"])
|
| 589 |
+
if torch.cuda.is_available():
|
| 590 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]))
|
| 591 |
+
else:
|
| 592 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"))
|
| 593 |
+
|
| 594 |
+
print("--- define optimizer ---")
|
| 595 |
+
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
| 596 |
+
|
| 597 |
+
### --- Step 3: Train or Valid Model ---
|
| 598 |
+
if(hypar["mode"]=="train"):
|
| 599 |
+
train(net,
|
| 600 |
+
optimizer,
|
| 601 |
+
train_dataloaders,
|
| 602 |
+
train_datasets,
|
| 603 |
+
valid_dataloaders,
|
| 604 |
+
valid_datasets,
|
| 605 |
+
hypar,
|
| 606 |
+
train_dataloaders_val, train_datasets_val)
|
| 607 |
+
else:
|
| 608 |
+
valid(net,
|
| 609 |
+
valid_dataloaders,
|
| 610 |
+
valid_datasets,
|
| 611 |
+
hypar)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
if __name__ == "__main__":
|
| 615 |
+
|
| 616 |
+
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
|
| 617 |
+
## configure the train, valid and inference datasets
|
| 618 |
+
train_datasets, valid_datasets = [], []
|
| 619 |
+
dataset_1, dataset_1 = {}, {}
|
| 620 |
+
|
| 621 |
+
dataset_tr = {"name": "DIS5K-TR",
|
| 622 |
+
"im_dir": "../DIS5K/DIS-TR/im",
|
| 623 |
+
"gt_dir": "../DIS5K/DIS-TR/gt",
|
| 624 |
+
"im_ext": ".jpg",
|
| 625 |
+
"gt_ext": ".png",
|
| 626 |
+
"cache_dir":"../DIS5K-Cache/DIS-TR"}
|
| 627 |
+
|
| 628 |
+
dataset_vd = {"name": "DIS5K-VD",
|
| 629 |
+
"im_dir": "../DIS5K/DIS-VD/im",
|
| 630 |
+
"gt_dir": "../DIS5K/DIS-VD/gt",
|
| 631 |
+
"im_ext": ".jpg",
|
| 632 |
+
"gt_ext": ".png",
|
| 633 |
+
"cache_dir":"../DIS5K-Cache/DIS-VD"}
|
| 634 |
+
|
| 635 |
+
dataset_te1 = {"name": "DIS5K-TE1",
|
| 636 |
+
"im_dir": "../DIS5K/DIS-TE1/im",
|
| 637 |
+
"gt_dir": "../DIS5K/DIS-TE1/gt",
|
| 638 |
+
"im_ext": ".jpg",
|
| 639 |
+
"gt_ext": ".png",
|
| 640 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE1"}
|
| 641 |
+
|
| 642 |
+
dataset_te2 = {"name": "DIS5K-TE2",
|
| 643 |
+
"im_dir": "../DIS5K/DIS-TE2/im",
|
| 644 |
+
"gt_dir": "../DIS5K/DIS-TE2/gt",
|
| 645 |
+
"im_ext": ".jpg",
|
| 646 |
+
"gt_ext": ".png",
|
| 647 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE2"}
|
| 648 |
+
|
| 649 |
+
dataset_te3 = {"name": "DIS5K-TE3",
|
| 650 |
+
"im_dir": "../DIS5K/DIS-TE3/im",
|
| 651 |
+
"gt_dir": "../DIS5K/DIS-TE3/gt",
|
| 652 |
+
"im_ext": ".jpg",
|
| 653 |
+
"gt_ext": ".png",
|
| 654 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE3"}
|
| 655 |
+
|
| 656 |
+
dataset_te4 = {"name": "DIS5K-TE4",
|
| 657 |
+
"im_dir": "../DIS5K/DIS-TE4/im",
|
| 658 |
+
"gt_dir": "../DIS5K/DIS-TE4/gt",
|
| 659 |
+
"im_ext": ".jpg",
|
| 660 |
+
"gt_ext": ".png",
|
| 661 |
+
"cache_dir":"../DIS5K-Cache/DIS-TE4"}
|
| 662 |
+
### test your own dataset
|
| 663 |
+
dataset_demo = {"name": "your-dataset",
|
| 664 |
+
"im_dir": "../your-dataset/im",
|
| 665 |
+
"gt_dir": "",
|
| 666 |
+
"im_ext": ".jpg",
|
| 667 |
+
"gt_ext": "",
|
| 668 |
+
"cache_dir":"../your-dataset/cache"}
|
| 669 |
+
|
| 670 |
+
train_datasets = [dataset_tr] ## users can create mutiple dictionary for setting a list of datasets as training set
|
| 671 |
+
# valid_datasets = [dataset_vd] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
| 672 |
+
valid_datasets = [dataset_vd] # dataset_vd, dataset_te1, dataset_te2, dataset_te3, dataset_te4] # and hypar["mode"] = "valid" for inference,
|
| 673 |
+
|
| 674 |
+
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
| 675 |
+
hypar = {}
|
| 676 |
+
|
| 677 |
+
## -- 2.1. configure the model saving or restoring path --
|
| 678 |
+
hypar["mode"] = "train"
|
| 679 |
+
## "train": for training,
|
| 680 |
+
## "valid": for validation and inferening,
|
| 681 |
+
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
| 682 |
+
## otherwise only accuracy will be calculated and no predictions will be saved
|
| 683 |
+
hypar["interm_sup"] = False ## in-dicate if activate intermediate feature supervision
|
| 684 |
+
|
| 685 |
+
if hypar["mode"] == "train":
|
| 686 |
+
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
| 687 |
+
hypar["model_path"] ="../saved_models/IS-Net-test" ## model weights saving (or restoring) path
|
| 688 |
+
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
|
| 689 |
+
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
| 690 |
+
hypar["gt_encoder_model"] = ""
|
| 691 |
+
else: ## configure the segmentation output path and the to-be-used model weights path
|
| 692 |
+
hypar["valid_out_dir"] = "../your-results/"##"../DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
| 693 |
+
hypar["model_path"] = "../saved_models/IS-Net" ## load trained weights from this path
|
| 694 |
+
hypar["restore_model"] = "isnet.pth"##"isnet.pth" ## name of the to-be-loaded weights
|
| 695 |
+
|
| 696 |
+
# if hypar["restore_model"]!="":
|
| 697 |
+
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
| 698 |
+
|
| 699 |
+
## -- 2.2. choose floating point accuracy --
|
| 700 |
+
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
|
| 701 |
+
hypar["seed"] = 0
|
| 702 |
+
|
| 703 |
+
## -- 2.3. cache data spatial size --
|
| 704 |
+
## To handle large size input images, which take a lot of time for loading in training,
|
| 705 |
+
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
|
| 706 |
+
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
|
| 707 |
+
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
|
| 708 |
+
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
|
| 709 |
+
|
| 710 |
+
## --- 2.4. data augmentation parameters ---
|
| 711 |
+
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
|
| 712 |
+
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
|
| 713 |
+
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the dataloader and it is not in use
|
| 714 |
+
hypar["random_flip_v"] = 0 ## vertical flip , currently not in use
|
| 715 |
+
|
| 716 |
+
## --- 2.5. define model ---
|
| 717 |
+
print("building model...")
|
| 718 |
+
hypar["model"] = ISNetDIS() #U2NETFASTFEATURESUP()
|
| 719 |
+
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
|
| 720 |
+
hypar["model_save_fre"] = 2000 ## valid and save model weights every 2000 iterations
|
| 721 |
+
|
| 722 |
+
hypar["batch_size_train"] = 8 ## batch size for training
|
| 723 |
+
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
|
| 724 |
+
print("batch size: ", hypar["batch_size_train"])
|
| 725 |
+
|
| 726 |
+
hypar["max_ite"] = 10000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
|
| 727 |
+
hypar["max_epoch_num"] = 1000000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
|
| 728 |
+
|
| 729 |
+
main(train_datasets,
|
| 730 |
+
valid_datasets,
|
| 731 |
+
hypar=hypar)
|