| |
| |
|
|
| import os |
| import random |
| import numpy as np |
| import glob |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
|
|
| |
| class UCF_JHMDB_Dataset(Dataset): |
| def __init__(self, |
| data_root, |
| dataset='ucf24', |
| img_size=224, |
| transform=None, |
| is_train=False, |
| len_clip=16, |
| sampling_rate=1): |
| self.data_root = data_root |
| self.dataset = dataset |
| self.transform = transform |
| self.is_train = is_train |
| |
| self.img_size = img_size |
| self.len_clip = len_clip |
| self.sampling_rate = sampling_rate |
| |
| if self.is_train: |
| self.split_list = 'trainlist.txt' |
| else: |
| self.split_list = 'testlist.txt' |
|
|
| |
| with open(os.path.join(data_root, self.split_list), 'r') as file: |
| self.file_names = file.readlines() |
| self.num_samples = len(self.file_names) |
|
|
| if dataset == 'ucf24': |
| self.num_classes = 24 |
| elif dataset == 'jhmdb21': |
| self.num_classes = 21 |
|
|
|
|
| def __len__(self): |
| return self.num_samples |
|
|
|
|
| def __getitem__(self, index): |
| |
| frame_idx, video_clip, target = self.pull_item(index) |
|
|
| return frame_idx, video_clip, target |
|
|
|
|
| def pull_item(self, index): |
| """ load a data """ |
| assert index <= len(self), 'index range error' |
| image_path = self.file_names[index].rstrip() |
|
|
| img_split = image_path.split('/') |
| |
| img_id = int(img_split[-1][:5]) |
|
|
| |
| label_path = os.path.join(self.data_root, img_split[0], img_split[1], img_split[2], '{:05d}.txt'.format(img_id)) |
|
|
| |
| img_folder = os.path.join(self.data_root, 'rgb-images', img_split[1], img_split[2]) |
|
|
| |
| if self.dataset == 'ucf24': |
| max_num = len(os.listdir(img_folder)) |
| elif self.dataset == 'jhmdb21': |
| max_num = len(os.listdir(img_folder)) - 1 |
|
|
| |
| if self.is_train: |
| d = random.randint(1, 2) |
| else: |
| d = self.sampling_rate |
|
|
| |
| video_clip = [] |
| for i in reversed(range(self.len_clip)): |
| |
| img_id_temp = img_id - i * d |
| if img_id_temp < 1: |
| img_id_temp = 1 |
| elif img_id_temp > max_num: |
| img_id_temp = max_num |
|
|
| |
| if self.dataset == 'ucf24': |
| path_tmp = os.path.join(self.data_root, 'rgb-images', img_split[1], img_split[2] ,'{:05d}.jpg'.format(img_id_temp)) |
| elif self.dataset == 'jhmdb21': |
| path_tmp = os.path.join(self.data_root, 'rgb-images', img_split[1], img_split[2] ,'{:05d}.png'.format(img_id_temp)) |
| frame = Image.open(path_tmp).convert('RGB') |
| ow, oh = frame.width, frame.height |
|
|
| video_clip.append(frame) |
|
|
| frame_id = img_split[1] + '_' +img_split[2] + '_' + img_split[3] |
|
|
| |
| if os.path.getsize(label_path): |
| target = np.loadtxt(label_path) |
| else: |
| target = None |
|
|
| |
| label = target[..., :1] |
| boxes = target[..., 1:] |
| target = np.concatenate([boxes, label], axis=-1).reshape(-1, 5) |
| |
| |
| video_clip, target = self.transform(video_clip, target) |
| |
| video_clip = torch.stack(video_clip, dim=1) |
|
|
| |
| target = { |
| 'boxes': target[:, :4].float(), |
| 'labels': target[:, -1].long() - 1, |
| 'orig_size': [ow, oh], |
| 'video_idx':frame_id[:-10] |
| } |
|
|
| return frame_id, video_clip, target |
|
|
|
|
| def pull_anno(self, index): |
| """ load a data """ |
| assert index <= len(self), 'index range error' |
| image_path = self.file_names[index].rstrip() |
|
|
| img_split = image_path.split('/') |
| |
| img_id = int(img_split[-1][:5]) |
|
|
| |
| label_path = os.path.join(self.data_root, img_split[0], img_split[1], img_split[2], '{:05d}.txt'.format(img_id)) |
|
|
| |
| target = np.loadtxt(label_path) |
| target = target.reshape(-1, 5) |
|
|
| return target |
| |
|
|
| |
| class UCF_JHMDB_VIDEO_Dataset(Dataset): |
| def __init__(self, |
| data_root, |
| dataset='ucf24', |
| img_size=224, |
| transform=None, |
| len_clip=16, |
| sampling_rate=1): |
| self.data_root = data_root |
| self.dataset = dataset |
| self.transform = transform |
| |
| self.img_size = img_size |
| self.len_clip = len_clip |
| self.sampling_rate = sampling_rate |
| |
| if dataset == 'ucf24': |
| self.num_classes = 24 |
| elif dataset == 'jhmdb21': |
| self.num_classes = 21 |
|
|
|
|
| def set_video_data(self, line): |
| self.line = line |
|
|
| |
| self.img_folder = os.path.join(self.data_root, 'rgb-images', self.line) |
|
|
| if self.dataset == 'ucf24': |
| self.label_paths = sorted(glob.glob(os.path.join(self.img_folder, '*.jpg'))) |
| elif self.dataset == 'jhmdb21': |
| self.label_paths = sorted(glob.glob(os.path.join(self.img_folder, '*.png'))) |
|
|
|
|
| def __len__(self): |
| return len(self.label_paths) |
|
|
|
|
| def __getitem__(self, index): |
| return self.pull_item(index) |
|
|
|
|
| def pull_item(self, index): |
| image_path = self.label_paths[index] |
|
|
| video_split = self.line.split('/') |
| video_class = video_split[0] |
| video_file = video_split[1] |
| |
| |
| |
| img_split = image_path.split('/') |
|
|
| |
| img_id = int(img_split[-1][:5]) |
| max_num = len(os.listdir(self.img_folder)) |
| if self.dataset == 'ucf24': |
| img_name = os.path.join(video_class, video_file, '{:05d}.jpg'.format(img_id)) |
| elif self.dataset == 'jhmdb21': |
| img_name = os.path.join(video_class, video_file, '{:05d}.png'.format(img_id)) |
|
|
| |
| video_clip = [] |
| for i in reversed(range(self.len_clip)): |
| |
| img_id_temp = img_id - i |
| if img_id_temp < 1: |
| img_id_temp = 1 |
| elif img_id_temp > max_num: |
| img_id_temp = max_num |
|
|
| |
| if self.dataset == 'ucf24': |
| path_tmp = os.path.join(self.data_root, 'rgb-images', video_class, video_file ,'{:05d}.jpg'.format(img_id_temp)) |
| elif self.dataset == 'jhmdb21': |
| path_tmp = os.path.join(self.data_root, 'rgb-images', video_class, video_file ,'{:05d}.png'.format(img_id_temp)) |
| frame = Image.open(path_tmp).convert('RGB') |
| ow, oh = frame.width, frame.height |
|
|
| video_clip.append(frame) |
|
|
| |
| video_clip, _ = self.transform(video_clip, normalize=False) |
| |
| video_clip = torch.stack(video_clip, dim=1) |
| orig_size = [ow, oh] |
|
|
| target = {'orig_size': [ow, oh]} |
|
|
| return img_name, video_clip, target |
|
|
|
|
|
|
|
|
| if __name__ == '__main__': |
| import cv2 |
| from transforms import Augmentation, BaseTransform |
|
|
| data_root = 'D:/python_work/spatial-temporal_action_detection/dataset/ucf24' |
| dataset = 'ucf24' |
| is_train = True |
| img_size = 224 |
| len_clip = 16 |
| trans_config = { |
| 'jitter': 0.2, |
| 'hue': 0.1, |
| 'saturation': 1.5, |
| 'exposure': 1.5 |
| } |
| train_transform = Augmentation( |
| img_size=img_size, |
| jitter=trans_config['jitter'], |
| saturation=trans_config['saturation'], |
| exposure=trans_config['exposure'] |
| ) |
| val_transform = BaseTransform(img_size=img_size) |
|
|
| train_dataset = UCF_JHMDB_Dataset( |
| data_root=data_root, |
| dataset=dataset, |
| img_size=img_size, |
| transform=train_transform, |
| is_train=is_train, |
| len_clip=len_clip, |
| sampling_rate=1 |
| ) |
|
|
| print(len(train_dataset)) |
| for i in range(len(train_dataset)): |
| frame_id, video_clip, target = train_dataset[i] |
| key_frame = video_clip[:, -1, :, :] |
|
|
| |
| key_frame = key_frame.permute(1, 2, 0).numpy() |
| key_frame = key_frame.astype(np.uint8) |
|
|
| |
| key_frame = key_frame[..., (2, 1, 0)] |
| H, W, C = key_frame.shape |
|
|
| key_frame = key_frame.copy() |
| bboxes = target['boxes'] |
| labels = target['labels'] |
|
|
| for box, cls_id in zip(bboxes, labels): |
| x1, y1, x2, y2 = box |
| x1 = int(x1 * W) |
| y1 = int(y1 * H) |
| x2 = int(x2 * W) |
| y2 = int(y2 * H) |
| key_frame = cv2.rectangle(key_frame, (x1, y1), (x2, y2), (255, 0, 0)) |
|
|
| |
| cv2.imshow('key frame', key_frame) |
| cv2.waitKey(0) |
| |
|
|