| import os |
| import json |
| import numpy as np |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from torch.nn import functional as F |
|
|
| from .datasets import register_dataset |
| from .data_utils import truncate_feats |
| from IPython import embed |
|
|
| @register_dataset("vidf") |
| class VidF(Dataset): |
| def __init__( |
| self, |
| is_training, |
| split, |
| feat_folder, |
| json_file, |
| feat_stride, |
| num_frames, |
| default_fps, |
| downsample_rate, |
| max_seq_len, |
| trunc_thresh, |
| crop_ratio, |
| input_dim, |
| num_classes, |
| file_prefix, |
| file_ext, |
| force_upsampling, |
| **kwargs, |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.num_frames = num_frames |
| self.input_dim = input_dim |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| self.db_attributes = { |
| 'dataset_name': 'vidf', |
| 'tiou_thresholds': np.linspace(0.3, 0.7, 5), |
| 'empty_label_ids': [], |
| } |
|
|
|
|
| self.version = kwargs['version'] |
|
|
| self.data_dir = f'/home/users/xxx/scratch/dataset/vidf/{self.version}' |
| assert os.path.exists(self.data_dir), 'Please specify data_dir' |
| self.split = split |
| if isinstance(self.split, str): |
| self.split = [self.split] |
|
|
| annotations = [] |
| self.split = [s for s in self.split if "real" not in s] + [s for s in self.split if "real" in s] |
| for split_itm in self.split: |
| anno_file = open( |
| os.path.join( |
| self.data_dir, |
| "{}.txt".format(split_itm) |
| ), 'r' |
| ) |
| line_cnt = -1 |
| tmp_annotations = [] |
| for line in anno_file: |
| line_cnt += 1 |
| |
| |
| anno = line |
| if 'real' in split_itm: |
| vid, duration = anno.split(" ") |
| duration = float(duration) |
| pairs = [] |
| else: |
| vid, duration, time_str = anno.split(" ") |
| duration = float(duration) |
| time_str = time_str.replace('\n', '') |
| pairs = [x.split('=') for x in time_str.split('+')] |
| time_list = [] |
| start_list = [] |
| end_list = [] |
| for p in pairs: |
| |
| assert len(p) == 2, f"Invalid format: '{'='.join(p)}' is not in start=end format" |
|
|
| start_str, end_str = p |
| |
| start = float(start_str) |
| end = min(float(end_str), duration) |
|
|
| time_list.append([start, end]) |
| start_list.append(start) |
| end_list.append(end) |
| |
| |
| |
| |
| |
|
|
| tmp_annotations.append( |
| {'video': vid, 'times': time_list, 'duration': duration}) |
| anno_file.close() |
| assert 'real' not in split_itm |
| if 'real' in split_itm: |
| tmp_annotations_num_1 = int(len(tmp_annotations) * kwargs['real_ratio']) |
| tmp_annotations_num_2 = int(len(annotations)) |
| tmp_annotations_num = min(tmp_annotations_num_1, tmp_annotations_num_2) |
| tmp_annotations = tmp_annotations[:tmp_annotations_num] |
| annotations += tmp_annotations |
|
|
| if 'train' in split[0]: |
| annot_num = kwargs['train_annot_num'] |
| else: |
| annot_num = kwargs['test_annot_num'] |
| if annot_num > 0: |
| indices = np.linspace(0, len(annotations) - 1, annot_num, dtype=int) |
| annotations = [annotations[i] for i in indices] |
|
|
| self.annotations = annotations |
|
|
| self.feature_type = 'clipL14' |
|
|
| def get_attributes(self): |
| return self.db_attributes |
|
|
| def _load_json_db(self, json_file): |
| |
| with open(json_file, 'r') as fid: |
| json_data = json.load(fid) |
| json_db = json_data['database'] |
|
|
| |
| if self.label_dict is None: |
| label_dict = {} |
| for key, value in json_db.items(): |
| for act in value['annotations']: |
| label_dict[act['label']] = act['label_id'] |
|
|
| dict_db = tuple() |
| for key, value in json_db.items(): |
| if value['subset'].lower() not in self.split: |
| continue |
| |
| feat_file = os.path.join(self.feat_folder, |
| self.file_prefix + key + self.file_ext) |
| if not os.path.exists(feat_file): |
| continue |
|
|
| |
| if self.default_fps is not None: |
| fps = self.default_fps |
| elif 'fps' in value: |
| fps = value['fps'] |
| else: |
| assert False, "Unknown video FPS." |
|
|
| if 'duration' in value: |
| duration = value['duration'] |
| else: |
| duration = 1e8 |
|
|
| |
| if ('annotations' in value) and (len(value['annotations']) > 0): |
| segments, labels = [], [] |
| for act in value['annotations']: |
| segments.append(act['segment']) |
| labels.append([label_dict[act['label']]]) |
|
|
| segments = np.asarray(segments, dtype=np.float32) |
| labels = np.squeeze(np.asarray(labels, dtype=np.int64), axis=1) |
| else: |
| segments = None |
| labels = None |
| dict_db += ({'id': key, |
| 'fps' : fps, |
| 'duration' : duration, |
| 'segments' : segments, |
| 'labels' : labels |
| }, ) |
|
|
| return dict_db, label_dict |
|
|
| def __len__(self): |
| return len(self.annotations) |
|
|
| def __getitem__(self, idx): |
|
|
| |
| C = self.input_dim |
|
|
| |
| video_id = self.annotations[idx]['video'].split('.mp4')[0] |
| visual_input = self.get_video_features(video_id) |
|
|
| def average_to_fixed_length(visual_input, num_sample_clips): |
| num_clips = visual_input.shape[0] |
| idxs = torch.arange(0, num_sample_clips + 1, 1.0) / num_sample_clips * num_clips |
| idxs = torch.min(torch.round(idxs).long(), torch.tensor(num_clips - 1)) |
| new_visual_input = [] |
| for i in range(num_sample_clips): |
| s_idx, e_idx = idxs[i].item(), idxs[i + 1].item() |
| if s_idx < e_idx: |
| new_visual_input.append(torch.mean(visual_input[s_idx:e_idx], dim=0)) |
| else: |
| new_visual_input.append(visual_input[s_idx]) |
| new_visual_input = torch.stack(new_visual_input, dim=0) |
| return new_visual_input |
|
|
| visual_input = average_to_fixed_length(visual_input, self.num_frames) |
| feats = visual_input.permute(1, 0) |
|
|
| times = torch.tensor(self.annotations[idx]['times']) |
| N = times.shape[0] |
|
|
| starts = times[:, 0] / self.annotations[idx]['duration'] * self.num_frames |
| ends = times[:, 1] / self.annotations[idx]['duration'] * self.num_frames |
|
|
| segments = torch.stack([starts, ends], dim=1) |
|
|
| labels = torch.zeros((N,)).long() |
|
|
| data_dict = {'video_id' : str(idx), |
| 'feats' : feats, |
| 'segments' : segments, |
| 'labels' : labels, |
| 'feat_num_frames' : self.num_frames, |
| 'duration' : self.annotations[idx]['duration'], |
| 'gt_time' : self.annotations[idx]['times'], |
| } |
|
|
| return data_dict |
|
|
|
|
| def get_video_features(self, vid): |
| if 'clipL14' in self.feature_type: |
| features = np.load(os.path.join(self.data_dir, f'../feat/01a.2a_L14/{vid}.npy')) |
| features = torch.from_numpy(features).float() |
| return features |
|
|