Spaces:
Runtime error
Runtime error
Upload 12 files
Browse files- eval_existingOnes.py +73 -0
- gen_best_ep.py +85 -0
- inference.py +120 -0
- loss.py +248 -0
- make_a_copy.sh +16 -0
- rm_cache.sh +25 -0
- sub.sh +17 -0
- test.sh +25 -0
- train.py +262 -0
- train.sh +42 -0
- train_test.sh +12 -0
- utils.py +100 -0
eval_existingOnes.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from glob import glob
|
| 4 |
+
import prettytable as pt
|
| 5 |
+
|
| 6 |
+
from evaluation.metrics import evaluator, sort_and_round_scores
|
| 7 |
+
from config import Config
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
config = Config()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def do_eval(args):
|
| 14 |
+
task_to_field_names = {
|
| 15 |
+
'DIS5K': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 16 |
+
'COD': ["Dataset", "Method", "Smeasure", "wFmeasure", "meanFm", "meanEm", "maxEm", 'MAE', "maxFm", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 17 |
+
'HRSOD': ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MAE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 18 |
+
'General': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 19 |
+
'Matting': ["Dataset", "Method", "Smeasure", "maxFm", "meanEm", 'MSE', "maxEm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 20 |
+
'General-2K': ["Dataset", "Method", "maxFm", "wFmeasure", 'MAE', "Smeasure", "meanEm", "HCE", "maxEm", "meanFm", "adpEm", "adpFm", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 21 |
+
'Others': ["Dataset", "Method", "Smeasure", 'MAE', "maxEm", "meanEm", "maxFm", "meanFm", "wFmeasure", "adpEm", "adpFm", "HCE", 'mBA', 'maxBIoU', 'meanBIoU'],
|
| 22 |
+
}
|
| 23 |
+
for data_name in args.data_lst.split('+'):
|
| 24 |
+
print('#' * 20, data_name, '#' * 20)
|
| 25 |
+
if not glob(os.path.join(args.pred_root, args.model_lst[0], data_name)):
|
| 26 |
+
print('Skip dataset {}.'.format(data_name))
|
| 27 |
+
continue
|
| 28 |
+
gt_paths = sorted(glob(os.path.join(args.gt_root, data_name, 'gt', '*')))
|
| 29 |
+
|
| 30 |
+
tb = pt.PrettyTable()
|
| 31 |
+
tb.vertical_char = '&'
|
| 32 |
+
tb.field_names = task_to_field_names[config.task] if config.task in task_to_field_names else task_to_field_names['Others']
|
| 33 |
+
for model_name in args.model_lst[:]:
|
| 34 |
+
print('\t', 'Evaluating model: {}...'.format(model_name))
|
| 35 |
+
pred_paths = [p.replace(args.gt_root, os.path.join(args.pred_root, model_name)).replace('/gt/', '/') for p in gt_paths]
|
| 36 |
+
|
| 37 |
+
em, sm, fm, mae, mse, wfm, hce, mba, biou = evaluator(
|
| 38 |
+
gt_paths=gt_paths,
|
| 39 |
+
pred_paths=pred_paths,
|
| 40 |
+
metrics=args.metrics.split('+'),
|
| 41 |
+
verbose=config.verbose_eval,
|
| 42 |
+
num_workers=min(8, int(os.cpu_count() * 0.9)),
|
| 43 |
+
)
|
| 44 |
+
scores = sort_and_round_scores(config.task, [em, sm, fm, mae, mse, wfm, hce, mba, biou])
|
| 45 |
+
for idx_score, score in enumerate(scores):
|
| 46 |
+
scores[idx_score] = '.' + format(score, '.3f').split('.')[-1] if score <= 1 else format(score, '<4')
|
| 47 |
+
records = [data_name, model_name] + scores
|
| 48 |
+
tb.add_row(records)
|
| 49 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 50 |
+
with open(os.path.join(args.save_dir, '{}_eval.txt'.format(data_name)), 'w+') as file_to_write:
|
| 51 |
+
file_to_write.write(str(tb)+'\n')
|
| 52 |
+
print(tb)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
parser = argparse.ArgumentParser()
|
| 57 |
+
parser.add_argument('--gt_root', type=str, help='ground-truth root', default=os.path.join(config.data_root_dir, config.task))
|
| 58 |
+
parser.add_argument('--pred_root', type=str, help='prediction root', default='./e_preds')
|
| 59 |
+
parser.add_argument('--data_lst', type=str, help='test datasets', default=config.testsets.replace(',', '+'))
|
| 60 |
+
parser.add_argument('--save_dir', type=str, help='directory to save results', default='e_results')
|
| 61 |
+
parser.add_argument('--metrics', type=str, help='candidate competitors', default='+'.join(['S', 'MAE']))
|
| 62 |
+
args = parser.parse_args()
|
| 63 |
+
|
| 64 |
+
if args.metrics == 'all':
|
| 65 |
+
args.metrics = '+'.join(['S', 'MAE', 'E', 'F', 'WF', 'MBA', 'BIoU', 'MSE', 'HCE'][:100 if sum(['DIS-' in _data for _data in args.data_lst.split('+')]) else -1])
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
args.model_lst = [m for m in sorted(os.listdir(args.pred_root), key=lambda x: int(x.split('epoch_')[-1].split('-')[0]), reverse=True) if int(m.split('epoch_')[-1].split('-')[0]) % 1 == 0]
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
|
| 71 |
+
args.model_lst = [m for m in sorted(os.listdir(args.pred_root))]
|
| 72 |
+
|
| 73 |
+
do_eval(args)
|
gen_best_ep.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
from config import Config
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
config = Config()
|
| 8 |
+
|
| 9 |
+
eval_txts = sorted(glob('e_results/*_eval.txt'))
|
| 10 |
+
print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts])
|
| 11 |
+
score_panel = {}
|
| 12 |
+
sep = '&'
|
| 13 |
+
metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others.
|
| 14 |
+
if 'DIS5K' not in config.task:
|
| 15 |
+
metrics.remove('hce')
|
| 16 |
+
|
| 17 |
+
for metric in metrics:
|
| 18 |
+
print('Metric:', metric)
|
| 19 |
+
current_line_nums = []
|
| 20 |
+
for idx_et, eval_txt in enumerate(eval_txts):
|
| 21 |
+
with open(eval_txt, 'r') as f:
|
| 22 |
+
lines = [l for l in f.readlines()[3:] if '.' in l]
|
| 23 |
+
current_line_nums.append(len(lines))
|
| 24 |
+
for idx_et, eval_txt in enumerate(eval_txts):
|
| 25 |
+
with open(eval_txt, 'r') as f:
|
| 26 |
+
lines = [l for l in f.readlines()[3:] if '.' in l]
|
| 27 |
+
for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file.
|
| 28 |
+
properties = line.strip().strip(sep).split(sep)
|
| 29 |
+
dataset = properties[0].strip()
|
| 30 |
+
ckpt = properties[1].strip()
|
| 31 |
+
if int(ckpt.split('--epoch_')[-1].strip()) < 0:
|
| 32 |
+
continue
|
| 33 |
+
targe_idx = {
|
| 34 |
+
'sm': [5, 2, 2, 5, 5, 2],
|
| 35 |
+
'wfm': [3, 3, 8, 3, 3, 8],
|
| 36 |
+
'hce': [7, -1, -1, 7, 7, -1]
|
| 37 |
+
}[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'].index(config.task)]
|
| 38 |
+
if metric != 'hce':
|
| 39 |
+
score_sm = float(properties[targe_idx].strip())
|
| 40 |
+
else:
|
| 41 |
+
score_sm = int(properties[targe_idx].strip().strip('.'))
|
| 42 |
+
if idx_et == 0:
|
| 43 |
+
score_panel[ckpt] = []
|
| 44 |
+
score_panel[ckpt].append(score_sm)
|
| 45 |
+
|
| 46 |
+
metrics_min = ['hce', 'mae']
|
| 47 |
+
max_or_min = min if metric in metrics_min else max
|
| 48 |
+
score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x))
|
| 49 |
+
|
| 50 |
+
good_models = []
|
| 51 |
+
for k, v in score_panel.items():
|
| 52 |
+
if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)):
|
| 53 |
+
print(k, v)
|
| 54 |
+
good_models.append(k)
|
| 55 |
+
|
| 56 |
+
# Write
|
| 57 |
+
with open(eval_txt, 'r') as f:
|
| 58 |
+
lines = f.readlines()
|
| 59 |
+
info4good_models = lines[:3]
|
| 60 |
+
metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]]
|
| 61 |
+
testset_mean_values = {metric_name: [] for metric_name in metric_names}
|
| 62 |
+
for good_model in good_models:
|
| 63 |
+
for idx_et, eval_txt in enumerate(eval_txts):
|
| 64 |
+
with open(eval_txt, 'r') as f:
|
| 65 |
+
lines = f.readlines()
|
| 66 |
+
for line in lines:
|
| 67 |
+
if set([good_model]) & set([_.strip() for _ in line.split(sep)]):
|
| 68 |
+
info4good_models.append(line)
|
| 69 |
+
metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]]
|
| 70 |
+
for idx_score, metric_score in enumerate(metric_scores):
|
| 71 |
+
testset_mean_values[metric_names[idx_score]].append(metric_score)
|
| 72 |
+
|
| 73 |
+
if 'DIS5K' in config.task:
|
| 74 |
+
testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD
|
| 75 |
+
sample_line_for_placing_mean_values = info4good_models[-2]
|
| 76 |
+
numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:]
|
| 77 |
+
for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)):
|
| 78 |
+
numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value)
|
| 79 |
+
testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n'
|
| 80 |
+
info4good_models.append(testset_mean_line)
|
| 81 |
+
info4good_models.append(lines[-1])
|
| 82 |
+
info = ''.join(info4good_models)
|
| 83 |
+
print(info)
|
| 84 |
+
with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f:
|
| 85 |
+
f.write(info + '\n')
|
inference.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from glob import glob
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
|
| 9 |
+
from dataset import MyData
|
| 10 |
+
from models.birefnet import BiRefNet
|
| 11 |
+
from utils import save_tensor_img, check_state_dict
|
| 12 |
+
from config import Config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
config = Config()
|
| 16 |
+
|
| 17 |
+
mixed_precision = config.mixed_precision
|
| 18 |
+
if mixed_precision == 'fp16':
|
| 19 |
+
mixed_dtype = torch.float16
|
| 20 |
+
elif mixed_precision == 'bf16':
|
| 21 |
+
mixed_dtype = torch.bfloat16
|
| 22 |
+
else:
|
| 23 |
+
mixed_dtype = None
|
| 24 |
+
|
| 25 |
+
autocast_ctx = torch.amp.autocast(device_type='cuda', dtype=mixed_dtype) if mixed_dtype else nullcontext()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def inference(model, data_loader_test, pred_root, method, testset, device=0):
|
| 29 |
+
model_training = model.training
|
| 30 |
+
if model_training:
|
| 31 |
+
model.eval()
|
| 32 |
+
for batch in tqdm(data_loader_test, total=len(data_loader_test)) if config.verbose_eval else data_loader_test:
|
| 33 |
+
inputs = batch[0].to(device)
|
| 34 |
+
label_paths = batch[-1]
|
| 35 |
+
with autocast_ctx, torch.no_grad():
|
| 36 |
+
scaled_preds = model(inputs)[-1].sigmoid().to(torch.float32)
|
| 37 |
+
|
| 38 |
+
os.makedirs(os.path.join(pred_root, method, testset), exist_ok=True)
|
| 39 |
+
|
| 40 |
+
for idx_sample in range(scaled_preds.shape[0]):
|
| 41 |
+
res = torch.nn.functional.interpolate(
|
| 42 |
+
scaled_preds[idx_sample].unsqueeze(0),
|
| 43 |
+
size=cv2.imread(label_paths[idx_sample], cv2.IMREAD_GRAYSCALE).shape[:2],
|
| 44 |
+
mode='bilinear',
|
| 45 |
+
align_corners=True
|
| 46 |
+
)
|
| 47 |
+
save_tensor_img(res, os.path.join(os.path.join(pred_root, method, testset), label_paths[idx_sample].replace('\\', '/').split('/')[-1])) # test set dir + file name
|
| 48 |
+
if model_training:
|
| 49 |
+
model.train()
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main(args):
|
| 54 |
+
device = config.device
|
| 55 |
+
if args.ckpt_folder:
|
| 56 |
+
print('Testing with models in {}'.format(args.ckpt_folder))
|
| 57 |
+
else:
|
| 58 |
+
print('Testing with model {}'.format(args.ckpt))
|
| 59 |
+
|
| 60 |
+
if config.model == 'BiRefNet':
|
| 61 |
+
model = BiRefNet(bb_pretrained=False)
|
| 62 |
+
else:
|
| 63 |
+
print('Undefined model: {}.'.format(config.model))
|
| 64 |
+
return None
|
| 65 |
+
weights_lst = sorted(
|
| 66 |
+
glob(os.path.join(args.ckpt_folder, '*.pth')) if args.ckpt_folder else [args.ckpt],
|
| 67 |
+
key=lambda x: int(x.split('epoch_')[-1].split('.pth')[0]),
|
| 68 |
+
reverse=True
|
| 69 |
+
)
|
| 70 |
+
try:
|
| 71 |
+
if args.resolution in [None, 'None', 0, '']:
|
| 72 |
+
# Use original resolution for inference.
|
| 73 |
+
data_size = None
|
| 74 |
+
elif args.resolution in ['config.size']:
|
| 75 |
+
data_size = config.size
|
| 76 |
+
else:
|
| 77 |
+
data_size = [int(l) for l in args.resolution.split('x')]
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"Exception: {type(e).__name__} at line {e.__traceback__.tb_lineno} of {__file__}: {e}")
|
| 80 |
+
# default as the config.size.
|
| 81 |
+
data_size = config.size
|
| 82 |
+
|
| 83 |
+
for testset in args.testsets.split('+'):
|
| 84 |
+
print('>>>> Testset: {}...'.format(testset))
|
| 85 |
+
data_loader_test = torch.utils.data.DataLoader(
|
| 86 |
+
dataset=MyData(testset, data_size=data_size, is_train=False),
|
| 87 |
+
batch_size=config.batch_size_valid, shuffle=False, num_workers=config.num_workers, pin_memory=True
|
| 88 |
+
)
|
| 89 |
+
for weights in weights_lst:
|
| 90 |
+
if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
|
| 91 |
+
continue
|
| 92 |
+
print('\tInferencing {}...'.format(weights))
|
| 93 |
+
state_dict = torch.load(weights, map_location='cpu', weights_only=True)
|
| 94 |
+
state_dict = check_state_dict(state_dict)
|
| 95 |
+
model.load_state_dict(state_dict)
|
| 96 |
+
model = model.to(device)
|
| 97 |
+
inference(
|
| 98 |
+
model, data_loader_test=data_loader_test, pred_root=args.pred_root,
|
| 99 |
+
method='--'.join([w.rstrip('.pth') for w in weights.split(os.sep)[-2:]]) + '-reso_{}'.format('x'.join([str(s) for s in data_size])),
|
| 100 |
+
testset=testset, device=config.device
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
# Parameter from command line
|
| 106 |
+
parser = argparse.ArgumentParser(description='')
|
| 107 |
+
parser.add_argument('--ckpt', type=str, help='model folder')
|
| 108 |
+
parser.add_argument('--ckpt_folder', default=sorted(glob(os.path.join('ckpts', '*')))[-1], type=str, help='model folder')
|
| 109 |
+
parser.add_argument('--pred_root', default='e_preds', type=str, help='Output folder')
|
| 110 |
+
parser.add_argument('--resolution', default='default', type=str, help='WeixHei')
|
| 111 |
+
parser.add_argument('--testsets',
|
| 112 |
+
default=config.testsets.replace(',', '+'),
|
| 113 |
+
type=str,
|
| 114 |
+
help="Test all sets: DIS5K -> 'DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'")
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
if config.precisionHigh:
|
| 119 |
+
torch.set_float32_matmul_precision('high')
|
| 120 |
+
main(args)
|
loss.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from math import exp
|
| 6 |
+
from config import Config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ContourLoss(torch.nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(ContourLoss, self).__init__()
|
| 12 |
+
|
| 13 |
+
def forward(self, pred, target, weight=10):
|
| 14 |
+
'''
|
| 15 |
+
target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
|
| 16 |
+
target[:,:,region_out_contour] == 0.
|
| 17 |
+
weight: scalar, length term weight.
|
| 18 |
+
'''
|
| 19 |
+
# length term
|
| 20 |
+
delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
|
| 21 |
+
delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
|
| 22 |
+
|
| 23 |
+
delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
|
| 24 |
+
delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
|
| 25 |
+
delta_pred = torch.abs(delta_r + delta_c)
|
| 26 |
+
|
| 27 |
+
epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
|
| 28 |
+
length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
|
| 29 |
+
|
| 30 |
+
c_in = torch.ones_like(pred)
|
| 31 |
+
c_out = torch.zeros_like(pred)
|
| 32 |
+
|
| 33 |
+
region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum.
|
| 34 |
+
region_out = torch.mean( (1-pred) * (target - c_out)**2 )
|
| 35 |
+
region = region_in + region_out
|
| 36 |
+
|
| 37 |
+
loss = weight * length + region
|
| 38 |
+
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class IoULoss(torch.nn.Module):
|
| 43 |
+
def __init__(self):
|
| 44 |
+
super(IoULoss, self).__init__()
|
| 45 |
+
|
| 46 |
+
def forward(self, pred, target):
|
| 47 |
+
b = pred.shape[0]
|
| 48 |
+
IoU = 0.0
|
| 49 |
+
for i in range(0, b):
|
| 50 |
+
# compute the IoU of the foreground
|
| 51 |
+
Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
|
| 52 |
+
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
|
| 53 |
+
IoU1 = Iand1 / Ior1
|
| 54 |
+
# IoU loss is (1-IoU1)
|
| 55 |
+
IoU = IoU + (1-IoU1)
|
| 56 |
+
# return IoU/b
|
| 57 |
+
return IoU
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class StructureLoss(torch.nn.Module):
|
| 61 |
+
def __init__(self):
|
| 62 |
+
super(StructureLoss, self).__init__()
|
| 63 |
+
|
| 64 |
+
def forward(self, pred, target):
|
| 65 |
+
weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
|
| 66 |
+
wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
|
| 67 |
+
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
|
| 68 |
+
|
| 69 |
+
pred = torch.sigmoid(pred)
|
| 70 |
+
inter = ((pred * target) * weit).sum(dim=(2, 3))
|
| 71 |
+
union = ((pred + target) * weit).sum(dim=(2, 3))
|
| 72 |
+
wiou = 1-(inter+1)/(union-inter+1)
|
| 73 |
+
|
| 74 |
+
return (wbce+wiou).mean()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PatchIoULoss(torch.nn.Module):
|
| 78 |
+
def __init__(self):
|
| 79 |
+
super(PatchIoULoss, self).__init__()
|
| 80 |
+
self.iou_loss = IoULoss()
|
| 81 |
+
|
| 82 |
+
def forward(self, pred, target):
|
| 83 |
+
win_y, win_x = 64, 64
|
| 84 |
+
iou_loss = 0.
|
| 85 |
+
for anchor_y in range(0, target.shape[0], win_y):
|
| 86 |
+
for anchor_x in range(0, target.shape[1], win_y):
|
| 87 |
+
patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
|
| 88 |
+
patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
|
| 89 |
+
patch_iou_loss = self.iou_loss(patch_pred, patch_target)
|
| 90 |
+
iou_loss += patch_iou_loss
|
| 91 |
+
return iou_loss
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ThrReg_loss(torch.nn.Module):
|
| 95 |
+
def __init__(self):
|
| 96 |
+
super(ThrReg_loss, self).__init__()
|
| 97 |
+
|
| 98 |
+
def forward(self, pred, gt=None):
|
| 99 |
+
return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ClsLoss(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
Auxiliary classification loss for each refined class output.
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self):
|
| 107 |
+
super(ClsLoss, self).__init__()
|
| 108 |
+
self.config = Config()
|
| 109 |
+
self.lambdas_cls = self.config.lambdas_cls
|
| 110 |
+
|
| 111 |
+
self.criterions_last = {
|
| 112 |
+
'ce': nn.CrossEntropyLoss()
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
def forward(self, preds, gt):
|
| 116 |
+
loss = 0.
|
| 117 |
+
for _, pred_lvl in enumerate(preds):
|
| 118 |
+
if pred_lvl is None:
|
| 119 |
+
continue
|
| 120 |
+
for criterion_name, criterion in self.criterions_last.items():
|
| 121 |
+
loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
|
| 122 |
+
return loss
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class PixLoss(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
Pixel loss for each refined map output.
|
| 128 |
+
"""
|
| 129 |
+
def __init__(self):
|
| 130 |
+
super(PixLoss, self).__init__()
|
| 131 |
+
self.config = Config()
|
| 132 |
+
self.lambdas_pix_last = self.config.lambdas_pix_last
|
| 133 |
+
|
| 134 |
+
self.criterions_last = {}
|
| 135 |
+
if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']:
|
| 136 |
+
self.criterions_last['bce'] = nn.BCELoss()
|
| 137 |
+
if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']:
|
| 138 |
+
self.criterions_last['iou'] = IoULoss()
|
| 139 |
+
if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']:
|
| 140 |
+
self.criterions_last['iou_patch'] = PatchIoULoss()
|
| 141 |
+
if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']:
|
| 142 |
+
self.criterions_last['ssim'] = SSIMLoss()
|
| 143 |
+
if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']:
|
| 144 |
+
self.criterions_last['mae'] = nn.L1Loss()
|
| 145 |
+
if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']:
|
| 146 |
+
self.criterions_last['mse'] = nn.MSELoss()
|
| 147 |
+
if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']:
|
| 148 |
+
self.criterions_last['reg'] = ThrReg_loss()
|
| 149 |
+
if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
|
| 150 |
+
self.criterions_last['cnt'] = ContourLoss()
|
| 151 |
+
if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
|
| 152 |
+
self.criterions_last['structure'] = StructureLoss()
|
| 153 |
+
|
| 154 |
+
def forward(self, scaled_preds, gt, pix_loss_lambda=1.0):
|
| 155 |
+
loss = 0.
|
| 156 |
+
loss_dict = {}
|
| 157 |
+
for _, pred_lvl in enumerate(scaled_preds):
|
| 158 |
+
if pred_lvl.shape != gt.shape:
|
| 159 |
+
pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
|
| 160 |
+
for criterion_name, criterion in self.criterions_last.items():
|
| 161 |
+
_loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] * pix_loss_lambda
|
| 162 |
+
loss += _loss
|
| 163 |
+
loss_dict[criterion_name] = loss_dict.get(criterion_name, 0.) + _loss.item() / len(scaled_preds)
|
| 164 |
+
# print(criterion_name, _loss.item())
|
| 165 |
+
return loss, loss_dict
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SSIMLoss(torch.nn.Module):
|
| 169 |
+
def __init__(self, window_size=11, size_average=True):
|
| 170 |
+
super(SSIMLoss, self).__init__()
|
| 171 |
+
self.window_size = window_size
|
| 172 |
+
self.size_average = size_average
|
| 173 |
+
self.channel = 1
|
| 174 |
+
self.window = create_window(window_size, self.channel)
|
| 175 |
+
|
| 176 |
+
def forward(self, img1, img2):
|
| 177 |
+
(_, channel, _, _) = img1.size()
|
| 178 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
| 179 |
+
window = self.window
|
| 180 |
+
else:
|
| 181 |
+
window = create_window(self.window_size, channel)
|
| 182 |
+
if img1.is_cuda:
|
| 183 |
+
window = window.cuda(img1.get_device())
|
| 184 |
+
window = window.type_as(img1)
|
| 185 |
+
self.window = window
|
| 186 |
+
self.channel = channel
|
| 187 |
+
return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def gaussian(window_size, sigma):
|
| 191 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
| 192 |
+
return gauss/gauss.sum()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def create_window(window_size, channel):
|
| 196 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 197 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 198 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
| 199 |
+
return window
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
| 203 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel)
|
| 204 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel)
|
| 205 |
+
|
| 206 |
+
mu1_sq = mu1.pow(2)
|
| 207 |
+
mu2_sq = mu2.pow(2)
|
| 208 |
+
mu1_mu2 = mu1*mu2
|
| 209 |
+
|
| 210 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
|
| 211 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
|
| 212 |
+
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
|
| 213 |
+
|
| 214 |
+
C1 = 0.01**2
|
| 215 |
+
C2 = 0.03**2
|
| 216 |
+
|
| 217 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
| 218 |
+
|
| 219 |
+
if size_average:
|
| 220 |
+
return ssim_map.mean()
|
| 221 |
+
else:
|
| 222 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def SSIM(x, y):
|
| 226 |
+
C1 = 0.01 ** 2
|
| 227 |
+
C2 = 0.03 ** 2
|
| 228 |
+
|
| 229 |
+
mu_x = nn.AvgPool2d(3, 1, 1)(x)
|
| 230 |
+
mu_y = nn.AvgPool2d(3, 1, 1)(y)
|
| 231 |
+
mu_x_mu_y = mu_x * mu_y
|
| 232 |
+
mu_x_sq = mu_x.pow(2)
|
| 233 |
+
mu_y_sq = mu_y.pow(2)
|
| 234 |
+
|
| 235 |
+
sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
|
| 236 |
+
sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
|
| 237 |
+
sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
|
| 238 |
+
|
| 239 |
+
SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
|
| 240 |
+
SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
|
| 241 |
+
SSIM = SSIM_n / SSIM_d
|
| 242 |
+
|
| 243 |
+
return torch.clamp((1 - SSIM) / 2, 0, 1)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def saliency_structure_consistency(x, y):
|
| 247 |
+
ssim = torch.mean(SSIM(x,y))
|
| 248 |
+
return ssim
|
make_a_copy.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Set dst repo here.
|
| 3 |
+
repo=$1
|
| 4 |
+
mkdir ../${repo}
|
| 5 |
+
mkdir ../${repo}/evaluation
|
| 6 |
+
mkdir ../${repo}/models
|
| 7 |
+
mkdir ../${repo}/models/backbones
|
| 8 |
+
mkdir ../${repo}/models/modules
|
| 9 |
+
|
| 10 |
+
cp ./*.sh ../${repo}
|
| 11 |
+
cp ./*.py ../${repo}
|
| 12 |
+
cp ./evaluation/*.py ../${repo}/evaluation
|
| 13 |
+
cp ./models/*.py ../${repo}/models
|
| 14 |
+
cp ./models/backbones/*.py ../${repo}/models/backbones
|
| 15 |
+
cp ./models/modules/*.py ../${repo}/models/modules
|
| 16 |
+
cp -r ./.git* ../${repo}
|
rm_cache.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
rm -rf __pycache__ */__pycache__ */*/__pycache__
|
| 3 |
+
|
| 4 |
+
# Val
|
| 5 |
+
rm -r tmp*
|
| 6 |
+
|
| 7 |
+
# Train
|
| 8 |
+
rm slurm*
|
| 9 |
+
rm -r ckpts
|
| 10 |
+
rm nohup.out*
|
| 11 |
+
rm nohup.log*
|
| 12 |
+
|
| 13 |
+
# Eval
|
| 14 |
+
rm -r evaluation/eval-*
|
| 15 |
+
rm -r tmp*
|
| 16 |
+
rm -r e_logs/
|
| 17 |
+
|
| 18 |
+
# System
|
| 19 |
+
rm core-*-python-*
|
| 20 |
+
|
| 21 |
+
# Inference cache
|
| 22 |
+
rm -rf images_todo/
|
| 23 |
+
rm -rf predictions/
|
| 24 |
+
|
| 25 |
+
clear
|
sub.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Example: ./sub.sh tmp_proj 0,1,2,3 3 --> Use 0,1,2,3 for training, release GPUs, use GPU:3 for inference.
|
| 3 |
+
|
| 4 |
+
# module load gcc/11.2.0 cuda/11.8 cudnn/8.6.0_cu11x && cpu_core_num=6
|
| 5 |
+
module load compilers/cuda/11.8 compilers/gcc/12.2.0 cudnn/8.4.0.27_cuda11.x && cpu_core_num=32
|
| 6 |
+
|
| 7 |
+
export PYTHONUNBUFFERED=1
|
| 8 |
+
|
| 9 |
+
method=${1:-"BSL"}
|
| 10 |
+
devices=${2:-"0,1"}
|
| 11 |
+
gpu_num=$(($(echo ${devices%%,} | grep -o "," | wc -l)+1))
|
| 12 |
+
|
| 13 |
+
sbatch --nodes=1 -p vip_gpu_ailab -A ai4bio \
|
| 14 |
+
--gres=gpu:${gpu_num} --ntasks-per-node=1 --cpus-per-task=$((gpu_num*cpu_core_num)) \
|
| 15 |
+
./train_test.sh ${method} ${devices}
|
| 16 |
+
|
| 17 |
+
hostname
|
test.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
devices=${1:-0}
|
| 2 |
+
pred_root=${2:-e_preds}
|
| 3 |
+
resolutions=${3:-"config.size"}
|
| 4 |
+
|
| 5 |
+
# Inference
|
| 6 |
+
# resolutions="1024x1024 None"
|
| 7 |
+
for resolution in ${resolutions}; do
|
| 8 |
+
CUDA_VISIBLE_DEVICES=${devices} python inference.py --pred_root ${pred_root} --resolution ${resolution}
|
| 9 |
+
done
|
| 10 |
+
|
| 11 |
+
echo Inference finished at $(date)
|
| 12 |
+
|
| 13 |
+
# Evaluation
|
| 14 |
+
log_dir=e_logs && mkdir ${log_dir}
|
| 15 |
+
|
| 16 |
+
task=$(python3 config.py --print_task)
|
| 17 |
+
testsets=$(python3 config.py --print_testsets)
|
| 18 |
+
|
| 19 |
+
testsets=(`echo ${testsets} | tr ',' ' '`) && testsets=${testsets[@]}
|
| 20 |
+
|
| 21 |
+
for testset in ${testsets}; do
|
| 22 |
+
python eval_existingOnes.py --pred_root ${pred_root} --data_lst ${testset} --metrics 'all' > ${log_dir}/eval_${testset}.out
|
| 23 |
+
done
|
| 24 |
+
|
| 25 |
+
echo Evaluation started at $(date)
|
train.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datetime
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
if tuple(map(int, torch.__version__.split('+')[0].split(".")[:3])) >= (2, 5, 0):
|
| 9 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 10 |
+
|
| 11 |
+
from config import Config
|
| 12 |
+
from loss import PixLoss, ClsLoss
|
| 13 |
+
from dataset import MyData
|
| 14 |
+
from models.birefnet import BiRefNet
|
| 15 |
+
from utils import Logger, AverageMeter, set_seed, check_state_dict
|
| 16 |
+
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
from torch.distributed import init_process_group, destroy_process_group
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
parser = argparse.ArgumentParser(description='')
|
| 23 |
+
parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')
|
| 24 |
+
parser.add_argument('--epochs', default=120, type=int)
|
| 25 |
+
parser.add_argument('--ckpt_dir', default='ckpts/tmp', help='Temporary folder')
|
| 26 |
+
parser.add_argument('--dist', default=False, type=lambda x: x == 'True')
|
| 27 |
+
parser.add_argument('--use_accelerate', action='store_true', help='`accelerate launch --multi_gpu train.py --use_accelerate`. Use accelerate for training, good for FP16/BF16/...')
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
config = Config()
|
| 31 |
+
|
| 32 |
+
if args.use_accelerate:
|
| 33 |
+
from accelerate import Accelerator, utils
|
| 34 |
+
mixed_precision = config.mixed_precision
|
| 35 |
+
kwargs_handlers = [
|
| 36 |
+
utils.InitProcessGroupKwargs(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)),
|
| 37 |
+
utils.DistributedDataParallelKwargs(find_unused_parameters=False),
|
| 38 |
+
utils.GradScalerKwargs(backoff_factor=0.5),
|
| 39 |
+
]
|
| 40 |
+
if mixed_precision == 'fp8':
|
| 41 |
+
kwargs_handlers.append(utils.AORecipeKwargs())
|
| 42 |
+
accelerator = Accelerator(
|
| 43 |
+
mixed_precision=mixed_precision,
|
| 44 |
+
gradient_accumulation_steps=1,
|
| 45 |
+
kwargs_handlers=kwargs_handlers,
|
| 46 |
+
)
|
| 47 |
+
accelerator.print(accelerator.state)
|
| 48 |
+
accelerator.print('backbone:', config.bb, ', freeze_bb:', config.freeze_bb)
|
| 49 |
+
args.dist = False
|
| 50 |
+
|
| 51 |
+
# DDP
|
| 52 |
+
to_be_distributed = args.dist
|
| 53 |
+
if to_be_distributed:
|
| 54 |
+
init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10))
|
| 55 |
+
device = int(os.environ["LOCAL_RANK"])
|
| 56 |
+
else:
|
| 57 |
+
if args.use_accelerate:
|
| 58 |
+
device = accelerator.local_process_index
|
| 59 |
+
else:
|
| 60 |
+
device = config.device
|
| 61 |
+
|
| 62 |
+
if config.rand_seed:
|
| 63 |
+
set_seed(config.rand_seed + device)
|
| 64 |
+
|
| 65 |
+
epoch_st = 1
|
| 66 |
+
# make dir for ckpt
|
| 67 |
+
os.makedirs(args.ckpt_dir, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
# Init log file
|
| 70 |
+
logger = Logger(os.path.join(args.ckpt_dir, "log.txt"))
|
| 71 |
+
logger_loss_idx = 1
|
| 72 |
+
|
| 73 |
+
# log model and optimizer params
|
| 74 |
+
# logger.info("Model details:"); logger.info(model)
|
| 75 |
+
# if args.use_accelerate and accelerator.mixed_precision != 'no':
|
| 76 |
+
# config.compile = False
|
| 77 |
+
logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile))
|
| 78 |
+
logger.info("Other hyperparameters:"); logger.info(args)
|
| 79 |
+
print('batch size:', config.batch_size)
|
| 80 |
+
|
| 81 |
+
from dataset import custom_collate_fn
|
| 82 |
+
|
| 83 |
+
def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True):
|
| 84 |
+
# Prepare dataloaders
|
| 85 |
+
if to_be_distributed:
|
| 86 |
+
return torch.utils.data.DataLoader(
|
| 87 |
+
dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
|
| 88 |
+
shuffle=False, sampler=DistributedSampler(dataset), drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
return torch.utils.data.DataLoader(
|
| 92 |
+
dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True,
|
| 93 |
+
shuffle=is_train, sampler=None, drop_last=True, collate_fn=custom_collate_fn if is_train and config.dynamic_size else None
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def init_data_loaders(to_be_distributed):
|
| 98 |
+
# Prepare datasets
|
| 99 |
+
train_loader = prepare_dataloader(
|
| 100 |
+
MyData(datasets=config.training_set, data_size=None if config.dynamic_size else config.size, is_train=True),
|
| 101 |
+
config.batch_size, to_be_distributed=to_be_distributed, is_train=True
|
| 102 |
+
)
|
| 103 |
+
print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set))
|
| 104 |
+
return train_loader
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def init_models_optimizers(epochs, to_be_distributed):
|
| 108 |
+
# Init models
|
| 109 |
+
if config.model == 'BiRefNet':
|
| 110 |
+
model = BiRefNet(bb_pretrained=True and not os.path.isfile(str(args.resume)))
|
| 111 |
+
else:
|
| 112 |
+
print('Undefined model: {}.'.format(config.model))
|
| 113 |
+
return None
|
| 114 |
+
if args.resume:
|
| 115 |
+
if os.path.isfile(args.resume):
|
| 116 |
+
logger.info("=> loading checkpoint '{}'".format(args.resume))
|
| 117 |
+
state_dict = torch.load(args.resume, map_location='cpu', weights_only=True)
|
| 118 |
+
state_dict = check_state_dict(state_dict)
|
| 119 |
+
model.load_state_dict(state_dict)
|
| 120 |
+
global epoch_st
|
| 121 |
+
epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1
|
| 122 |
+
else:
|
| 123 |
+
logger.info("=> no checkpoint found at '{}'".format(args.resume))
|
| 124 |
+
if not args.use_accelerate:
|
| 125 |
+
if to_be_distributed:
|
| 126 |
+
model = model.to(device)
|
| 127 |
+
model = DDP(model, device_ids=[device])
|
| 128 |
+
else:
|
| 129 |
+
model = model.to(device)
|
| 130 |
+
if config.compile:
|
| 131 |
+
model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0])
|
| 132 |
+
if config.precisionHigh:
|
| 133 |
+
torch.set_float32_matmul_precision('high')
|
| 134 |
+
|
| 135 |
+
# Setting optimizer
|
| 136 |
+
if config.optimizer == 'AdamW':
|
| 137 |
+
optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=1e-2)
|
| 138 |
+
elif config.optimizer == 'Adam':
|
| 139 |
+
optimizer = optim.Adam(params=[p for p in model.parameters() if p.requires_grad], lr=config.lr, weight_decay=0)
|
| 140 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 141 |
+
optimizer,
|
| 142 |
+
milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs],
|
| 143 |
+
gamma=config.lr_decay_rate
|
| 144 |
+
)
|
| 145 |
+
# logger.info("Optimizer details:"); logger.info(optimizer)
|
| 146 |
+
|
| 147 |
+
return model, optimizer, lr_scheduler
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Trainer:
|
| 151 |
+
def __init__(
|
| 152 |
+
self, data_loaders, model_opt_lrsch,
|
| 153 |
+
):
|
| 154 |
+
self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch
|
| 155 |
+
self.train_loader = data_loaders
|
| 156 |
+
if args.use_accelerate:
|
| 157 |
+
self.train_loader, self.model, self.optimizer = accelerator.prepare(self.train_loader, self.model, self.optimizer)
|
| 158 |
+
if config.out_ref:
|
| 159 |
+
self.criterion_gdt = nn.BCELoss()
|
| 160 |
+
|
| 161 |
+
# Setting Losses
|
| 162 |
+
self.pix_loss = PixLoss()
|
| 163 |
+
self.cls_loss = ClsLoss()
|
| 164 |
+
|
| 165 |
+
# Others
|
| 166 |
+
self.loss_log = AverageMeter()
|
| 167 |
+
|
| 168 |
+
def _train_batch(self, batch):
|
| 169 |
+
if args.use_accelerate:
|
| 170 |
+
inputs = batch[0]#.to(device)
|
| 171 |
+
gts = batch[1]#.to(device)
|
| 172 |
+
class_labels = batch[2]#.to(device)
|
| 173 |
+
else:
|
| 174 |
+
inputs = batch[0].to(device)
|
| 175 |
+
gts = batch[1].to(device)
|
| 176 |
+
class_labels = batch[2].to(device)
|
| 177 |
+
self.optimizer.zero_grad()
|
| 178 |
+
scaled_preds, class_preds_lst = self.model(inputs)
|
| 179 |
+
if config.out_ref:
|
| 180 |
+
(outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds
|
| 181 |
+
for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)):
|
| 182 |
+
_gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid()
|
| 183 |
+
_gdt_label = _gdt_label.sigmoid()
|
| 184 |
+
loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt
|
| 185 |
+
# self.loss_dict['loss_gdt'] = loss_gdt.item()
|
| 186 |
+
if None in class_preds_lst:
|
| 187 |
+
loss_cls = 0.
|
| 188 |
+
else:
|
| 189 |
+
loss_cls = self.cls_loss(class_preds_lst, class_labels)
|
| 190 |
+
self.loss_dict['loss_cls'] = loss_cls.item()
|
| 191 |
+
|
| 192 |
+
# Loss
|
| 193 |
+
loss_pix, loss_dict_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1), pix_loss_lambda=1.0)
|
| 194 |
+
self.loss_dict.update(loss_dict_pix)
|
| 195 |
+
self.loss_dict['loss_pix'] = loss_pix.item()
|
| 196 |
+
# since there may be several losses for sal, the lambdas for them (lambdas_pix) are inside the loss.py
|
| 197 |
+
loss = loss_pix + loss_cls
|
| 198 |
+
if config.out_ref:
|
| 199 |
+
loss = loss + loss_gdt * 1.0
|
| 200 |
+
|
| 201 |
+
self.loss_log.update(loss.item(), inputs.size(0))
|
| 202 |
+
if args.use_accelerate:
|
| 203 |
+
loss = loss / accelerator.gradient_accumulation_steps
|
| 204 |
+
accelerator.backward(loss)
|
| 205 |
+
else:
|
| 206 |
+
loss.backward()
|
| 207 |
+
self.optimizer.step()
|
| 208 |
+
|
| 209 |
+
def train_epoch(self, epoch):
|
| 210 |
+
global logger_loss_idx
|
| 211 |
+
self.model.train()
|
| 212 |
+
self.loss_dict = {}
|
| 213 |
+
if epoch > args.epochs + config.finetune_last_epochs:
|
| 214 |
+
if config.task == 'Matting':
|
| 215 |
+
self.pix_loss.lambdas_pix_last['mae'] *= 1
|
| 216 |
+
self.pix_loss.lambdas_pix_last['mse'] *= 0.9
|
| 217 |
+
self.pix_loss.lambdas_pix_last['ssim'] *= 0.9
|
| 218 |
+
else:
|
| 219 |
+
self.pix_loss.lambdas_pix_last['bce'] *= 0
|
| 220 |
+
self.pix_loss.lambdas_pix_last['ssim'] *= 1
|
| 221 |
+
self.pix_loss.lambdas_pix_last['iou'] *= 0.5
|
| 222 |
+
self.pix_loss.lambdas_pix_last['mae'] *= 0.9
|
| 223 |
+
|
| 224 |
+
for batch_idx, batch in enumerate(self.train_loader):
|
| 225 |
+
# with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model):
|
| 226 |
+
self._train_batch(batch)
|
| 227 |
+
# Logger
|
| 228 |
+
if (epoch < 2 and batch_idx < 100 and batch_idx % 20 == 0) or batch_idx % max(100, len(self.train_loader) / 100 // 100 * 100) == 0:
|
| 229 |
+
info_progress = f'Epoch[{epoch}/{args.epochs}] Iter[{batch_idx}/{len(self.train_loader)}].'
|
| 230 |
+
info_loss = 'Training Losses:'
|
| 231 |
+
for loss_name, loss_value in self.loss_dict.items():
|
| 232 |
+
info_loss += f' {loss_name}: {loss_value:.5g} |'
|
| 233 |
+
logger.info(' '.join((info_progress, info_loss)))
|
| 234 |
+
info_loss = f'@==Final== Epoch[{epoch}/{args.epochs}] Training Loss: {self.loss_log.avg:.5g} '
|
| 235 |
+
logger.info(info_loss)
|
| 236 |
+
|
| 237 |
+
self.lr_scheduler.step()
|
| 238 |
+
return self.loss_log.avg
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def main():
|
| 242 |
+
|
| 243 |
+
trainer = Trainer(
|
| 244 |
+
data_loaders=init_data_loaders(to_be_distributed),
|
| 245 |
+
model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
for epoch in range(epoch_st, args.epochs+1):
|
| 249 |
+
train_loss = trainer.train_epoch(epoch)
|
| 250 |
+
# Save checkpoint
|
| 251 |
+
if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0:
|
| 252 |
+
if args.use_accelerate:
|
| 253 |
+
state_dict = trainer.model.state_dict()
|
| 254 |
+
else:
|
| 255 |
+
state_dict = trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict()
|
| 256 |
+
torch.save(state_dict, os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)))
|
| 257 |
+
if to_be_distributed:
|
| 258 |
+
destroy_process_group()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == '__main__':
|
| 262 |
+
main()
|
train.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Run script
|
| 3 |
+
# Settings of training & test for different tasks.
|
| 4 |
+
method="$1"
|
| 5 |
+
task=$(python3 config.py --print_task)
|
| 6 |
+
case "${task}" in
|
| 7 |
+
'DIS5K') epochs=500 && val_last=50 && step=5 ;;
|
| 8 |
+
'COD') epochs=150 && val_last=50 && step=5 ;;
|
| 9 |
+
'HRSOD') epochs=150 && val_last=50 && step=5 ;;
|
| 10 |
+
'General') epochs=200 && val_last=50 && step=5 ;;
|
| 11 |
+
'General-2K') epochs=250 && val_last=30 && step=2 ;;
|
| 12 |
+
'Matting') epochs=150 && val_last=50 && step=5 ;;
|
| 13 |
+
esac
|
| 14 |
+
|
| 15 |
+
# Train
|
| 16 |
+
devices=$2
|
| 17 |
+
nproc_per_node=$(echo ${devices%%,} | grep -o "," | wc -l)
|
| 18 |
+
|
| 19 |
+
to_be_distributed=`echo ${nproc_per_node} | awk '{if($e > 0) print "True"; else print "False";}'`
|
| 20 |
+
|
| 21 |
+
echo Training started at $(date)
|
| 22 |
+
resume_weights_path='path_to_a_pth'
|
| 23 |
+
if [ ${to_be_distributed} == "True" ]
|
| 24 |
+
then
|
| 25 |
+
# Adapt the nproc_per_node by the number of GPUs. Give 8989 as the default value of master_port.
|
| 26 |
+
echo "Multi-GPU mode received..."
|
| 27 |
+
CUDA_VISIBLE_DEVICES=${devices} \
|
| 28 |
+
torchrun --standalone --nproc_per_node $((nproc_per_node+1)) \
|
| 29 |
+
train.py --ckpt_dir ckpts/${method} --epochs ${epochs} \
|
| 30 |
+
--dist ${to_be_distributed} \
|
| 31 |
+
--resume ${resume_weights_path} \
|
| 32 |
+
--use_accelerate
|
| 33 |
+
else
|
| 34 |
+
echo "Single-GPU mode received..."
|
| 35 |
+
CUDA_VISIBLE_DEVICES=${devices} \
|
| 36 |
+
python train.py --ckpt_dir ckpts/${method} --epochs ${epochs} \
|
| 37 |
+
--dist ${to_be_distributed} \
|
| 38 |
+
--resume ${resume_weights_path} \
|
| 39 |
+
--use_accelerate
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
echo Training finished at $(date)
|
train_test.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Example: `setsid nohup ./train_test.sh BiRefNet 0,1,2,3,4,5,6,7 0 &>nohup.log &`
|
| 3 |
+
|
| 4 |
+
method=${1:-"BSL"}
|
| 5 |
+
devices=${2:-"0,1,2,3,4,5,6,7"}
|
| 6 |
+
|
| 7 |
+
bash train.sh ${method} ${devices}
|
| 8 |
+
|
| 9 |
+
devices_test=${3:-0}
|
| 10 |
+
bash test.sh ${devices_test}
|
| 11 |
+
|
| 12 |
+
hostname
|
utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import cv2
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]):
|
| 12 |
+
if color_type.lower() == 'rgb':
|
| 13 |
+
image = cv2.imread(path)
|
| 14 |
+
elif color_type.lower() == 'gray':
|
| 15 |
+
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
| 16 |
+
else:
|
| 17 |
+
print('Select the color_type to return, either to RGB or gray image.')
|
| 18 |
+
return
|
| 19 |
+
if size:
|
| 20 |
+
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
|
| 21 |
+
if color_type.lower() == 'rgb':
|
| 22 |
+
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB')
|
| 23 |
+
else:
|
| 24 |
+
image = Image.fromarray(image).convert('L')
|
| 25 |
+
return image
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_state_dict(state_dict, unwanted_prefixes=['module.', '_orig_mod.']):
|
| 30 |
+
for k, v in list(state_dict.items()):
|
| 31 |
+
prefix_length = 0
|
| 32 |
+
for unwanted_prefix in unwanted_prefixes:
|
| 33 |
+
if k[prefix_length:].startswith(unwanted_prefix):
|
| 34 |
+
prefix_length += len(unwanted_prefix)
|
| 35 |
+
state_dict[k[prefix_length:]] = state_dict.pop(k)
|
| 36 |
+
return state_dict
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def generate_smoothed_gt(gts):
|
| 40 |
+
epsilon = 0.001
|
| 41 |
+
new_gts = (1-epsilon)*gts+epsilon/2
|
| 42 |
+
return new_gts
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Logger():
|
| 46 |
+
def __init__(self, path="log.txt"):
|
| 47 |
+
self.logger = logging.getLogger('BiRefNet')
|
| 48 |
+
self.file_handler = logging.FileHandler(path, "w")
|
| 49 |
+
self.stdout_handler = logging.StreamHandler()
|
| 50 |
+
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
| 51 |
+
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
|
| 52 |
+
self.logger.addHandler(self.file_handler)
|
| 53 |
+
self.logger.addHandler(self.stdout_handler)
|
| 54 |
+
self.logger.setLevel(logging.INFO)
|
| 55 |
+
self.logger.propagate = False
|
| 56 |
+
|
| 57 |
+
def info(self, txt):
|
| 58 |
+
self.logger.info(txt)
|
| 59 |
+
|
| 60 |
+
def close(self):
|
| 61 |
+
self.file_handler.close()
|
| 62 |
+
self.stdout_handler.close()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AverageMeter(object):
|
| 66 |
+
"""Computes and stores the average and current value"""
|
| 67 |
+
def __init__(self):
|
| 68 |
+
self.reset()
|
| 69 |
+
|
| 70 |
+
def reset(self):
|
| 71 |
+
self.val = 0.0
|
| 72 |
+
self.avg = 0.0
|
| 73 |
+
self.sum = 0.0
|
| 74 |
+
self.count = 0.0
|
| 75 |
+
|
| 76 |
+
def update(self, val, n=1):
|
| 77 |
+
self.val = val
|
| 78 |
+
self.sum += val * n
|
| 79 |
+
self.count += n
|
| 80 |
+
self.avg = self.sum / self.count
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def save_checkpoint(state, path, filename="latest.pth"):
|
| 84 |
+
torch.save(state, os.path.join(path, filename))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def save_tensor_img(tenor_im, path):
|
| 88 |
+
im = tenor_im.cpu().clone()
|
| 89 |
+
im = im.squeeze(0)
|
| 90 |
+
tensor2pil = transforms.ToPILImage()
|
| 91 |
+
im = tensor2pil(im)
|
| 92 |
+
im.save(path)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def set_seed(seed):
|
| 96 |
+
torch.manual_seed(seed)
|
| 97 |
+
torch.cuda.manual_seed_all(seed)
|
| 98 |
+
np.random.seed(seed)
|
| 99 |
+
random.seed(seed)
|
| 100 |
+
torch.backends.cudnn.deterministic = True
|