from torch.utils.data.dataloader import default_collate import torch import numpy as np # Modified by Zhang Ruicheng on 2024.04.21 def snake_collator(batch): ret = {'inp': default_collate([b['inp'] for b in batch])} ret.update({'orig_img': default_collate([b['orig_img'] for b in batch])}) ret.update({'img_path': default_collate([b['img_path'] for b in batch])}) meta = default_collate([b['meta'] for b in batch]) ret.update({'meta': meta}) if 'vis_GT' in meta: return ret # detection ct_hm = default_collate([b['ct_hm'] for b in batch]) batch_size = len(batch) ct_num = torch.max(meta['ct_num']) wh = torch.zeros([batch_size, ct_num, 2], dtype=torch.float) # reg = torch.zeros([batch_size, ct_num, 2], dtype=torch.float) ct_cls = torch.zeros([batch_size, ct_num], dtype=torch.int64) ct_ind = torch.zeros([batch_size, ct_num], dtype=torch.int64) ct_01 = torch.zeros([batch_size, ct_num], dtype=torch.uint8) for i in range(batch_size): ct_01[i, :meta['ct_num'][i]] = 1 if ct_num != 0: wh[ct_01] = torch.Tensor(sum([b['wh'] for b in batch], [])) # reg[ct_01] = torch.Tensor(sum([b['reg'] for b in batch], [])) ct_cls[ct_01] = torch.LongTensor(sum([b['ct_cls'] for b in batch], [])) ct_ind[ct_01] = torch.LongTensor(sum([b['ct_ind'] for b in batch], [])) detection = {'ct_hm': ct_hm, 'wh': wh, 'ct_cls': ct_cls, 'ct_ind': ct_ind, 'ct_01': ct_01.float()} # detection = {'ct_hm': ct_hm, 'wh': wh, 'reg': reg, 'ct_cls': ct_cls, 'ct_ind': ct_ind, 'ct_01': ct_01.float()} ret.update(detection) from lib.utils.snake import snake_config # init i_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) c_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) i_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) c_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) if ct_num != 0: i_it_4pys[ct_01] = torch.Tensor(sum([b['i_it_4py'] for b in batch], [])) c_it_4pys[ct_01] = torch.Tensor(sum([b['c_it_4py'] for b in batch], [])) i_gt_4pys[ct_01] = torch.Tensor(sum([b['i_gt_4py'] for b in batch], [])) c_gt_4pys[ct_01] = torch.Tensor(sum([b['c_gt_4py'] for b in batch], [])) init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'i_gt_4py': i_gt_4pys, 'c_gt_4py': c_gt_4pys} ret.update(init) # evolution i_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) c_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) i_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) c_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) if ct_num != 0: i_it_pys[ct_01] = torch.Tensor(sum([b['i_it_py'] for b in batch], [])) c_it_pys[ct_01] = torch.Tensor(sum([b['c_it_py'] for b in batch], [])) i_gt_pys[ct_01] = torch.Tensor(sum([b['i_gt_py'] for b in batch], [])) c_gt_pys[ct_01] = torch.Tensor(sum([b['c_gt_py'] for b in batch], [])) evolution = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys, 'i_gt_py': i_gt_pys, 'c_gt_py': c_gt_pys} ret.update(evolution) # Aggregate YOLO targets if present in samples has_yolo = any(('bboxes' in b and 'cls' in b and 'batch_idx' in b) for b in batch) if has_yolo: all_bboxes = [] all_cls = [] all_batch = [] for i, b in enumerate(batch): if 'bboxes' in b and b['bboxes'] is not None and b['bboxes'].numel() > 0: n = b['bboxes'].shape[0] all_bboxes.append(b['bboxes']) all_cls.append(b['cls']) # override batch_idx with actual image index i all_batch.append(torch.full((n, 1), float(i))) if len(all_bboxes) > 0: ret['bboxes'] = torch.cat(all_bboxes, dim=0) ret['cls'] = torch.cat(all_cls, dim=0) ret['batch_idx'] = torch.cat(all_batch, dim=0) else: ret['bboxes'] = torch.zeros((0, 4), dtype=torch.float32) ret['cls'] = torch.zeros((0, 1), dtype=torch.float32) ret['batch_idx'] = torch.zeros((0, 1), dtype=torch.float32) return ret def dsnake_collator(batch): ret = snake_collator(batch) meta = ret['meta'] # detection act_hm = default_collate([b['act_hm'] for b in batch]) batch_size = len(batch) act_num = torch.max(meta['act_num']) awh = torch.zeros([batch_size, act_num, 2], dtype=torch.float) act_ind = torch.zeros([batch_size, act_num], dtype=torch.int64) act_01 = torch.zeros([batch_size, act_num], dtype=torch.uint8) for i in range(batch_size): act_01[i, :meta['act_num'][i]] = 1 awh[act_01] = torch.Tensor(sum([b['awh'] for b in batch], [])) act_ind[act_01] = torch.LongTensor(sum([b['act_ind'] for b in batch], [])) adet = {'act_hm': act_hm, 'awh': awh, 'act_ind': act_ind, 'act_01': act_01.float()} ret.update(adet) return ret # Modified by Zhang Ruicheng on 2024.04.21 ''' def rcnn_snake_collator(batch): ret = {'inp': default_collate([b['inp'] for b in batch])} meta = default_collate([b['meta'] for b in batch]) ret.update({'meta': meta}) if 'vis_GT' in meta: return ret # detection act_hm = default_collate([b['act_hm'] for b in batch]) batch_size = len(batch) act_num = torch.max(meta['act_num']) awh = torch.zeros([batch_size, act_num, 2], dtype=torch.float) act_ind = torch.zeros([batch_size, act_num], dtype=torch.int64) act_01 = torch.zeros([batch_size, act_num], dtype=torch.uint8) for i in range(batch_size): act_01[i, :meta['act_num'][i]] = 1 if act_num != 0: awh[act_01] = torch.Tensor(sum([b['awh'] for b in batch], [])) act_ind[act_01] = torch.LongTensor(sum([b['act_ind'] for b in batch], [])) detection = {'act_hm': act_hm, 'awh': awh, 'act_ind': act_ind, 'act_01': act_01.float()} ret.update(detection) from lib.utils.rcnn_snake import rcnn_snake_config as snake_config ct_num = torch.max(meta['ct_num']) ct_01 = torch.zeros([batch_size, ct_num], dtype=torch.uint8) for i in range(batch_size): ct_01[i, :meta['ct_num'][i]] = 1 ret.update({'ct_01': ct_01.float()}) # component detection cp_hm = default_collate(sum([[hm for hm in b['cp_hm']] for b in batch], [])) cp_num_list = sum([[len(wh) for wh in b['cp_wh']] for b in batch], []) cp_num = max(cp_num_list) cp_wh = torch.zeros([len(cp_hm), cp_num, 2], dtype=torch.float) cp_ind = torch.zeros([len(cp_hm), cp_num], dtype=torch.int64) cp_01 = torch.zeros([len(cp_hm), cp_num], dtype=torch.uint8) for i in range(len(cp_hm)): cp_01[i, :cp_num_list[i]] = 1 if cp_num != 0: cp_wh[cp_01] = torch.Tensor(sum(sum([b['cp_wh'] for b in batch], []), [])) cp_ind[cp_01] = torch.LongTensor(sum(sum([b['cp_ind'] for b in batch], []), [])) cp_hm_ = torch.zeros([batch_size, act_num, 1, snake_config.cp_h, snake_config.cp_w], dtype=torch.float) cp_wh_ = torch.zeros([batch_size, act_num, cp_num, 2], dtype=torch.float) cp_ind_ = torch.zeros([batch_size, act_num, cp_num], dtype=torch.int64) cp_01_ = torch.zeros([batch_size, act_num, cp_num], dtype=torch.uint8) cp_hm_[act_01] = cp_hm cp_wh_[act_01] = cp_wh cp_ind_[act_01] = cp_ind cp_01_[act_01] = cp_01 cp_detection = {'cp_hm': cp_hm_, 'cp_wh': cp_wh_, 'cp_ind': cp_ind_, 'cp_01': cp_01_.float()} ret.update(cp_detection) # init i_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) c_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) i_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) c_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) if ct_num != 0: i_it_4pys[ct_01] = torch.Tensor(sum([b['i_it_4py'] for b in batch], [])) c_it_4pys[ct_01] = torch.Tensor(sum([b['c_it_4py'] for b in batch], [])) i_gt_4pys[ct_01] = torch.Tensor(sum([b['i_gt_4py'] for b in batch], [])) c_gt_4pys[ct_01] = torch.Tensor(sum([b['c_gt_4py'] for b in batch], [])) init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'i_gt_4py': i_gt_4pys, 'c_gt_4py': c_gt_4pys} ret.update(init) # evolution i_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) c_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) i_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) c_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) if ct_num != 0: i_it_pys[ct_01] = torch.Tensor(sum([b['i_it_py'] for b in batch], [])) c_it_pys[ct_01] = torch.Tensor(sum([b['c_it_py'] for b in batch], [])) i_gt_pys[ct_01] = torch.Tensor(sum([b['i_gt_py'] for b in batch], [])) c_gt_pys[ct_01] = torch.Tensor(sum([b['c_gt_py'] for b in batch], [])) evolution = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys, 'i_gt_py': i_gt_pys, 'c_gt_py': c_gt_pys} ret.update(evolution) return ret def ext_snake_collator(batch): ret = {'inp': default_collate([b['inp'] for b in batch])} meta = default_collate([b['meta'] for b in batch]) ret.update({'meta': meta}) if 'vis_GT' in meta: return ret # detection ct_hm = default_collate([b['ct_hm'] for b in batch]) batch_size = len(batch) ct_num = torch.max(meta['ct_num']) ext = torch.zeros([batch_size, ct_num, 8], dtype=torch.float) ct_cls = torch.zeros([batch_size, ct_num], dtype=torch.int64) ct_ind = torch.zeros([batch_size, ct_num], dtype=torch.int64) ct_01 = torch.zeros([batch_size, ct_num], dtype=torch.uint8) for i in range(batch_size): ct_01[i, :meta['ct_num'][i]] = 1 ext[ct_01] = torch.Tensor(sum([b['ext'] for b in batch], [])) ct_cls[ct_01] = torch.LongTensor(sum([b['ct_cls'] for b in batch], [])) ct_ind[ct_01] = torch.LongTensor(sum([b['ct_ind'] for b in batch], [])) detection = {'ct_hm': ct_hm, 'ext': ext, 'ct_cls': ct_cls, 'ct_ind': ct_ind, 'ct_01': ct_01.float()} ret.update(detection) from lib.utils.snake import snake_config # init i_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) c_it_4pys = torch.zeros([batch_size, ct_num, snake_config.init_poly_num, 2], dtype=torch.float) i_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) c_gt_4pys = torch.zeros([batch_size, ct_num, 4, 2], dtype=torch.float) i_it_4pys[ct_01] = torch.Tensor(sum([b['i_it_4py'] for b in batch], [])) c_it_4pys[ct_01] = torch.Tensor(sum([b['c_it_4py'] for b in batch], [])) i_gt_4pys[ct_01] = torch.Tensor(sum([b['i_gt_4py'] for b in batch], [])) c_gt_4pys[ct_01] = torch.Tensor(sum([b['c_gt_4py'] for b in batch], [])) init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'i_gt_4py': i_gt_4pys, 'c_gt_4py': c_gt_4pys} ret.update(init) # evolution i_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) c_it_pys = torch.zeros([batch_size, ct_num, snake_config.poly_num, 2], dtype=torch.float) i_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) c_gt_pys = torch.zeros([batch_size, ct_num, snake_config.gt_poly_num, 2], dtype=torch.float) i_it_pys[ct_01] = torch.Tensor(sum([b['i_it_py'] for b in batch], [])) c_it_pys[ct_01] = torch.Tensor(sum([b['c_it_py'] for b in batch], [])) i_gt_pys[ct_01] = torch.Tensor(sum([b['i_gt_py'] for b in batch], [])) c_gt_pys[ct_01] = torch.Tensor(sum([b['c_gt_py'] for b in batch], [])) evolution = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys, 'i_gt_py': i_gt_pys, 'c_gt_py': c_gt_pys} ret.update(evolution) return ret ''' # Modified by Zhang Ruicheng on 2024.04.21 _collators = { 'snake': snake_collator, 'ct': snake_collator, #'rcnn_snake': rcnn_snake_collator, #'ct_rcnn': rcnn_snake_collator } def make_collator(cfg): if cfg.task in _collators: return _collators[cfg.task] else: return default_collate