Thesis / utils /misc.py
Ryan-Pham's picture
Upload 103 files
beb7843 verified
import os
import torch
import torch.nn as nn
from dataset.ucf_jhmdb import UCF_JHMDB_Dataset
from dataset.ava import AVA_Dataset
from dataset.transforms import Augmentation, BaseTransform
from evaluator.ucf_jhmdb_evaluator import UCF_JHMDB_Evaluator
from evaluator.ava_evaluator import AVA_Evaluator
def build_dataset(d_cfg, args, is_train=False):
"""
d_cfg: dataset config
"""
# transform
augmentation = Augmentation(
img_size=d_cfg['train_size'],
jitter=d_cfg['jitter'],
hue=d_cfg['hue'],
saturation=d_cfg['saturation'],
exposure=d_cfg['exposure']
)
basetransform = BaseTransform(
img_size=d_cfg['test_size'],
)
# dataset
if args.dataset in ['ucf24', 'jhmdb21']:
data_dir = os.path.join(args.root, 'ucf24')
# dataset
dataset = UCF_JHMDB_Dataset(
data_root=data_dir,
dataset=args.dataset,
img_size=d_cfg['train_size'],
transform=augmentation,
is_train=is_train,
len_clip=args.len_clip,
sampling_rate=d_cfg['sampling_rate']
)
num_classes = dataset.num_classes
# evaluator
evaluator = UCF_JHMDB_Evaluator(
data_root=data_dir,
dataset=args.dataset,
model_name=args.version,
metric='fmap',
img_size=d_cfg['test_size'],
len_clip=args.len_clip,
batch_size=args.test_batch_size,
conf_thresh=0.01,
iou_thresh=0.5,
gt_folder=d_cfg['gt_folder'],
save_path='./evaluator/eval_results/',
transform=basetransform,
collate_fn=CollateFunc()
)
elif args.dataset == 'ava_v2.2':
#data_dir = os.path.join(args.root, 'AVA_Dataset')
data_dir = args.root
# dataset
dataset = AVA_Dataset(
cfg=d_cfg,
data_root=data_dir,
is_train=True,
img_size=d_cfg['train_size'],
transform=augmentation,
len_clip=args.len_clip,
sampling_rate=d_cfg['sampling_rate']
)
num_classes = 3
# evaluator
evaluator = AVA_Evaluator(
d_cfg=d_cfg,
data_root=data_dir,
img_size=d_cfg['test_size'],
len_clip=args.len_clip,
sampling_rate=d_cfg['sampling_rate'],
batch_size=args.test_batch_size,
transform=basetransform,
collate_fn=CollateFunc(),
full_test_on_val=False,
version='v2.2'
)
else:
print('unknow dataset !! Only support ucf24 & jhmdb21 & ava_v2.2 !!')
exit(0)
print('==============================')
print('Training model on:', args.dataset)
print('The dataset size:', len(dataset))
if not args.eval:
# no evaluator during training stage
evaluator = None
return dataset, evaluator, num_classes
def build_dataloader(args, dataset, batch_size, collate_fn=None, is_train=False):
if is_train:
# distributed
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler_train = torch.utils.data.BatchSampler(sampler,
batch_size,
drop_last=True)
# train dataloader
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_sampler=batch_sampler_train,
collate_fn=collate_fn,
num_workers=args.num_workers,
pin_memory=True
)
else:
# test dataloader
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
shuffle=False,
collate_fn=collate_fn,
num_workers=args.num_workers,
drop_last=False,
pin_memory=True
)
return dataloader
def load_weight(model, path_to_ckpt=None):
if path_to_ckpt is None:
print('No trained weight ..')
return model
checkpoint = torch.load(path_to_ckpt, map_location='cpu')
# checkpoint state dict
checkpoint_state_dict = checkpoint.pop("model")
# model state dict
model_state_dict = model[0].state_dict()
# check
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
checkpoint_state_dict.pop(k)
else:
checkpoint_state_dict.pop(k)
print(k)
model[0].load_state_dict(checkpoint_state_dict)
print('Finished loading model!')
return model[0]
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
class CollateFunc(object):
def __call__(self, batch):
batch_frame_id = []
batch_key_target = []
batch_video_clips = []
for sample in batch:
key_frame_id = sample[0]
video_clip = sample[1]
key_target = sample[2]
batch_frame_id.append(key_frame_id)
batch_video_clips.append(video_clip)
batch_key_target.append(key_target)
# List [B, 3, T, H, W] -> [B, 3, T, H, W]
batch_video_clips = torch.stack(batch_video_clips)
return batch_frame_id, batch_video_clips, batch_key_target