File size: 4,568 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List, Optional, Union
import mmengine
import numpy as np
import torch
from mmengine.fileio import exists
from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset
try:
import nltk
nltk_imported = True
except ImportError:
nltk_imported = False
@DATASETS.register_module()
class CharadesSTADataset(BaseActionDataset):
def __init__(self,
ann_file: str,
pipeline: List[Union[dict, Callable]],
word2id_file: str,
fps_file: str,
duration_file: str,
num_frames_file: str,
window_size: int,
ft_overlap: float,
data_prefix: Optional[ConfigType] = dict(video=''),
test_mode: bool = False,
**kwargs):
if not nltk_imported:
raise ImportError('nltk is required for CharadesSTADataset')
self.fps_info = mmengine.load(fps_file)
self.duration_info = mmengine.load(duration_file)
self.num_frames = mmengine.load(num_frames_file)
self.word2id = mmengine.load(word2id_file)
self.ft_interval = int(window_size * (1 - ft_overlap))
super().__init__(
ann_file,
pipeline=pipeline,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
def load_data_list(self) -> List[dict]:
"""Load annotation file to get video information."""
exists(self.ann_file)
data_list = []
with open(self.ann_file) as f:
anno_database = f.readlines()
for item in anno_database:
first_part, query_sentence = item.strip().split('##')
query_sentence = query_sentence.replace('.', '')
query_words = nltk.word_tokenize(query_sentence)
query_tokens = [self.word2id[word] for word in query_words]
query_length = len(query_tokens)
query_tokens = torch.from_numpy(np.array(query_tokens))
vid_name, start_time, end_time = first_part.split()
duration = float(self.duration_info[vid_name])
fps = float(self.fps_info[vid_name])
gt_start_time = float(start_time)
gt_end_time = float(end_time)
gt_bbox = (gt_start_time / duration, min(gt_end_time / duration,
1))
num_frames = int(self.num_frames[vid_name])
proposal_frames = self.get_proposals(num_frames)
proposals = proposal_frames / num_frames
proposals = torch.from_numpy(proposals)
proposal_indexes = proposal_frames / self.ft_interval
proposal_indexes = proposal_indexes.astype(np.int32)
info = dict(
vid_name=vid_name,
fps=fps,
num_frames=num_frames,
duration=duration,
query_tokens=query_tokens,
query_length=query_length,
gt_start_time=gt_start_time,
gt_end_time=gt_end_time,
gt_bbox=gt_bbox,
proposals=proposals,
num_proposals=proposals.shape[0],
proposal_indexes=proposal_indexes)
data_list.append(info)
return data_list
def get_proposals(self, num_frames):
proposals = (num_frames - 1) / 32 * np.arange(33)
proposals = proposals.astype(np.int32)
proposals = np.stack([proposals[:-1], proposals[1:]]).T
return proposals
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index."""
data_info = super().get_data_info(idx)
vid_name = data_info['vid_name']
feature_path = os.path.join(self.data_prefix['video'],
f'{vid_name}.pt')
vid_feature = torch.load(feature_path)
proposal_feats = []
proposal_indexes = data_info['proposal_indexes'].clip(
max=vid_feature.shape[0] - 1)
for s, e in proposal_indexes:
prop_feature, _ = vid_feature[s:e + 1].max(dim=0)
proposal_feats.append(prop_feature)
proposal_feats = torch.stack(proposal_feats)
data_info['raw_feature'] = proposal_feats
return data_info
|