Mr7Explorer commited on
Commit
06e6b8c
·
verified ·
1 Parent(s): 67390a4

Upload 12 files

Browse files
Files changed (12) hide show
  1. eval_existingOnes.py +73 -0
  2. gen_best_ep.py +85 -0
  3. inference.py +120 -0
  4. loss.py +248 -0
  5. make_a_copy.sh +16 -0
  6. rm_cache.sh +25 -0
  7. sub.sh +17 -0
  8. test.sh +25 -0
  9. train.py +262 -0
  10. train.sh +42 -0
  11. train_test.sh +12 -0
  12. utils.py +100 -0
eval_existingOnes.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from glob import glob
4
+ import prettytable as pt
5
+
6
+ from evaluation.metrics import evaluator, sort_and_round_scores
7
+ from config import Config
8
+
9
+
10
+ config = Config()
11
+
12
+
13
+ def do_eval(args):
14
+ task_to_field_names = {
15
+ 'DIS5K': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
16
+ 'COD': ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
17
+ 'HRSOD': ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
18
+ 'General': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
19
+ 'Matting': ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MSE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
20
+ 'General-2K': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
21
+ 'Others': ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
22
+ }
23
+ for data_name in args.data_lst.split('+'):
24
+ print('#' * 20, data_name, '#' * 20)
25
+ if not glob(os.path.join(args.pred_root, args.model_lst[0], data_name)):
26
+ print('Skip dataset {}.'.format(data_name))
27
+ continue
28
+ gt_paths = sorted(glob(os.path.join(args.gt_root, data_name, 'gt', '*')))
29
+
30
+ tb = pt.PrettyTable()
31
+ tb.vertical_char = '&'
32
+ tb.field_names = task_to_field_names[config.task] if config.task in task_to_field_names else task_to_field_names['Others']
33
+ for model_name in args.model_lst[:]:
34
+ print('\t', 'Evaluating model: {}...'.format(model_name))
35
+ pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, model_name)).replace('/gt/', '/') for p in gt_paths]
36
+
37
+ em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator(
38
+ gt_paths=gt_paths,
39
+ pred_paths=pred_paths,
40
+ metrics=args.metrics.split('+'),
41
+ verbose=config.verbose_eval,
42
+ num_workers=min(8, int(os.cpu_count() * 0.9)),
43
+ )
44
+ scores = sort_and_round_scores(config.task, [em, sm, fm, mae, mse, wfm, hce, mba, biou])
45
+ for idx_score, score in enumerate(scores):
46
+ scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4')
47
+ records = [data_name, model_name] + scores
48
+ tb.add_row(records)
49
+ os.makedirs(args.save_dir, exist_ok=True)
50
+ with open(os.path.join(args.save_dir, '{}_eval.txt'.format(data_name)), 'w+') as file_to_write:
51
+ file_to_write.write(str(tb)+'\n')
52
+ print(tb)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--gt_root', type=str, help='ground-truth root', default=os.path.join(config.data_root_dir, config.task))
58
+ parser.add_argument('--pred_root', type=str, help='prediction root', default='./e_preds')
59
+ parser.add_argument('--data_lst', type=str, help='test datasets', default=config.testsets.replace(',', '+'))
60
+ parser.add_argument('--save_dir', type=str, help='directory to save results', default='e_results')
61
+ parser.add_argument('--metrics', type=str, help='candidate competitors', default='+'.join(['S', 'MAE']))
62
+ args = parser.parse_args()
63
+
64
+ if args.metrics == 'all':
65
+ args.metrics = '+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if sum(['DIS-' in _data for _data in args.data_lst.split('+')]) else -1])
66
+
67
+ try:
68
+ args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1].split('-')[0]), reverse=True) if int(m.split('epoch_')[-1].split('-')[0]) % 1 == 0]
69
+ except Exception as e:
70
+ print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
71
+ args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]
72
+
73
+ do_eval(args)
gen_best_ep.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import numpy as np
4
+ from config import Config
5
+
6
+
7
+ config = Config()
8
+
9
+ eval_txts = sorted(glob('e_results/*_eval.txt'))
10
+ print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts])
11
+ score_panel = {}
12
+ sep = '&'
13
+ metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others.
14
+ if 'DIS5K' not in config.task:
15
+ metrics.remove('hce')
16
+
17
+ for metric in metrics:
18
+ print('Metric:', metric)
19
+ current_line_nums = []
20
+ for idx_et, eval_txt in enumerate(eval_txts):
21
+ with open(eval_txt, 'r') as f:
22
+ lines = [l for l in f.readlines()[3:] if '.' in l]
23
+ current_line_nums.append(len(lines))
24
+ for idx_et, eval_txt in enumerate(eval_txts):
25
+ with open(eval_txt, 'r') as f:
26
+ lines = [l for l in f.readlines()[3:] if '.' in l]
27
+ for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file.
28
+ properties = line.strip().strip(sep).split(sep)
29
+ dataset = properties[0].strip()
30
+ ckpt = properties[1].strip()
31
+ if int(ckpt.split('--epoch_')[-1].strip()) < 0:
32
+ continue
33
+ targe_idx = {
34
+ 'sm': [5, 2, 2, 5, 5, 2],
35
+ 'wfm': [3, 3, 8, 3, 3, 8],
36
+ 'hce': [7, -1, -1, 7, 7, -1]
37
+ }[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'].index(config.task)]
38
+ if metric != 'hce':
39
+ score_sm = float(properties[targe_idx].strip())
40
+ else:
41
+ score_sm = int(properties[targe_idx].strip().strip('.'))
42
+ if idx_et == 0:
43
+ score_panel[ckpt] = []
44
+ score_panel[ckpt].append(score_sm)
45
+
46
+ metrics_min = ['hce', 'mae']
47
+ max_or_min = min if metric in metrics_min else max
48
+ score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x))
49
+
50
+ good_models = []
51
+ for k, v in score_panel.items():
52
+ if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)):
53
+ print(k, v)
54
+ good_models.append(k)
55
+
56
+ # Write
57
+ with open(eval_txt, 'r') as f:
58
+ lines = f.readlines()
59
+ info4good_models = lines[:3]
60
+ metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]]
61
+ testset_mean_values = {metric_name: [] for metric_name in metric_names}
62
+ for good_model in good_models:
63
+ for idx_et, eval_txt in enumerate(eval_txts):
64
+ with open(eval_txt, 'r') as f:
65
+ lines = f.readlines()
66
+ for line in lines:
67
+ if set([good_model]) & set([_.strip() for _ in line.split(sep)]):
68
+ info4good_models.append(line)
69
+ metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]]
70
+ for idx_score, metric_score in enumerate(metric_scores):
71
+ testset_mean_values[metric_names[idx_score]].append(metric_score)
72
+
73
+ if 'DIS5K' in config.task:
74
+ testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD
75
+ sample_line_for_placing_mean_values = info4good_models[-2]
76
+ numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:]
77
+ for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)):
78
+ numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value)
79
+ testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n'
80
+ info4good_models.append(testset_mean_line)
81
+ info4good_models.append(lines[-1])
82
+ info = ''.join(info4good_models)
83
+ print(info)
84
+ with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f:
85
+ f.write(info + '\n')
inference.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from glob import glob
4
+ from tqdm import tqdm
5
+ import cv2
6
+ import torch
7
+ from contextlib import nullcontext
8
+
9
+ from dataset import MyData
10
+ from models.birefnet import BiRefNet
11
+ from utils import save_tensor_img, check_state_dict
12
+ from config import Config
13
+
14
+
15
+ config = Config()
16
+
17
+ mixed_precision = config.mixed_precision
18
+ if mixed_precision == 'fp16':
19
+ mixed_dtype = torch.float16
20
+ elif mixed_precision == 'bf16':
21
+ mixed_dtype = torch.bfloat16
22
+ else:
23
+ mixed_dtype = None
24
+
25
+ autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=mixed_dtype) if mixed_dtype else nullcontext()
26
+
27
+
28
+ def inference(model, data_loader_test, pred_root, method, testset, device=0):
29
+ model_training = model.training
30
+ if model_training:
31
+ model.eval()
32
+ for batch in tqdm(data_loader_test, total=len(data_loader_test)) if config.verbose_eval else data_loader_test:
33
+ inputs = batch[0].to(device)
34
+ label_paths = batch[-1]
35
+ with autocast_ctx, torch.no_grad():
36
+ scaled_preds = model(inputs)[-1].sigmoid().to(torch.float32)
37
+
38
+ os.makedirs(os.path.join(pred_root, method, testset), exist_ok=True)
39
+
40
+ for idx_sample in range(scaled_preds.shape[0]):
41
+ res = torch.nn.functional.interpolate(
42
+ scaled_preds[idx_sample].unsqueeze(0),
43
+ size=cv2.imread(label_paths[idx_sample], cv2.IMREAD_GRAYSCALE).shape[:2],
44
+ mode='bilinear',
45
+ align_corners=True
46
+ )
47
+ save_tensor_img(res, os.path.join(os.path.join(pred_root, method, testset), label_paths[idx_sample].replace('\\', '/').split('/')[-1])) # test set dir + file name
48
+ if model_training:
49
+ model.train()
50
+ return None
51
+
52
+
53
+ def main(args):
54
+ device = config.device
55
+ if args.ckpt_folder:
56
+ print('Testing with models in {}'.format(args.ckpt_folder))
57
+ else:
58
+ print('Testing with model {}'.format(args.ckpt))
59
+
60
+ if config.model == 'BiRefNet':
61
+ model = BiRefNet(bb_pretrained=False)
62
+ else:
63
+ print('Undefined model: {}.'.format(config.model))
64
+ return None
65
+ weights_lst = sorted(
66
+ glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt],
67
+ key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]),
68
+ reverse=True
69
+ )
70
+ try:
71
+ if args.resolution in [None, 'None', 0, '']:
72
+ # Use original resolution for inference.
73
+ data_size = None
74
+ elif args.resolution in ['config.size']:
75
+ data_size = config.size
76
+ else:
77
+ data_size = [int(l) for l in args.resolution.split('x')]
78
+ except Exception as e:
79
+ print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
80
+ # default as the config.size.
81
+ data_size = config.size
82
+
83
+ for testset in args.testsets.split('+'):
84
+ print('>>>> Testset: {}...'.format(testset))
85
+ data_loader_test = torch.utils.data.DataLoader(
86
+ dataset=MyData(testset, data_size=data_size, is_train=False),
87
+ batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True
88
+ )
89
+ for weights in weights_lst:
90
+ if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
91
+ continue
92
+ print('\tInferencing {}...'.format(weights))
93
+ state_dict = torch.load(weights, map_location='cpu', weights_only=True)
94
+ state_dict = check_state_dict(state_dict)
95
+ model.load_state_dict(state_dict)
96
+ model = model.to(device)
97
+ inference(
98
+ model, data_loader_test=data_loader_test, pred_root=args.pred_root,
99
+ method='--'.join([w.rstrip('.pth') for w in weights.split(os.sep)[-2:]]) + '-reso_{}'.format('x'.join([str(s) for s in data_size])),
100
+ testset=testset, device=config.device
101
+ )
102
+
103
+
104
+ if __name__ == '__main__':
105
+ # Parameter from command line
106
+ parser = argparse.ArgumentParser(description='')
107
+ parser.add_argument('--ckpt', type=str, help='model folder')
108
+ parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpts', '*')))[-1], type=str, help='model folder')
109
+ parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder')
110
+ parser.add_argument('--resolution', default='default', type=str, help='WeixHei')
111
+ parser.add_argument('--testsets',
112
+ default=config.testsets.replace(',', '+'),
113
+ type=str,
114
+ help="Test all sets: DIS5K -> 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'")
115
+
116
+ args = parser.parse_args()
117
+
118
+ if config.precisionHigh:
119
+ torch.set_float32_matmul_precision('high')
120
+ main(args)
loss.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from math import exp
6
+ from config import Config
7
+
8
+
9
+ class ContourLoss(torch.nn.Module):
10
+ def __init__(self):
11
+ super(ContourLoss, self).__init__()
12
+
13
+ def forward(self, pred, target, weight=10):
14
+ '''
15
+ target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
16
+ target[:,:,region_out_contour] == 0.
17
+ weight: scalar, length term weight.
18
+ '''
19
+ # length term
20
+ delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
21
+ delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
22
+
23
+ delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
24
+ delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
25
+ delta_pred = torch.abs(delta_r + delta_c)
26
+
27
+ epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
28
+ length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
29
+
30
+ c_in = torch.ones_like(pred)
31
+ c_out = torch.zeros_like(pred)
32
+
33
+ region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum.
34
+ region_out = torch.mean( (1-pred) * (target - c_out)**2 )
35
+ region = region_in + region_out
36
+
37
+ loss = weight * length + region
38
+
39
+ return loss
40
+
41
+
42
+ class IoULoss(torch.nn.Module):
43
+ def __init__(self):
44
+ super(IoULoss, self).__init__()
45
+
46
+ def forward(self, pred, target):
47
+ b = pred.shape[0]
48
+ IoU = 0.0
49
+ for i in range(0, b):
50
+ # compute the IoU of the foreground
51
+ Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
52
+ Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
53
+ IoU1 = Iand1 / Ior1
54
+ # IoU loss is (1-IoU1)
55
+ IoU = IoU + (1-IoU1)
56
+ # return IoU/b
57
+ return IoU
58
+
59
+
60
+ class StructureLoss(torch.nn.Module):
61
+ def __init__(self):
62
+ super(StructureLoss, self).__init__()
63
+
64
+ def forward(self, pred, target):
65
+ weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
66
+ wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
67
+ wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
68
+
69
+ pred = torch.sigmoid(pred)
70
+ inter = ((pred * target) * weit).sum(dim=(2, 3))
71
+ union = ((pred + target) * weit).sum(dim=(2, 3))
72
+ wiou = 1-(inter+1)/(union-inter+1)
73
+
74
+ return (wbce+wiou).mean()
75
+
76
+
77
+ class PatchIoULoss(torch.nn.Module):
78
+ def __init__(self):
79
+ super(PatchIoULoss, self).__init__()
80
+ self.iou_loss = IoULoss()
81
+
82
+ def forward(self, pred, target):
83
+ win_y, win_x = 64, 64
84
+ iou_loss = 0.
85
+ for anchor_y in range(0, target.shape[0], win_y):
86
+ for anchor_x in range(0, target.shape[1], win_y):
87
+ patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
88
+ patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
89
+ patch_iou_loss = self.iou_loss(patch_pred, patch_target)
90
+ iou_loss += patch_iou_loss
91
+ return iou_loss
92
+
93
+
94
+ class ThrReg_loss(torch.nn.Module):
95
+ def __init__(self):
96
+ super(ThrReg_loss, self).__init__()
97
+
98
+ def forward(self, pred, gt=None):
99
+ return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
100
+
101
+
102
+ class ClsLoss(nn.Module):
103
+ """
104
+ Auxiliary classification loss for each refined class output.
105
+ """
106
+ def __init__(self):
107
+ super(ClsLoss, self).__init__()
108
+ self.config = Config()
109
+ self.lambdas_cls = self.config.lambdas_cls
110
+
111
+ self.criterions_last = {
112
+ 'ce': nn.CrossEntropyLoss()
113
+ }
114
+
115
+ def forward(self, preds, gt):
116
+ loss = 0.
117
+ for _, pred_lvl in enumerate(preds):
118
+ if pred_lvl is None:
119
+ continue
120
+ for criterion_name, criterion in self.criterions_last.items():
121
+ loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
122
+ return loss
123
+
124
+
125
+ class PixLoss(nn.Module):
126
+ """
127
+ Pixel loss for each refined map output.
128
+ """
129
+ def __init__(self):
130
+ super(PixLoss, self).__init__()
131
+ self.config = Config()
132
+ self.lambdas_pix_last = self.config.lambdas_pix_last
133
+
134
+ self.criterions_last = {}
135
+ if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']:
136
+ self.criterions_last['bce'] = nn.BCELoss()
137
+ if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']:
138
+ self.criterions_last['iou'] = IoULoss()
139
+ if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']:
140
+ self.criterions_last['iou_patch'] = PatchIoULoss()
141
+ if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']:
142
+ self.criterions_last['ssim'] = SSIMLoss()
143
+ if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']:
144
+ self.criterions_last['mae'] = nn.L1Loss()
145
+ if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']:
146
+ self.criterions_last['mse'] = nn.MSELoss()
147
+ if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']:
148
+ self.criterions_last['reg'] = ThrReg_loss()
149
+ if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
150
+ self.criterions_last['cnt'] = ContourLoss()
151
+ if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
152
+ self.criterions_last['structure'] = StructureLoss()
153
+
154
+ def forward(self, scaled_preds, gt, pix_loss_lambda=1.0):
155
+ loss = 0.
156
+ loss_dict = {}
157
+ for _, pred_lvl in enumerate(scaled_preds):
158
+ if pred_lvl.shape != gt.shape:
159
+ pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
160
+ for criterion_name, criterion in self.criterions_last.items():
161
+ _loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] * pix_loss_lambda
162
+ loss += _loss
163
+ loss_dict[criterion_name] = loss_dict.get(criterion_name, 0.) + _loss.item() / len(scaled_preds)
164
+ # print(criterion_name, _loss.item())
165
+ return loss, loss_dict
166
+
167
+
168
+ class SSIMLoss(torch.nn.Module):
169
+ def __init__(self, window_size=11, size_average=True):
170
+ super(SSIMLoss, self).__init__()
171
+ self.window_size = window_size
172
+ self.size_average = size_average
173
+ self.channel = 1
174
+ self.window = create_window(window_size, self.channel)
175
+
176
+ def forward(self, img1, img2):
177
+ (_, channel, _, _) = img1.size()
178
+ if channel == self.channel and self.window.data.type() == img1.data.type():
179
+ window = self.window
180
+ else:
181
+ window = create_window(self.window_size, channel)
182
+ if img1.is_cuda:
183
+ window = window.cuda(img1.get_device())
184
+ window = window.type_as(img1)
185
+ self.window = window
186
+ self.channel = channel
187
+ return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2
188
+
189
+
190
+ def gaussian(window_size, sigma):
191
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
192
+ return gauss/gauss.sum()
193
+
194
+
195
+ def create_window(window_size, channel):
196
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
197
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
198
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
199
+ return window
200
+
201
+
202
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
203
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel)
204
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel)
205
+
206
+ mu1_sq = mu1.pow(2)
207
+ mu2_sq = mu2.pow(2)
208
+ mu1_mu2 = mu1*mu2
209
+
210
+ sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
211
+ sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
212
+ sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
213
+
214
+ C1 = 0.01**2
215
+ C2 = 0.03**2
216
+
217
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
218
+
219
+ if size_average:
220
+ return ssim_map.mean()
221
+ else:
222
+ return ssim_map.mean(1).mean(1).mean(1)
223
+
224
+
225
+ def SSIM(x, y):
226
+ C1 = 0.01 ** 2
227
+ C2 = 0.03 ** 2
228
+
229
+ mu_x = nn.AvgPool2d(3, 1, 1)(x)
230
+ mu_y = nn.AvgPool2d(3, 1, 1)(y)
231
+ mu_x_mu_y = mu_x * mu_y
232
+ mu_x_sq = mu_x.pow(2)
233
+ mu_y_sq = mu_y.pow(2)
234
+
235
+ sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
236
+ sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
237
+ sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
238
+
239
+ SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
240
+ SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
241
+ SSIM = SSIM_n / SSIM_d
242
+
243
+ return torch.clamp((1 - SSIM) / 2, 0, 1)
244
+
245
+
246
+ def saliency_structure_consistency(x, y):
247
+ ssim = torch.mean(SSIM(x,y))
248
+ return ssim
make_a_copy.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Set dst repo here.
3
+ repo=$1
4
+ mkdir ../${repo}
5
+ mkdir ../${repo}/evaluation
6
+ mkdir ../${repo}/models
7
+ mkdir ../${repo}/models/backbones
8
+ mkdir ../${repo}/models/modules
9
+
10
+ cp ./*.sh ../${repo}
11
+ cp ./*.py ../${repo}
12
+ cp ./evaluation/*.py ../${repo}/evaluation
13
+ cp ./models/*.py ../${repo}/models
14
+ cp ./models/backbones/*.py ../${repo}/models/backbones
15
+ cp ./models/modules/*.py ../${repo}/models/modules
16
+ cp -r ./.git* ../${repo}
rm_cache.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ rm -rf __pycache__ */__pycache__ */*/__pycache__
3
+
4
+ # Val
5
+ rm -r tmp*
6
+
7
+ # Train
8
+ rm slurm*
9
+ rm -r ckpts
10
+ rm nohup.out*
11
+ rm nohup.log*
12
+
13
+ # Eval
14
+ rm -r evaluation/eval-*
15
+ rm -r tmp*
16
+ rm -r e_logs/
17
+
18
+ # System
19
+ rm core-*-python-*
20
+
21
+ # Inference cache
22
+ rm -rf images_todo/
23
+ rm -rf predictions/
24
+
25
+ clear
sub.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Example: ./sub.sh tmp_proj 0,1,2,3 3 --> Use 0,1,2,3 for training, release GPUs, use GPU:3 for inference.
3
+
4
+ # module load gcc/11.2.0 cuda/11.8 cudnn/8.6.0_cu11x && cpu_core_num=6
5
+ module load compilers/cuda/11.8 compilers/gcc/12.2.0 cudnn/8.4.0.27_cuda11.x && cpu_core_num=32
6
+
7
+ export PYTHONUNBUFFERED=1
8
+
9
+ method=${1:-"BSL"}
10
+ devices=${2:-"0,1"}
11
+ gpu_num=$(($(echo ${devices%%,} | grep -o "," | wc -l)+1))
12
+
13
+ sbatch --nodes=1 -p vip_gpu_ailab -A ai4bio \
14
+ --gres=gpu:${gpu_num} --ntasks-per-node=1 --cpus-per-task=$((gpu_num*cpu_core_num)) \
15
+ ./train_test.sh ${method} ${devices}
16
+
17
+ hostname
test.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ devices=${1:-0}
2
+ pred_root=${2:-e_preds}
3
+ resolutions=${3:-"config.size"}
4
+
5
+ # Inference
6
+ # resolutions="1024x1024 None"
7
+ for resolution in ${resolutions}; do
8
+ CUDA_VISIBLE_DEVICES=${devices} python inference.py --pred_root ${pred_root} --resolution ${resolution}
9
+ done
10
+
11
+ echo Inference finished at $(date)
12
+
13
+ # Evaluation
14
+ log_dir=e_logs && mkdir ${log_dir}
15
+
16
+ task=$(python3 config.py --print_task)
17
+ testsets=$(python3 config.py --print_testsets)
18
+
19
+ testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]}
20
+
21
+ for testset in ${testsets}; do
22
+ python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} --metrics 'all' > ${log_dir}/eval_${testset}.out
23
+ done
24
+
25
+ echo Evaluation started at $(date)
train.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ from contextlib import nullcontext
4
+ import argparse
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ if tuple(map(int, torch.__version__.split('+')[0].split(".")[:3])) >= (2, 5, 0):
9
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
10
+
11
+ from config import Config
12
+ from loss import PixLoss, ClsLoss
13
+ from dataset import MyData
14
+ from models.birefnet import BiRefNet
15
+ from utils import Logger, AverageMeter, set_seed, check_state_dict
16
+
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ from torch.distributed import init_process_group, destroy_process_group
20
+
21
+
22
+ parser = argparse.ArgumentParser(description='')
23
+ parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
24
+ parser.add_argument('--epochs', default=120, type=int)
25
+ parser.add_argument('--ckpt_dir', default='ckpts/tmp', help='Temporary folder')
26
+ parser.add_argument('--dist', default=False, type=lambda x: x == 'True')
27
+ parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...')
28
+ args = parser.parse_args()
29
+
30
+ config = Config()
31
+
32
+ if args.use_accelerate:
33
+ from accelerate import Accelerator, utils
34
+ mixed_precision = config.mixed_precision
35
+ kwargs_handlers = [
36
+ utils.InitProcessGroupKwargs(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)),
37
+ utils.DistributedDataParallelKwargs(find_unused_parameters=False),
38
+ utils.GradScalerKwargs(backoff_factor=0.5),
39
+ ]
40
+ if mixed_precision == 'fp8':
41
+ kwargs_handlers.append(utils.AORecipeKwargs())
42
+ accelerator = Accelerator(
43
+ mixed_precision=mixed_precision,
44
+ gradient_accumulation_steps=1,
45
+ kwargs_handlers=kwargs_handlers,
46
+ )
47
+ accelerator.print(accelerator.state)
48
+ accelerator.print('backbone:', config.bb, ', freeze_bb:', config.freeze_bb)
49
+ args.dist = False
50
+
51
+ # DDP
52
+ to_be_distributed = args.dist
53
+ if to_be_distributed:
54
+ init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10))
55
+ device = int(os.environ["LOCAL_RANK"])
56
+ else:
57
+ if args.use_accelerate:
58
+ device = accelerator.local_process_index
59
+ else:
60
+ device = config.device
61
+
62
+ if config.rand_seed:
63
+ set_seed(config.rand_seed + device)
64
+
65
+ epoch_st = 1
66
+ # make dir for ckpt
67
+ os.makedirs(args.ckpt_dir, exist_ok=True)
68
+
69
+ # Init log file
70
+ logger = Logger(os.path.join(args.ckpt_dir, "log.txt"))
71
+ logger_loss_idx = 1
72
+
73
+ # log model and optimizer params
74
+ # logger.info("Model details:"); logger.info(model)
75
+ # if args.use_accelerate and accelerator.mixed_precision != 'no':
76
+ # config.compile = False
77
+ logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile))
78
+ logger.info("Other hyperparameters:"); logger.info(args)
79
+ print('batch size:', config.batch_size)
80
+
81
+ from dataset import custom_collate_fn
82
+
83
+ def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True):
84
+ # Prepare dataloaders
85
+ if to_be_distributed:
86
+ return torch.utils.data.DataLoader(
87
+ dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
88
+ shuffle=False, sampler=DistributedSampler(dataset), drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
89
+ )
90
+ else:
91
+ return torch.utils.data.DataLoader(
92
+ dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
93
+ shuffle=is_train, sampler=None, drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
94
+ )
95
+
96
+
97
+ def init_data_loaders(to_be_distributed):
98
+ # Prepare datasets
99
+ train_loader = prepare_dataloader(
100
+ MyData(datasets=config.training_set, data_size=None if config.dynamic_size else config.size, is_train=True),
101
+ config.batch_size, to_be_distributed=to_be_distributed, is_train=True
102
+ )
103
+ print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set))
104
+ return train_loader
105
+
106
+
107
+ def init_models_optimizers(epochs, to_be_distributed):
108
+ # Init models
109
+ if config.model == 'BiRefNet':
110
+ model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume)))
111
+ else:
112
+ print('Undefined model: {}.'.format(config.model))
113
+ return None
114
+ if args.resume:
115
+ if os.path.isfile(args.resume):
116
+ logger.info("=> loading checkpoint '{}'".format(args.resume))
117
+ state_dict = torch.load(args.resume, map_location='cpu', weights_only=True)
118
+ state_dict = check_state_dict(state_dict)
119
+ model.load_state_dict(state_dict)
120
+ global epoch_st
121
+ epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1
122
+ else:
123
+ logger.info("=> no checkpoint found at '{}'".format(args.resume))
124
+ if not args.use_accelerate:
125
+ if to_be_distributed:
126
+ model = model.to(device)
127
+ model = DDP(model, device_ids=[device])
128
+ else:
129
+ model = model.to(device)
130
+ if config.compile:
131
+ model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0])
132
+ if config.precisionHigh:
133
+ torch.set_float32_matmul_precision('high')
134
+
135
+ # Setting optimizer
136
+ if config.optimizer == 'AdamW':
137
+ optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=1e-2)
138
+ elif config.optimizer == 'Adam':
139
+ optimizer = optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=0)
140
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
141
+ optimizer,
142
+ milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs],
143
+ gamma=config.lr_decay_rate
144
+ )
145
+ # logger.info("Optimizer details:"); logger.info(optimizer)
146
+
147
+ return model, optimizer, lr_scheduler
148
+
149
+
150
+ class Trainer:
151
+ def __init__(
152
+ self, data_loaders, model_opt_lrsch,
153
+ ):
154
+ self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch
155
+ self.train_loader = data_loaders
156
+ if args.use_accelerate:
157
+ self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer)
158
+ if config.out_ref:
159
+ self.criterion_gdt = nn.BCELoss()
160
+
161
+ # Setting Losses
162
+ self.pix_loss = PixLoss()
163
+ self.cls_loss = ClsLoss()
164
+
165
+ # Others
166
+ self.loss_log = AverageMeter()
167
+
168
+ def _train_batch(self, batch):
169
+ if args.use_accelerate:
170
+ inputs = batch[0]#.to(device)
171
+ gts = batch[1]#.to(device)
172
+ class_labels = batch[2]#.to(device)
173
+ else:
174
+ inputs = batch[0].to(device)
175
+ gts = batch[1].to(device)
176
+ class_labels = batch[2].to(device)
177
+ self.optimizer.zero_grad()
178
+ scaled_preds, class_preds_lst = self.model(inputs)
179
+ if config.out_ref:
180
+ (outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds
181
+ for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)):
182
+ _gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid()
183
+ _gdt_label = _gdt_label.sigmoid()
184
+ loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt
185
+ # self.loss_dict['loss_gdt'] = loss_gdt.item()
186
+ if None in class_preds_lst:
187
+ loss_cls = 0.
188
+ else:
189
+ loss_cls = self.cls_loss(class_preds_lst, class_labels)
190
+ self.loss_dict['loss_cls'] = loss_cls.item()
191
+
192
+ # Loss
193
+ loss_pix, loss_dict_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1), pix_loss_lambda=1.0)
194
+ self.loss_dict.update(loss_dict_pix)
195
+ self.loss_dict['loss_pix'] = loss_pix.item()
196
+ # since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py
197
+ loss = loss_pix + loss_cls
198
+ if config.out_ref:
199
+ loss = loss + loss_gdt * 1.0
200
+
201
+ self.loss_log.update(loss.item(), inputs.size(0))
202
+ if args.use_accelerate:
203
+ loss = loss / accelerator.gradient_accumulation_steps
204
+ accelerator.backward(loss)
205
+ else:
206
+ loss.backward()
207
+ self.optimizer.step()
208
+
209
+ def train_epoch(self, epoch):
210
+ global logger_loss_idx
211
+ self.model.train()
212
+ self.loss_dict = {}
213
+ if epoch > args.epochs + config.finetune_last_epochs:
214
+ if config.task == 'Matting':
215
+ self.pix_loss.lambdas_pix_last['mae'] *= 1
216
+ self.pix_loss.lambdas_pix_last['mse'] *= 0.9
217
+ self.pix_loss.lambdas_pix_last['ssim'] *= 0.9
218
+ else:
219
+ self.pix_loss.lambdas_pix_last['bce'] *= 0
220
+ self.pix_loss.lambdas_pix_last['ssim'] *= 1
221
+ self.pix_loss.lambdas_pix_last['iou'] *= 0.5
222
+ self.pix_loss.lambdas_pix_last['mae'] *= 0.9
223
+
224
+ for batch_idx, batch in enumerate(self.train_loader):
225
+ # with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model):
226
+ self._train_batch(batch)
227
+ # Logger
228
+ if (epoch < 2 and batch_idx < 100 and batch_idx % 20 == 0) or batch_idx % max(100, len(self.train_loader) / 100 // 100 * 100) == 0:
229
+ info_progress = f'Epoch[{epoch}/{args.epochs}] Iter[{batch_idx}/{len(self.train_loader)}].'
230
+ info_loss = 'Training Losses:'
231
+ for loss_name, loss_value in self.loss_dict.items():
232
+ info_loss += f' {loss_name}: {loss_value:.5g} |'
233
+ logger.info(' '.join((info_progress, info_loss)))
234
+ info_loss = f'@==Final== Epoch[{epoch}/{args.epochs}] Training Loss: {self.loss_log.avg:.5g} '
235
+ logger.info(info_loss)
236
+
237
+ self.lr_scheduler.step()
238
+ return self.loss_log.avg
239
+
240
+
241
+ def main():
242
+
243
+ trainer = Trainer(
244
+ data_loaders=init_data_loaders(to_be_distributed),
245
+ model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed)
246
+ )
247
+
248
+ for epoch in range(epoch_st, args.epochs+1):
249
+ train_loss = trainer.train_epoch(epoch)
250
+ # Save checkpoint
251
+ if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0:
252
+ if args.use_accelerate:
253
+ state_dict = trainer.model.state_dict()
254
+ else:
255
+ state_dict = trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict()
256
+ torch.save(state_dict, os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)))
257
+ if to_be_distributed:
258
+ destroy_process_group()
259
+
260
+
261
+ if __name__ == '__main__':
262
+ main()
train.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Run script
3
+ # Settings of training & test for different tasks.
4
+ method="$1"
5
+ task=$(python3 config.py --print_task)
6
+ case "${task}" in
7
+ 'DIS5K') epochs=500 && val_last=50 && step=5 ;;
8
+ 'COD') epochs=150 && val_last=50 && step=5 ;;
9
+ 'HRSOD') epochs=150 && val_last=50 && step=5 ;;
10
+ 'General') epochs=200 && val_last=50 && step=5 ;;
11
+ 'General-2K') epochs=250 && val_last=30 && step=2 ;;
12
+ 'Matting') epochs=150 && val_last=50 && step=5 ;;
13
+ esac
14
+
15
+ # Train
16
+ devices=$2
17
+ nproc_per_node=$(echo ${devices%%,} | grep -o "," | wc -l)
18
+
19
+ to_be_distributed=`echo ${nproc_per_node} | awk '{if($e > 0) print "True"; else print "False";}'`
20
+
21
+ echo Training started at $(date)
22
+ resume_weights_path='path_to_a_pth'
23
+ if [ ${to_be_distributed} == "True" ]
24
+ then
25
+ # Adapt the nproc_per_node by the number of GPUs. Give 8989 as the default value of master_port.
26
+ echo "Multi-GPU mode received..."
27
+ CUDA_VISIBLE_DEVICES=${devices} \
28
+ torchrun --standalone --nproc_per_node $((nproc_per_node+1)) \
29
+ train.py --ckpt_dir ckpts/${method} --epochs ${epochs} \
30
+ --dist ${to_be_distributed} \
31
+ --resume ${resume_weights_path} \
32
+ --use_accelerate
33
+ else
34
+ echo "Single-GPU mode received..."
35
+ CUDA_VISIBLE_DEVICES=${devices} \
36
+ python train.py --ckpt_dir ckpts/${method} --epochs ${epochs} \
37
+ --dist ${to_be_distributed} \
38
+ --resume ${resume_weights_path} \
39
+ --use_accelerate
40
+ fi
41
+
42
+ echo Training finished at $(date)
train_test.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Example: `setsid nohup ./train_test.sh BiRefNet 0,1,2,3,4,5,6,7 0 &>nohup.log &`
3
+
4
+ method=${1:-"BSL"}
5
+ devices=${2:-"0,1,2,3,4,5,6,7"}
6
+
7
+ bash train.sh ${method} ${devices}
8
+
9
+ devices_test=${3:-0}
10
+ bash test.sh ${devices_test}
11
+
12
+ hostname
utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ import random
7
+ import cv2
8
+ from PIL import Image
9
+
10
+
11
+ def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]):
12
+ if color_type.lower() == 'rgb':
13
+ image = cv2.imread(path)
14
+ elif color_type.lower() == 'gray':
15
+ image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
16
+ else:
17
+ print('Select the color_type to return, either to RGB or gray image.')
18
+ return
19
+ if size:
20
+ image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
21
+ if color_type.lower() == 'rgb':
22
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB')
23
+ else:
24
+ image = Image.fromarray(image).convert('L')
25
+ return image
26
+
27
+
28
+
29
+ def check_state_dict(state_dict, unwanted_prefixes=['module.', '_orig_mod.']):
30
+ for k, v in list(state_dict.items()):
31
+ prefix_length = 0
32
+ for unwanted_prefix in unwanted_prefixes:
33
+ if k[prefix_length:].startswith(unwanted_prefix):
34
+ prefix_length += len(unwanted_prefix)
35
+ state_dict[k[prefix_length:]] = state_dict.pop(k)
36
+ return state_dict
37
+
38
+
39
+ def generate_smoothed_gt(gts):
40
+ epsilon = 0.001
41
+ new_gts = (1-epsilon)*gts+epsilon/2
42
+ return new_gts
43
+
44
+
45
+ class Logger():
46
+ def __init__(self, path="log.txt"):
47
+ self.logger = logging.getLogger('BiRefNet')
48
+ self.file_handler = logging.FileHandler(path, "w")
49
+ self.stdout_handler = logging.StreamHandler()
50
+ self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
51
+ self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
52
+ self.logger.addHandler(self.file_handler)
53
+ self.logger.addHandler(self.stdout_handler)
54
+ self.logger.setLevel(logging.INFO)
55
+ self.logger.propagate = False
56
+
57
+ def info(self, txt):
58
+ self.logger.info(txt)
59
+
60
+ def close(self):
61
+ self.file_handler.close()
62
+ self.stdout_handler.close()
63
+
64
+
65
+ class AverageMeter(object):
66
+ """Computes and stores the average and current value"""
67
+ def __init__(self):
68
+ self.reset()
69
+
70
+ def reset(self):
71
+ self.val = 0.0
72
+ self.avg = 0.0
73
+ self.sum = 0.0
74
+ self.count = 0.0
75
+
76
+ def update(self, val, n=1):
77
+ self.val = val
78
+ self.sum += val * n
79
+ self.count += n
80
+ self.avg = self.sum / self.count
81
+
82
+
83
+ def save_checkpoint(state, path, filename="latest.pth"):
84
+ torch.save(state, os.path.join(path, filename))
85
+
86
+
87
+ def save_tensor_img(tenor_im, path):
88
+ im = tenor_im.cpu().clone()
89
+ im = im.squeeze(0)
90
+ tensor2pil = transforms.ToPILImage()
91
+ im = tensor2pil(im)
92
+ im.save(path)
93
+
94
+
95
+ def set_seed(seed):
96
+ torch.manual_seed(seed)
97
+ torch.cuda.manual_seed_all(seed)
98
+ np.random.seed(seed)
99
+ random.seed(seed)
100
+ torch.backends.cudnn.deterministic = True