|
|
|
|
|
import os
|
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
|
|
import mmengine
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from mmengine.fileio import exists
|
|
|
|
|
|
from mmaction.registry import DATASETS
|
|
|
from mmaction.utils import ConfigType
|
|
|
from .base import BaseActionDataset
|
|
|
|
|
|
try:
|
|
|
import nltk
|
|
|
nltk_imported = True
|
|
|
except ImportError:
|
|
|
nltk_imported = False
|
|
|
|
|
|
|
|
|
@DATASETS.register_module()
|
|
|
class CharadesSTADataset(BaseActionDataset):
|
|
|
|
|
|
def __init__(self,
|
|
|
ann_file: str,
|
|
|
pipeline: List[Union[dict, Callable]],
|
|
|
word2id_file: str,
|
|
|
fps_file: str,
|
|
|
duration_file: str,
|
|
|
num_frames_file: str,
|
|
|
window_size: int,
|
|
|
ft_overlap: float,
|
|
|
data_prefix: Optional[ConfigType] = dict(video=''),
|
|
|
test_mode: bool = False,
|
|
|
**kwargs):
|
|
|
if not nltk_imported:
|
|
|
raise ImportError('nltk is required for CharadesSTADataset')
|
|
|
|
|
|
self.fps_info = mmengine.load(fps_file)
|
|
|
self.duration_info = mmengine.load(duration_file)
|
|
|
self.num_frames = mmengine.load(num_frames_file)
|
|
|
self.word2id = mmengine.load(word2id_file)
|
|
|
self.ft_interval = int(window_size * (1 - ft_overlap))
|
|
|
|
|
|
super().__init__(
|
|
|
ann_file,
|
|
|
pipeline=pipeline,
|
|
|
data_prefix=data_prefix,
|
|
|
test_mode=test_mode,
|
|
|
**kwargs)
|
|
|
|
|
|
def load_data_list(self) -> List[dict]:
|
|
|
"""Load annotation file to get video information."""
|
|
|
exists(self.ann_file)
|
|
|
data_list = []
|
|
|
with open(self.ann_file) as f:
|
|
|
anno_database = f.readlines()
|
|
|
|
|
|
for item in anno_database:
|
|
|
first_part, query_sentence = item.strip().split('##')
|
|
|
query_sentence = query_sentence.replace('.', '')
|
|
|
query_words = nltk.word_tokenize(query_sentence)
|
|
|
query_tokens = [self.word2id[word] for word in query_words]
|
|
|
query_length = len(query_tokens)
|
|
|
query_tokens = torch.from_numpy(np.array(query_tokens))
|
|
|
|
|
|
vid_name, start_time, end_time = first_part.split()
|
|
|
duration = float(self.duration_info[vid_name])
|
|
|
fps = float(self.fps_info[vid_name])
|
|
|
|
|
|
gt_start_time = float(start_time)
|
|
|
gt_end_time = float(end_time)
|
|
|
|
|
|
gt_bbox = (gt_start_time / duration, min(gt_end_time / duration,
|
|
|
1))
|
|
|
|
|
|
num_frames = int(self.num_frames[vid_name])
|
|
|
proposal_frames = self.get_proposals(num_frames)
|
|
|
|
|
|
proposals = proposal_frames / num_frames
|
|
|
proposals = torch.from_numpy(proposals)
|
|
|
proposal_indexes = proposal_frames / self.ft_interval
|
|
|
proposal_indexes = proposal_indexes.astype(np.int32)
|
|
|
|
|
|
info = dict(
|
|
|
vid_name=vid_name,
|
|
|
fps=fps,
|
|
|
num_frames=num_frames,
|
|
|
duration=duration,
|
|
|
query_tokens=query_tokens,
|
|
|
query_length=query_length,
|
|
|
gt_start_time=gt_start_time,
|
|
|
gt_end_time=gt_end_time,
|
|
|
gt_bbox=gt_bbox,
|
|
|
proposals=proposals,
|
|
|
num_proposals=proposals.shape[0],
|
|
|
proposal_indexes=proposal_indexes)
|
|
|
data_list.append(info)
|
|
|
return data_list
|
|
|
|
|
|
def get_proposals(self, num_frames):
|
|
|
proposals = (num_frames - 1) / 32 * np.arange(33)
|
|
|
proposals = proposals.astype(np.int32)
|
|
|
proposals = np.stack([proposals[:-1], proposals[1:]]).T
|
|
|
return proposals
|
|
|
|
|
|
def get_data_info(self, idx: int) -> dict:
|
|
|
"""Get annotation by index."""
|
|
|
data_info = super().get_data_info(idx)
|
|
|
vid_name = data_info['vid_name']
|
|
|
feature_path = os.path.join(self.data_prefix['video'],
|
|
|
f'{vid_name}.pt')
|
|
|
vid_feature = torch.load(feature_path)
|
|
|
proposal_feats = []
|
|
|
proposal_indexes = data_info['proposal_indexes'].clip(
|
|
|
max=vid_feature.shape[0] - 1)
|
|
|
for s, e in proposal_indexes:
|
|
|
prop_feature, _ = vid_feature[s:e + 1].max(dim=0)
|
|
|
proposal_feats.append(prop_feature)
|
|
|
|
|
|
proposal_feats = torch.stack(proposal_feats)
|
|
|
|
|
|
data_info['raw_feature'] = proposal_feats
|
|
|
return data_info
|
|
|
|