| | |
| | |
| |
|
| | import csv |
| | import logging |
| | import numpy as np |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| | import av |
| | import torch |
| | from torch.utils.data.dataset import Dataset |
| |
|
| | from detectron2.utils.file_io import PathManager |
| |
|
| | from ..utils import maybe_prepend_base_path |
| | from .frame_selector import FrameSelector, FrameTsList |
| |
|
| | FrameList = List[av.frame.Frame] |
| | FrameTransform = Callable[[torch.Tensor], torch.Tensor] |
| |
|
| |
|
| | def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList: |
| | """ |
| | Traverses all keyframes of a video file. Returns a list of keyframe |
| | timestamps. Timestamps are counts in timebase units. |
| | |
| | Args: |
| | video_fpath (str): Video file path |
| | video_stream_idx (int): Video stream index (default: 0) |
| | Returns: |
| | List[int]: list of keyframe timestaps (timestamp is a count in timebase |
| | units) |
| | """ |
| | try: |
| | with PathManager.open(video_fpath, "rb") as io: |
| | container = av.open(io, mode="r") |
| | stream = container.streams.video[video_stream_idx] |
| | keyframes = [] |
| | pts = -1 |
| | |
| | |
| | |
| | tolerance_backward_seeks = 2 |
| | while True: |
| | try: |
| | container.seek(pts + 1, backward=False, any_frame=False, stream=stream) |
| | except av.AVError as e: |
| | |
| | |
| | logger = logging.getLogger(__name__) |
| | logger.debug( |
| | f"List keyframes: Error seeking video file {video_fpath}, " |
| | f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}" |
| | ) |
| | return keyframes |
| | except OSError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"List keyframes: Error seeking video file {video_fpath}, " |
| | f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}" |
| | ) |
| | return [] |
| | packet = next(container.demux(video=video_stream_idx)) |
| | if packet.pts is not None and packet.pts <= pts: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Video file {video_fpath}, stream {video_stream_idx}: " |
| | f"bad seek for packet {pts + 1} (got packet {packet.pts}), " |
| | f"tolerance {tolerance_backward_seeks}." |
| | ) |
| | tolerance_backward_seeks -= 1 |
| | if tolerance_backward_seeks == 0: |
| | return [] |
| | pts += 1 |
| | continue |
| | tolerance_backward_seeks = 2 |
| | pts = packet.pts |
| | if pts is None: |
| | return keyframes |
| | if packet.is_keyframe: |
| | keyframes.append(pts) |
| | return keyframes |
| | except OSError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}" |
| | ) |
| | except RuntimeError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"List keyframes: Error opening video file container {video_fpath}, " |
| | f"Runtime error: {e}" |
| | ) |
| | return [] |
| |
|
| |
|
| | def read_keyframes( |
| | video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0 |
| | ) -> FrameList: |
| | """ |
| | Reads keyframe data from a video file. |
| | |
| | Args: |
| | video_fpath (str): Video file path |
| | keyframes (List[int]): List of keyframe timestamps (as counts in |
| | timebase units to be used in container seek operations) |
| | video_stream_idx (int): Video stream index (default: 0) |
| | Returns: |
| | List[Frame]: list of frames that correspond to the specified timestamps |
| | """ |
| | try: |
| | with PathManager.open(video_fpath, "rb") as io: |
| | container = av.open(io) |
| | stream = container.streams.video[video_stream_idx] |
| | frames = [] |
| | for pts in keyframes: |
| | try: |
| | container.seek(pts, any_frame=False, stream=stream) |
| | frame = next(container.decode(video=0)) |
| | frames.append(frame) |
| | except av.AVError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Read keyframes: Error seeking video file {video_fpath}, " |
| | f"video stream {video_stream_idx}, pts {pts}, AV error: {e}" |
| | ) |
| | container.close() |
| | return frames |
| | except OSError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Read keyframes: Error seeking video file {video_fpath}, " |
| | f"video stream {video_stream_idx}, pts {pts}, OS error: {e}" |
| | ) |
| | container.close() |
| | return frames |
| | except StopIteration: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Read keyframes: Error decoding frame from {video_fpath}, " |
| | f"video stream {video_stream_idx}, pts {pts}" |
| | ) |
| | container.close() |
| | return frames |
| |
|
| | container.close() |
| | return frames |
| | except OSError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}" |
| | ) |
| | except RuntimeError as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning( |
| | f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}" |
| | ) |
| | return [] |
| |
|
| |
|
| | def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None): |
| | """ |
| | Create a list of paths to video files from a text file. |
| | |
| | Args: |
| | video_list_fpath (str): path to a plain text file with the list of videos |
| | base_path (str): base path for entries from the video list (default: None) |
| | """ |
| | video_list = [] |
| | with PathManager.open(video_list_fpath, "r") as io: |
| | for line in io: |
| | video_list.append(maybe_prepend_base_path(base_path, str(line.strip()))) |
| | return video_list |
| |
|
| |
|
| | def read_keyframe_helper_data(fpath: str): |
| | """ |
| | Read keyframe data from a file in CSV format: the header should contain |
| | "video_id" and "keyframes" fields. Value specifications are: |
| | video_id: int |
| | keyframes: list(int) |
| | Example of contents: |
| | video_id,keyframes |
| | 2,"[1,11,21,31,41,51,61,71,81]" |
| | |
| | Args: |
| | fpath (str): File containing keyframe data |
| | |
| | Return: |
| | video_id_to_keyframes (dict: int -> list(int)): for a given video ID it |
| | contains a list of keyframes for that video |
| | """ |
| | video_id_to_keyframes = {} |
| | try: |
| | with PathManager.open(fpath, "r") as io: |
| | csv_reader = csv.reader(io) |
| | header = next(csv_reader) |
| | video_id_idx = header.index("video_id") |
| | keyframes_idx = header.index("keyframes") |
| | for row in csv_reader: |
| | video_id = int(row[video_id_idx]) |
| | assert ( |
| | video_id not in video_id_to_keyframes |
| | ), f"Duplicate keyframes entry for video {fpath}" |
| | video_id_to_keyframes[video_id] = ( |
| | [int(v) for v in row[keyframes_idx][1:-1].split(",")] |
| | if len(row[keyframes_idx]) > 2 |
| | else [] |
| | ) |
| | except Exception as e: |
| | logger = logging.getLogger(__name__) |
| | logger.warning(f"Error reading keyframe helper data from {fpath}: {e}") |
| | return video_id_to_keyframes |
| |
|
| |
|
| | class VideoKeyframeDataset(Dataset): |
| | """ |
| | Dataset that provides keyframes for a set of videos. |
| | """ |
| |
|
| | _EMPTY_FRAMES = torch.empty((0, 3, 1, 1)) |
| |
|
| | def __init__( |
| | self, |
| | video_list: List[str], |
| | category_list: Union[str, List[str], None] = None, |
| | frame_selector: Optional[FrameSelector] = None, |
| | transform: Optional[FrameTransform] = None, |
| | keyframe_helper_fpath: Optional[str] = None, |
| | ): |
| | """ |
| | Dataset constructor |
| | |
| | Args: |
| | video_list (List[str]): list of paths to video files |
| | category_list (Union[str, List[str], None]): list of animal categories for each |
| | video file. If it is a string, or None, this applies to all videos |
| | frame_selector (Callable: KeyFrameList -> KeyFrameList): |
| | selects keyframes to process, keyframes are given by |
| | packet timestamps in timebase counts. If None, all keyframes |
| | are selected (default: None) |
| | transform (Callable: torch.Tensor -> torch.Tensor): |
| | transforms a batch of RGB images (tensors of size [B, 3, H, W]), |
| | returns a tensor of the same size. If None, no transform is |
| | applied (default: None) |
| | |
| | """ |
| | if type(category_list) == list: |
| | self.category_list = category_list |
| | else: |
| | self.category_list = [category_list] * len(video_list) |
| | assert len(video_list) == len( |
| | self.category_list |
| | ), "length of video and category lists must be equal" |
| | self.video_list = video_list |
| | self.frame_selector = frame_selector |
| | self.transform = transform |
| | self.keyframe_helper_data = ( |
| | read_keyframe_helper_data(keyframe_helper_fpath) |
| | if keyframe_helper_fpath is not None |
| | else None |
| | ) |
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]: |
| | """ |
| | Gets selected keyframes from a given video |
| | |
| | Args: |
| | idx (int): video index in the video list file |
| | Returns: |
| | A dictionary containing two keys: |
| | images (torch.Tensor): tensor of size [N, H, W, 3] or of size |
| | defined by the transform that contains keyframes data |
| | categories (List[str]): categories of the frames |
| | """ |
| | categories = [self.category_list[idx]] |
| | fpath = self.video_list[idx] |
| | keyframes = ( |
| | list_keyframes(fpath) |
| | if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data |
| | else self.keyframe_helper_data[idx] |
| | ) |
| | transform = self.transform |
| | frame_selector = self.frame_selector |
| | if not keyframes: |
| | return {"images": self._EMPTY_FRAMES, "categories": []} |
| | if frame_selector is not None: |
| | keyframes = frame_selector(keyframes) |
| | frames = read_keyframes(fpath, keyframes) |
| | if not frames: |
| | return {"images": self._EMPTY_FRAMES, "categories": []} |
| | frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames]) |
| | frames = torch.as_tensor(frames, device=torch.device("cpu")) |
| | frames = frames[..., [2, 1, 0]] |
| | frames = frames.permute(0, 3, 1, 2).float() |
| | if transform is not None: |
| | frames = transform(frames) |
| | return {"images": frames, "categories": categories} |
| |
|
| | def __len__(self): |
| | return len(self.video_list) |
| |
|