|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from dataset.ucf_jhmdb import UCF_JHMDB_Dataset |
|
|
from dataset.ava import AVA_Dataset |
|
|
from dataset.transforms import Augmentation, BaseTransform |
|
|
|
|
|
from evaluator.ucf_jhmdb_evaluator import UCF_JHMDB_Evaluator |
|
|
from evaluator.ava_evaluator import AVA_Evaluator |
|
|
|
|
|
|
|
|
def build_dataset(d_cfg, args, is_train=False): |
|
|
""" |
|
|
d_cfg: dataset config |
|
|
""" |
|
|
|
|
|
augmentation = Augmentation( |
|
|
img_size=d_cfg['train_size'], |
|
|
jitter=d_cfg['jitter'], |
|
|
hue=d_cfg['hue'], |
|
|
saturation=d_cfg['saturation'], |
|
|
exposure=d_cfg['exposure'] |
|
|
) |
|
|
basetransform = BaseTransform( |
|
|
img_size=d_cfg['test_size'], |
|
|
) |
|
|
|
|
|
|
|
|
if args.dataset in ['ucf24', 'jhmdb21']: |
|
|
data_dir = os.path.join(args.root, 'ucf24') |
|
|
|
|
|
|
|
|
dataset = UCF_JHMDB_Dataset( |
|
|
data_root=data_dir, |
|
|
dataset=args.dataset, |
|
|
img_size=d_cfg['train_size'], |
|
|
transform=augmentation, |
|
|
is_train=is_train, |
|
|
len_clip=args.len_clip, |
|
|
sampling_rate=d_cfg['sampling_rate'] |
|
|
) |
|
|
num_classes = dataset.num_classes |
|
|
|
|
|
|
|
|
evaluator = UCF_JHMDB_Evaluator( |
|
|
data_root=data_dir, |
|
|
dataset=args.dataset, |
|
|
model_name=args.version, |
|
|
metric='fmap', |
|
|
img_size=d_cfg['test_size'], |
|
|
len_clip=args.len_clip, |
|
|
batch_size=args.test_batch_size, |
|
|
conf_thresh=0.01, |
|
|
iou_thresh=0.5, |
|
|
gt_folder=d_cfg['gt_folder'], |
|
|
save_path='./evaluator/eval_results/', |
|
|
transform=basetransform, |
|
|
collate_fn=CollateFunc() |
|
|
) |
|
|
|
|
|
elif args.dataset == 'ava_v2.2': |
|
|
|
|
|
data_dir = args.root |
|
|
|
|
|
|
|
|
dataset = AVA_Dataset( |
|
|
cfg=d_cfg, |
|
|
data_root=data_dir, |
|
|
is_train=True, |
|
|
img_size=d_cfg['train_size'], |
|
|
transform=augmentation, |
|
|
len_clip=args.len_clip, |
|
|
sampling_rate=d_cfg['sampling_rate'] |
|
|
) |
|
|
num_classes = 3 |
|
|
|
|
|
|
|
|
evaluator = AVA_Evaluator( |
|
|
d_cfg=d_cfg, |
|
|
data_root=data_dir, |
|
|
img_size=d_cfg['test_size'], |
|
|
len_clip=args.len_clip, |
|
|
sampling_rate=d_cfg['sampling_rate'], |
|
|
batch_size=args.test_batch_size, |
|
|
transform=basetransform, |
|
|
collate_fn=CollateFunc(), |
|
|
full_test_on_val=False, |
|
|
version='v2.2' |
|
|
) |
|
|
|
|
|
else: |
|
|
print('unknow dataset !! Only support ucf24 & jhmdb21 & ava_v2.2 !!') |
|
|
exit(0) |
|
|
|
|
|
print('==============================') |
|
|
print('Training model on:', args.dataset) |
|
|
print('The dataset size:', len(dataset)) |
|
|
|
|
|
if not args.eval: |
|
|
|
|
|
evaluator = None |
|
|
|
|
|
return dataset, evaluator, num_classes |
|
|
|
|
|
|
|
|
def build_dataloader(args, dataset, batch_size, collate_fn=None, is_train=False): |
|
|
if is_train: |
|
|
|
|
|
if args.distributed: |
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
|
else: |
|
|
sampler = torch.utils.data.RandomSampler(dataset) |
|
|
|
|
|
batch_sampler_train = torch.utils.data.BatchSampler(sampler, |
|
|
batch_size, |
|
|
drop_last=True) |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset=dataset, |
|
|
batch_sampler=batch_sampler_train, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
else: |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset=dataset, |
|
|
shuffle=False, |
|
|
collate_fn=collate_fn, |
|
|
num_workers=args.num_workers, |
|
|
drop_last=False, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
def load_weight(model, path_to_ckpt=None): |
|
|
if path_to_ckpt is None: |
|
|
print('No trained weight ..') |
|
|
return model |
|
|
|
|
|
checkpoint = torch.load(path_to_ckpt, map_location='cpu') |
|
|
|
|
|
checkpoint_state_dict = checkpoint.pop("model") |
|
|
|
|
|
model_state_dict = model[0].state_dict() |
|
|
|
|
|
for k in list(checkpoint_state_dict.keys()): |
|
|
if k in model_state_dict: |
|
|
shape_model = tuple(model_state_dict[k].shape) |
|
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
|
|
if shape_model != shape_checkpoint: |
|
|
checkpoint_state_dict.pop(k) |
|
|
else: |
|
|
checkpoint_state_dict.pop(k) |
|
|
print(k) |
|
|
|
|
|
model[0].load_state_dict(checkpoint_state_dict) |
|
|
print('Finished loading model!') |
|
|
|
|
|
return model[0] |
|
|
|
|
|
|
|
|
def is_parallel(model): |
|
|
|
|
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) |
|
|
|
|
|
|
|
|
class CollateFunc(object): |
|
|
def __call__(self, batch): |
|
|
batch_frame_id = [] |
|
|
batch_key_target = [] |
|
|
batch_video_clips = [] |
|
|
|
|
|
for sample in batch: |
|
|
key_frame_id = sample[0] |
|
|
video_clip = sample[1] |
|
|
key_target = sample[2] |
|
|
|
|
|
batch_frame_id.append(key_frame_id) |
|
|
batch_video_clips.append(video_clip) |
|
|
batch_key_target.append(key_target) |
|
|
|
|
|
|
|
|
batch_video_clips = torch.stack(batch_video_clips) |
|
|
|
|
|
return batch_frame_id, batch_video_clips, batch_key_target |
|
|
|