|
|
|
|
|
import os.path as osp
|
|
|
from collections import defaultdict
|
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
|
|
import numpy as np
|
|
|
from mmengine.fileio import exists, list_from_file, load
|
|
|
from mmengine.logging import MMLogger
|
|
|
|
|
|
from mmaction.evaluation import read_labelmap
|
|
|
from mmaction.registry import DATASETS
|
|
|
from mmaction.utils import ConfigType
|
|
|
from .base import BaseActionDataset
|
|
|
|
|
|
|
|
|
@DATASETS.register_module()
|
|
|
class AVADataset(BaseActionDataset):
|
|
|
"""STAD dataset for spatial temporal action detection.
|
|
|
|
|
|
The dataset loads raw frames/video files, bounding boxes,
|
|
|
proposals and applies specified transformations to return
|
|
|
a dict containing the frame tensors and other information.
|
|
|
|
|
|
This datasets can load information from the following files:
|
|
|
|
|
|
.. code-block:: txt
|
|
|
|
|
|
ann_file -> ava_{train, val}_{v2.1, v2.2}.csv
|
|
|
exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv
|
|
|
label_file -> ava_action_list_{v2.1, v2.2}.pbtxt /
|
|
|
ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt
|
|
|
proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl
|
|
|
|
|
|
Particularly, the proposal_file is a pickle file which contains
|
|
|
``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle
|
|
|
file:
|
|
|
|
|
|
.. code-block:: JSON
|
|
|
|
|
|
{
|
|
|
...
|
|
|
'0f39OWEqJ24,0902':
|
|
|
array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]),
|
|
|
'0f39OWEqJ24,0912':
|
|
|
array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273],
|
|
|
[0.016 , 0.161 , 0.519 , 0.974 , 0.984025],
|
|
|
[0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]),
|
|
|
...
|
|
|
}
|
|
|
|
|
|
Args:
|
|
|
ann_file (str): Path to the annotation file like
|
|
|
``ava_{train, val}_{v2.1, v2.2}.csv``.
|
|
|
exclude_file (str): Path to the excluded timestamp file like
|
|
|
``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``.
|
|
|
pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of
|
|
|
data transforms.
|
|
|
label_file (str): Path to the label file like
|
|
|
``ava_action_list_{v2.1, v2.2}.pbtxt`` or
|
|
|
``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``.
|
|
|
Defaults to None.
|
|
|
filename_tmpl (str): Template for each filename.
|
|
|
Defaults to 'img_{:05}.jpg'.
|
|
|
start_index (int): Specify a start index for frames in consideration of
|
|
|
different filename format. It should be set to 1 for AVA, since
|
|
|
frame index start from 1 in AVA dataset. Defaults to 1.
|
|
|
proposal_file (str): Path to the proposal file like
|
|
|
``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``.
|
|
|
Defaults to None.
|
|
|
person_det_score_thr (float): The threshold of person detection scores,
|
|
|
bboxes with scores above the threshold will be used.
|
|
|
Note that 0 <= person_det_score_thr <= 1. If no proposal has
|
|
|
detection score larger than the threshold, the one with the largest
|
|
|
detection score will be used. Default: 0.9.
|
|
|
num_classes (int): The number of classes of the dataset. Default: 81.
|
|
|
(AVA has 80 action classes, another 1-dim is added for potential
|
|
|
usage)
|
|
|
custom_classes (List[int], optional): A subset of class ids from origin
|
|
|
dataset. Please note that 0 should NOT be selected, and
|
|
|
``num_classes`` should be equal to ``len(custom_classes) + 1``.
|
|
|
data_prefix (dict or ConfigDict): Path to a directory where video
|
|
|
frames are held. Defaults to ``dict(img='')``.
|
|
|
test_mode (bool): Store True when building test or validation dataset.
|
|
|
Defaults to False.
|
|
|
modality (str): Modality of data. Support ``RGB``, ``Flow``.
|
|
|
Defaults to ``RGB``.
|
|
|
num_max_proposals (int): Max proposals number to store.
|
|
|
Defaults to 1000.
|
|
|
timestamp_start (int): The start point of included timestamps. The
|
|
|
default value is referred from the official website.
|
|
|
Defaults to 902.
|
|
|
timestamp_end (int): The end point of included timestamps. The default
|
|
|
value is referred from the official website. Defaults to 1798.
|
|
|
use_frames (bool): Whether to use rawframes as input.
|
|
|
Defaults to True.
|
|
|
fps (int): Overrides the default FPS for the dataset. If set to 1,
|
|
|
means counting timestamp by frame, e.g. MultiSports dataset.
|
|
|
Otherwise by second. Defaults to 30.
|
|
|
multilabel (bool): Determines whether it is a multilabel recognition
|
|
|
task. Defaults to True.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
ann_file: str,
|
|
|
pipeline: List[Union[ConfigType, Callable]],
|
|
|
exclude_file: Optional[str] = None,
|
|
|
label_file: Optional[str] = None,
|
|
|
filename_tmpl: str = 'img_{:05}.jpg',
|
|
|
start_index: int = 1,
|
|
|
proposal_file: str = None,
|
|
|
person_det_score_thr: float = 0.9,
|
|
|
num_classes: int = 81,
|
|
|
custom_classes: Optional[List[int]] = None,
|
|
|
data_prefix: ConfigType = dict(img=''),
|
|
|
modality: str = 'RGB',
|
|
|
test_mode: bool = False,
|
|
|
num_max_proposals: int = 1000,
|
|
|
timestamp_start: int = 900,
|
|
|
timestamp_end: int = 1800,
|
|
|
use_frames: bool = True,
|
|
|
fps: int = 30,
|
|
|
multilabel: bool = True,
|
|
|
**kwargs) -> None:
|
|
|
self._FPS = fps
|
|
|
self.custom_classes = custom_classes
|
|
|
if custom_classes is not None:
|
|
|
assert num_classes == len(custom_classes) + 1
|
|
|
assert 0 not in custom_classes
|
|
|
_, class_whitelist = read_labelmap(open(label_file))
|
|
|
assert set(custom_classes).issubset(class_whitelist)
|
|
|
|
|
|
self.custom_classes = list([0] + custom_classes)
|
|
|
self.exclude_file = exclude_file
|
|
|
self.label_file = label_file
|
|
|
self.proposal_file = proposal_file
|
|
|
assert 0 <= person_det_score_thr <= 1, (
|
|
|
'The value of '
|
|
|
'person_det_score_thr should in [0, 1]. ')
|
|
|
self.person_det_score_thr = person_det_score_thr
|
|
|
self.timestamp_start = timestamp_start
|
|
|
self.timestamp_end = timestamp_end
|
|
|
self.num_max_proposals = num_max_proposals
|
|
|
self.filename_tmpl = filename_tmpl
|
|
|
self.use_frames = use_frames
|
|
|
self.multilabel = multilabel
|
|
|
|
|
|
super().__init__(
|
|
|
ann_file,
|
|
|
pipeline=pipeline,
|
|
|
data_prefix=data_prefix,
|
|
|
test_mode=test_mode,
|
|
|
num_classes=num_classes,
|
|
|
start_index=start_index,
|
|
|
modality=modality,
|
|
|
**kwargs)
|
|
|
|
|
|
if self.proposal_file is not None:
|
|
|
self.proposals = load(self.proposal_file)
|
|
|
else:
|
|
|
self.proposals = None
|
|
|
|
|
|
def parse_img_record(self, img_records: List[dict]) -> tuple:
|
|
|
"""Merge image records of the same entity at the same time.
|
|
|
|
|
|
Args:
|
|
|
img_records (List[dict]): List of img_records (lines in AVA
|
|
|
annotations).
|
|
|
|
|
|
Returns:
|
|
|
Tuple(list): A tuple consists of lists of bboxes, action labels and
|
|
|
entity_ids.
|
|
|
"""
|
|
|
bboxes, labels, entity_ids = [], [], []
|
|
|
while len(img_records) > 0:
|
|
|
img_record = img_records[0]
|
|
|
num_img_records = len(img_records)
|
|
|
|
|
|
selected_records = [
|
|
|
x for x in img_records
|
|
|
if np.array_equal(x['entity_box'], img_record['entity_box'])
|
|
|
]
|
|
|
|
|
|
num_selected_records = len(selected_records)
|
|
|
img_records = [
|
|
|
x for x in img_records if
|
|
|
not np.array_equal(x['entity_box'], img_record['entity_box'])
|
|
|
]
|
|
|
|
|
|
assert len(img_records) + num_selected_records == num_img_records
|
|
|
|
|
|
bboxes.append(img_record['entity_box'])
|
|
|
valid_labels = np.array([
|
|
|
selected_record['label']
|
|
|
for selected_record in selected_records
|
|
|
])
|
|
|
|
|
|
|
|
|
if self.multilabel:
|
|
|
label = np.zeros(self.num_classes, dtype=np.float32)
|
|
|
label[valid_labels] = 1.
|
|
|
else:
|
|
|
label = valid_labels
|
|
|
|
|
|
labels.append(label)
|
|
|
entity_ids.append(img_record['entity_id'])
|
|
|
bboxes = np.stack(bboxes)
|
|
|
labels = np.stack(labels)
|
|
|
entity_ids = np.stack(entity_ids)
|
|
|
return bboxes, labels, entity_ids
|
|
|
|
|
|
def load_data_list(self) -> List[dict]:
|
|
|
"""Load AVA annotations."""
|
|
|
exists(self.ann_file)
|
|
|
data_list = []
|
|
|
records_dict_by_img = defaultdict(list)
|
|
|
fin = list_from_file(self.ann_file)
|
|
|
for line in fin:
|
|
|
line_split = line.strip().split(',')
|
|
|
|
|
|
label = int(line_split[6])
|
|
|
if self.custom_classes is not None:
|
|
|
if label not in self.custom_classes:
|
|
|
continue
|
|
|
label = self.custom_classes.index(label)
|
|
|
|
|
|
video_id = line_split[0]
|
|
|
timestamp = int(line_split[1])
|
|
|
img_key = f'{video_id},{timestamp:04d}'
|
|
|
|
|
|
entity_box = np.array(list(map(float, line_split[2:6])))
|
|
|
entity_id = int(line_split[7])
|
|
|
if self.use_frames:
|
|
|
shot_info = (0, (self.timestamp_end - self.timestamp_start) *
|
|
|
self._FPS)
|
|
|
|
|
|
else:
|
|
|
shot_info = None
|
|
|
|
|
|
video_info = dict(
|
|
|
video_id=video_id,
|
|
|
timestamp=timestamp,
|
|
|
entity_box=entity_box,
|
|
|
label=label,
|
|
|
entity_id=entity_id,
|
|
|
shot_info=shot_info)
|
|
|
records_dict_by_img[img_key].append(video_info)
|
|
|
|
|
|
for img_key in records_dict_by_img:
|
|
|
video_id, timestamp = img_key.split(',')
|
|
|
bboxes, labels, entity_ids = self.parse_img_record(
|
|
|
records_dict_by_img[img_key])
|
|
|
ann = dict(
|
|
|
gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids)
|
|
|
frame_dir = video_id
|
|
|
if self.data_prefix['img'] is not None:
|
|
|
frame_dir = osp.join(self.data_prefix['img'], frame_dir)
|
|
|
video_info = dict(
|
|
|
frame_dir=frame_dir,
|
|
|
video_id=video_id,
|
|
|
timestamp=int(timestamp),
|
|
|
img_key=img_key,
|
|
|
shot_info=shot_info,
|
|
|
fps=self._FPS,
|
|
|
ann=ann)
|
|
|
if not self.use_frames:
|
|
|
video_info['filename'] = video_info.pop('frame_dir')
|
|
|
data_list.append(video_info)
|
|
|
|
|
|
return data_list
|
|
|
|
|
|
def filter_data(self) -> List[dict]:
|
|
|
"""Filter out records in the exclude_file."""
|
|
|
valid_indexes = []
|
|
|
if self.exclude_file is None:
|
|
|
valid_indexes = list(range(len(self.data_list)))
|
|
|
else:
|
|
|
exclude_video_infos = [
|
|
|
x.strip().split(',') for x in open(self.exclude_file)
|
|
|
]
|
|
|
for i, data_info in enumerate(self.data_list):
|
|
|
valid_indexes.append(i)
|
|
|
for video_id, timestamp in exclude_video_infos:
|
|
|
if (data_info['video_id'] == video_id
|
|
|
and data_info['timestamp'] == int(timestamp)):
|
|
|
valid_indexes.pop()
|
|
|
break
|
|
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
|
logger.info(f'{len(valid_indexes)} out of {len(self.data_list)}'
|
|
|
f' frames are valid.')
|
|
|
data_list = [self.data_list[i] for i in valid_indexes]
|
|
|
|
|
|
return data_list
|
|
|
|
|
|
def get_data_info(self, idx: int) -> dict:
|
|
|
"""Get annotation by index."""
|
|
|
data_info = super().get_data_info(idx)
|
|
|
img_key = data_info['img_key']
|
|
|
|
|
|
data_info['filename_tmpl'] = self.filename_tmpl
|
|
|
data_info['timestamp_start'] = self.timestamp_start
|
|
|
data_info['timestamp_end'] = self.timestamp_end
|
|
|
|
|
|
if self.proposals is not None:
|
|
|
if img_key not in self.proposals:
|
|
|
data_info['proposals'] = np.array([[0, 0, 1, 1]])
|
|
|
data_info['scores'] = np.array([1])
|
|
|
else:
|
|
|
proposals = self.proposals[img_key]
|
|
|
assert proposals.shape[-1] in [4, 5]
|
|
|
if proposals.shape[-1] == 5:
|
|
|
thr = min(self.person_det_score_thr, max(proposals[:, 4]))
|
|
|
positive_inds = (proposals[:, 4] >= thr)
|
|
|
proposals = proposals[positive_inds]
|
|
|
proposals = proposals[:self.num_max_proposals]
|
|
|
data_info['proposals'] = proposals[:, :4]
|
|
|
data_info['scores'] = proposals[:, 4]
|
|
|
else:
|
|
|
proposals = proposals[:self.num_max_proposals]
|
|
|
data_info['proposals'] = proposals
|
|
|
|
|
|
assert data_info['proposals'].max() <= 1 and \
|
|
|
data_info['proposals'].min() >= 0, \
|
|
|
(f'relative proposals invalid: max value '
|
|
|
f'{data_info["proposals"].max()}, min value '
|
|
|
f'{data_info["proposals"].min()}')
|
|
|
|
|
|
ann = data_info.pop('ann')
|
|
|
data_info['gt_bboxes'] = ann['gt_bboxes']
|
|
|
data_info['gt_labels'] = ann['gt_labels']
|
|
|
data_info['entity_ids'] = ann['entity_ids']
|
|
|
|
|
|
return data_info
|
|
|
|
|
|
|
|
|
@DATASETS.register_module()
|
|
|
class AVAKineticsDataset(BaseActionDataset):
|
|
|
"""AVA-Kinetics dataset for spatial temporal detection.
|
|
|
|
|
|
Based on official AVA annotation files, the dataset loads raw frames,
|
|
|
bounding boxes, proposals and applies specified transformations to return
|
|
|
a dict containing the frame tensors and other information.
|
|
|
|
|
|
This datasets can load information from the following files:
|
|
|
|
|
|
.. code-block:: txt
|
|
|
|
|
|
ann_file -> ava_{train, val}_{v2.1, v2.2}.csv
|
|
|
exclude_file -> ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv
|
|
|
label_file -> ava_action_list_{v2.1, v2.2}.pbtxt /
|
|
|
ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt
|
|
|
proposal_file -> ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl
|
|
|
|
|
|
Particularly, the proposal_file is a pickle file which contains
|
|
|
``img_key`` (in format of ``{video_id},{timestamp}``). Example of a pickle
|
|
|
file:
|
|
|
|
|
|
.. code-block:: JSON
|
|
|
|
|
|
{
|
|
|
...
|
|
|
'0f39OWEqJ24,0902':
|
|
|
array([[0.011 , 0.157 , 0.655 , 0.983 , 0.998163]]),
|
|
|
'0f39OWEqJ24,0912':
|
|
|
array([[0.054 , 0.088 , 0.91 , 0.998 , 0.068273],
|
|
|
[0.016 , 0.161 , 0.519 , 0.974 , 0.984025],
|
|
|
[0.493 , 0.283 , 0.981 , 0.984 , 0.983621]]),
|
|
|
...
|
|
|
}
|
|
|
|
|
|
Args:
|
|
|
ann_file (str): Path to the annotation file like
|
|
|
``ava_{train, val}_{v2.1, v2.2}.csv``.
|
|
|
exclude_file (str): Path to the excluded timestamp file like
|
|
|
``ava_{train, val}_excluded_timestamps_{v2.1, v2.2}.csv``.
|
|
|
pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of
|
|
|
data transforms.
|
|
|
label_file (str): Path to the label file like
|
|
|
``ava_action_list_{v2.1, v2.2}.pbtxt`` or
|
|
|
``ava_action_list_{v2.1, v2.2}_for_activitynet_2019.pbtxt``.
|
|
|
Defaults to None.
|
|
|
filename_tmpl (str): Template for each filename.
|
|
|
Defaults to 'img_{:05}.jpg'.
|
|
|
start_index (int): Specify a start index for frames in consideration of
|
|
|
different filename format. However, when taking frames as input,
|
|
|
it should be set to 0, since frames from 0. Defaults to 0.
|
|
|
proposal_file (str): Path to the proposal file like
|
|
|
``ava_dense_proposals_{train, val}.FAIR.recall_93.9.pkl``.
|
|
|
Defaults to None.
|
|
|
person_det_score_thr (float): The threshold of person detection scores,
|
|
|
bboxes with scores above the threshold will be used.
|
|
|
Note that 0 <= person_det_score_thr <= 1. If no proposal has
|
|
|
detection score larger than the threshold, the one with the largest
|
|
|
detection score will be used. Default: 0.9.
|
|
|
num_classes (int): The number of classes of the dataset. Default: 81.
|
|
|
(AVA has 80 action classes, another 1-dim is added for potential
|
|
|
usage)
|
|
|
custom_classes (List[int], optional): A subset of class ids from origin
|
|
|
dataset. Please note that 0 should NOT be selected, and
|
|
|
``num_classes`` should be equal to ``len(custom_classes) + 1``.
|
|
|
data_prefix (dict or ConfigDict): Path to a directory where video
|
|
|
frames are held. Defaults to ``dict(img='')``.
|
|
|
test_mode (bool): Store True when building test or validation dataset.
|
|
|
Defaults to False.
|
|
|
modality (str): Modality of data. Support ``RGB``, ``Flow``.
|
|
|
Defaults to ``RGB``.
|
|
|
num_max_proposals (int): Max proposals number to store.
|
|
|
Defaults to 1000.
|
|
|
timestamp_start (int): The start point of included timestamps. The
|
|
|
default value is referred from the official website.
|
|
|
Defaults to 902.
|
|
|
timestamp_end (int): The end point of included timestamps. The default
|
|
|
value is referred from the official website. Defaults to 1798.
|
|
|
fps (int): Overrides the default FPS for the dataset. Defaults to 30.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
ann_file: str,
|
|
|
exclude_file: str,
|
|
|
pipeline: List[Union[ConfigType, Callable]],
|
|
|
label_file: str,
|
|
|
filename_tmpl: str = 'img_{:05}.jpg',
|
|
|
start_index: int = 0,
|
|
|
proposal_file: str = None,
|
|
|
person_det_score_thr: float = 0.9,
|
|
|
num_classes: int = 81,
|
|
|
custom_classes: Optional[List[int]] = None,
|
|
|
data_prefix: ConfigType = dict(img=''),
|
|
|
modality: str = 'RGB',
|
|
|
test_mode: bool = False,
|
|
|
num_max_proposals: int = 1000,
|
|
|
timestamp_start: int = 900,
|
|
|
timestamp_end: int = 1800,
|
|
|
fps: int = 30,
|
|
|
**kwargs) -> None:
|
|
|
self._FPS = fps
|
|
|
self.custom_classes = custom_classes
|
|
|
if custom_classes is not None:
|
|
|
assert num_classes == len(custom_classes) + 1
|
|
|
assert 0 not in custom_classes
|
|
|
_, class_whitelist = read_labelmap(open(label_file))
|
|
|
assert set(custom_classes).issubset(class_whitelist)
|
|
|
|
|
|
self.custom_classes = list([0] + custom_classes)
|
|
|
self.exclude_file = exclude_file
|
|
|
self.label_file = label_file
|
|
|
self.proposal_file = proposal_file
|
|
|
assert 0 <= person_det_score_thr <= 1, (
|
|
|
'The value of '
|
|
|
'person_det_score_thr should in [0, 1]. ')
|
|
|
self.person_det_score_thr = person_det_score_thr
|
|
|
self.timestamp_start = timestamp_start
|
|
|
self.timestamp_end = timestamp_end
|
|
|
self.num_max_proposals = num_max_proposals
|
|
|
self.filename_tmpl = filename_tmpl
|
|
|
|
|
|
super().__init__(
|
|
|
ann_file,
|
|
|
pipeline=pipeline,
|
|
|
data_prefix=data_prefix,
|
|
|
test_mode=test_mode,
|
|
|
num_classes=num_classes,
|
|
|
start_index=start_index,
|
|
|
modality=modality,
|
|
|
**kwargs)
|
|
|
|
|
|
if self.proposal_file is not None:
|
|
|
self.proposals = load(self.proposal_file)
|
|
|
else:
|
|
|
self.proposals = None
|
|
|
|
|
|
def parse_img_record(self, img_records: List[dict]) -> tuple:
|
|
|
"""Merge image records of the same entity at the same time.
|
|
|
|
|
|
Args:
|
|
|
img_records (List[dict]): List of img_records (lines in AVA
|
|
|
annotations).
|
|
|
|
|
|
Returns:
|
|
|
Tuple(list): A tuple consists of lists of bboxes, action labels and
|
|
|
entity_ids.
|
|
|
"""
|
|
|
bboxes, labels, entity_ids = [], [], []
|
|
|
while len(img_records) > 0:
|
|
|
img_record = img_records[0]
|
|
|
num_img_records = len(img_records)
|
|
|
|
|
|
selected_records = [
|
|
|
x for x in img_records
|
|
|
if np.array_equal(x['entity_box'], img_record['entity_box'])
|
|
|
]
|
|
|
|
|
|
num_selected_records = len(selected_records)
|
|
|
img_records = [
|
|
|
x for x in img_records if
|
|
|
not np.array_equal(x['entity_box'], img_record['entity_box'])
|
|
|
]
|
|
|
|
|
|
assert len(img_records) + num_selected_records == num_img_records
|
|
|
|
|
|
bboxes.append(img_record['entity_box'])
|
|
|
valid_labels = np.array([
|
|
|
selected_record['label']
|
|
|
for selected_record in selected_records
|
|
|
])
|
|
|
|
|
|
|
|
|
label = np.zeros(self.num_classes, dtype=np.float32)
|
|
|
label[valid_labels] = 1.
|
|
|
|
|
|
labels.append(label)
|
|
|
entity_ids.append(img_record['entity_id'])
|
|
|
|
|
|
bboxes = np.stack(bboxes)
|
|
|
labels = np.stack(labels)
|
|
|
entity_ids = np.stack(entity_ids)
|
|
|
return bboxes, labels, entity_ids
|
|
|
|
|
|
def filter_data(self) -> List[dict]:
|
|
|
"""Filter out records in the exclude_file."""
|
|
|
valid_indexes = []
|
|
|
if self.exclude_file is None:
|
|
|
valid_indexes = list(range(len(self.data_list)))
|
|
|
else:
|
|
|
exclude_video_infos = [
|
|
|
x.strip().split(',') for x in open(self.exclude_file)
|
|
|
]
|
|
|
for i, data_info in enumerate(self.data_list):
|
|
|
valid_indexes.append(i)
|
|
|
for video_id, timestamp in exclude_video_infos:
|
|
|
if (data_info['video_id'] == video_id
|
|
|
and data_info['timestamp'] == int(timestamp)):
|
|
|
valid_indexes.pop()
|
|
|
break
|
|
|
|
|
|
logger = MMLogger.get_current_instance()
|
|
|
logger.info(f'{len(valid_indexes)} out of {len(self.data_list)}'
|
|
|
f' frames are valid.')
|
|
|
data_list = [self.data_list[i] for i in valid_indexes]
|
|
|
|
|
|
return data_list
|
|
|
|
|
|
def get_timestamp(self, video_id):
|
|
|
if len(video_id) == 11:
|
|
|
return self.timestamp_start, self.timestamp_end
|
|
|
video_id = video_id.split('_')
|
|
|
if len(video_id) >= 3:
|
|
|
start = int(video_id[-2])
|
|
|
end = int(video_id[-1])
|
|
|
video_id = '_'.join(video_id[:-2])
|
|
|
return start, end
|
|
|
return self.timestamp_start, self.timestamp_end
|
|
|
|
|
|
def load_data_list(self) -> List[dict]:
|
|
|
"""Load AVA annotations."""
|
|
|
exists(self.ann_file)
|
|
|
data_list = []
|
|
|
records_dict_by_img = defaultdict(list)
|
|
|
fin = list_from_file(self.ann_file)
|
|
|
for line in fin:
|
|
|
line_split = line.strip().split(',')
|
|
|
|
|
|
label = int(line_split[6])
|
|
|
if self.custom_classes is not None:
|
|
|
if label not in self.custom_classes:
|
|
|
continue
|
|
|
label = self.custom_classes.index(label)
|
|
|
|
|
|
video_id = line_split[0]
|
|
|
timestamp = int(line_split[1])
|
|
|
img_key = f'{video_id},{timestamp:04d}'
|
|
|
|
|
|
entity_box = np.array(list(map(float, line_split[2:6])))
|
|
|
entity_id = int(line_split[7])
|
|
|
start, end = self.get_timestamp(video_id)
|
|
|
shot_info = (1, (end - start) * self._FPS + 1)
|
|
|
|
|
|
video_info = dict(
|
|
|
video_id=video_id,
|
|
|
timestamp=timestamp,
|
|
|
entity_box=entity_box,
|
|
|
label=label,
|
|
|
entity_id=entity_id,
|
|
|
shot_info=shot_info)
|
|
|
records_dict_by_img[img_key].append(video_info)
|
|
|
|
|
|
for img_key in records_dict_by_img:
|
|
|
video_id, timestamp = img_key.split(',')
|
|
|
start, end = self.get_timestamp(video_id)
|
|
|
bboxes, labels, entity_ids = self.parse_img_record(
|
|
|
records_dict_by_img[img_key])
|
|
|
ann = dict(
|
|
|
gt_bboxes=bboxes, gt_labels=labels, entity_ids=entity_ids)
|
|
|
frame_dir = video_id
|
|
|
if self.data_prefix['img'] is not None:
|
|
|
frame_dir = osp.join(self.data_prefix['img'], frame_dir)
|
|
|
video_info = dict(
|
|
|
frame_dir=frame_dir,
|
|
|
video_id=video_id,
|
|
|
timestamp=int(timestamp),
|
|
|
timestamp_start=start,
|
|
|
timestamp_end=end,
|
|
|
img_key=img_key,
|
|
|
shot_info=shot_info,
|
|
|
fps=self._FPS,
|
|
|
ann=ann)
|
|
|
data_list.append(video_info)
|
|
|
|
|
|
return data_list
|
|
|
|
|
|
def get_data_info(self, idx: int) -> dict:
|
|
|
"""Get annotation by index."""
|
|
|
data_info = super().get_data_info(idx)
|
|
|
img_key = data_info['img_key']
|
|
|
data_info['filename_tmpl'] = self.filename_tmpl
|
|
|
if 'timestamp_start' not in data_info:
|
|
|
data_info['timestamp_start'] = self.timestamp_start
|
|
|
data_info['timestamp_end'] = self.timestamp_end
|
|
|
|
|
|
if self.proposals is not None:
|
|
|
if len(img_key) == 16:
|
|
|
proposal_key = img_key
|
|
|
else:
|
|
|
video_id, timestamp = img_key.split(',')
|
|
|
vid = '_'.join(video_id.split('_')[:-2])
|
|
|
timestamp = int(timestamp)
|
|
|
proposal_key = f'{vid},{timestamp:04d}'
|
|
|
|
|
|
if proposal_key not in self.proposals:
|
|
|
data_info['proposals'] = np.array([[0, 0, 1, 1]])
|
|
|
data_info['scores'] = np.array([1])
|
|
|
else:
|
|
|
proposals = self.proposals[proposal_key]
|
|
|
assert proposals.shape[-1] in [4, 5]
|
|
|
if proposals.shape[-1] == 5:
|
|
|
thr = min(self.person_det_score_thr, max(proposals[:, 4]))
|
|
|
positive_inds = (proposals[:, 4] >= thr)
|
|
|
proposals = proposals[positive_inds]
|
|
|
proposals = proposals[:self.num_max_proposals]
|
|
|
data_info['proposals'] = proposals[:, :4]
|
|
|
data_info['scores'] = proposals[:, 4]
|
|
|
else:
|
|
|
proposals = proposals[:self.num_max_proposals]
|
|
|
data_info['proposals'] = proposals
|
|
|
|
|
|
ann = data_info.pop('ann')
|
|
|
data_info['gt_bboxes'] = ann['gt_bboxes']
|
|
|
data_info['gt_labels'] = ann['gt_labels']
|
|
|
data_info['entity_ids'] = ann['entity_ids']
|
|
|
|
|
|
return data_info
|
|
|
|