|
|
|
|
|
|
| import csv
|
| import logging
|
| import numpy as np
|
| from typing import Any, Callable, Dict, List, Optional, Union
|
| import av
|
| import torch
|
| from torch.utils.data.dataset import Dataset
|
|
|
| from detectron2.utils.file_io import PathManager
|
|
|
| from ..utils import maybe_prepend_base_path
|
| from .frame_selector import FrameSelector, FrameTsList
|
|
|
| FrameList = List[av.frame.Frame]
|
| FrameTransform = Callable[[torch.Tensor], torch.Tensor]
|
|
|
|
|
| def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
|
| """
|
| Traverses all keyframes of a video file. Returns a list of keyframe
|
| timestamps. Timestamps are counts in timebase units.
|
|
|
| Args:
|
| video_fpath (str): Video file path
|
| video_stream_idx (int): Video stream index (default: 0)
|
| Returns:
|
| List[int]: list of keyframe timestaps (timestamp is a count in timebase
|
| units)
|
| """
|
| try:
|
| with PathManager.open(video_fpath, "rb") as io:
|
| container = av.open(io, mode="r")
|
| stream = container.streams.video[video_stream_idx]
|
| keyframes = []
|
| pts = -1
|
|
|
|
|
|
|
| tolerance_backward_seeks = 2
|
| while True:
|
| try:
|
| container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
|
| except av.AVError as e:
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
| logger.debug(
|
| f"List keyframes: Error seeking video file {video_fpath}, "
|
| f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
|
| )
|
| return keyframes
|
| except OSError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"List keyframes: Error seeking video file {video_fpath}, "
|
| f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
|
| )
|
| return []
|
| packet = next(container.demux(video=video_stream_idx))
|
| if packet.pts is not None and packet.pts <= pts:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Video file {video_fpath}, stream {video_stream_idx}: "
|
| f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
|
| f"tolerance {tolerance_backward_seeks}."
|
| )
|
| tolerance_backward_seeks -= 1
|
| if tolerance_backward_seeks == 0:
|
| return []
|
| pts += 1
|
| continue
|
| tolerance_backward_seeks = 2
|
| pts = packet.pts
|
| if pts is None:
|
| return keyframes
|
| if packet.is_keyframe:
|
| keyframes.append(pts)
|
| return keyframes
|
| except OSError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
|
| )
|
| except RuntimeError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"List keyframes: Error opening video file container {video_fpath}, "
|
| f"Runtime error: {e}"
|
| )
|
| return []
|
|
|
|
|
| def read_keyframes(
|
| video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
|
| ) -> FrameList:
|
| """
|
| Reads keyframe data from a video file.
|
|
|
| Args:
|
| video_fpath (str): Video file path
|
| keyframes (List[int]): List of keyframe timestamps (as counts in
|
| timebase units to be used in container seek operations)
|
| video_stream_idx (int): Video stream index (default: 0)
|
| Returns:
|
| List[Frame]: list of frames that correspond to the specified timestamps
|
| """
|
| try:
|
| with PathManager.open(video_fpath, "rb") as io:
|
| container = av.open(io)
|
| stream = container.streams.video[video_stream_idx]
|
| frames = []
|
| for pts in keyframes:
|
| try:
|
| container.seek(pts, any_frame=False, stream=stream)
|
| frame = next(container.decode(video=0))
|
| frames.append(frame)
|
| except av.AVError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Read keyframes: Error seeking video file {video_fpath}, "
|
| f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
|
| )
|
| container.close()
|
| return frames
|
| except OSError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Read keyframes: Error seeking video file {video_fpath}, "
|
| f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
|
| )
|
| container.close()
|
| return frames
|
| except StopIteration:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Read keyframes: Error decoding frame from {video_fpath}, "
|
| f"video stream {video_stream_idx}, pts {pts}"
|
| )
|
| container.close()
|
| return frames
|
|
|
| container.close()
|
| return frames
|
| except OSError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
|
| )
|
| except RuntimeError as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(
|
| f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
|
| )
|
| return []
|
|
|
|
|
| def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
|
| """
|
| Create a list of paths to video files from a text file.
|
|
|
| Args:
|
| video_list_fpath (str): path to a plain text file with the list of videos
|
| base_path (str): base path for entries from the video list (default: None)
|
| """
|
| video_list = []
|
| with PathManager.open(video_list_fpath, "r") as io:
|
| for line in io:
|
| video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
|
| return video_list
|
|
|
|
|
| def read_keyframe_helper_data(fpath: str):
|
| """
|
| Read keyframe data from a file in CSV format: the header should contain
|
| "video_id" and "keyframes" fields. Value specifications are:
|
| video_id: int
|
| keyframes: list(int)
|
| Example of contents:
|
| video_id,keyframes
|
| 2,"[1,11,21,31,41,51,61,71,81]"
|
|
|
| Args:
|
| fpath (str): File containing keyframe data
|
|
|
| Return:
|
| video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
|
| contains a list of keyframes for that video
|
| """
|
| video_id_to_keyframes = {}
|
| try:
|
| with PathManager.open(fpath, "r") as io:
|
| csv_reader = csv.reader(io)
|
| header = next(csv_reader)
|
| video_id_idx = header.index("video_id")
|
| keyframes_idx = header.index("keyframes")
|
| for row in csv_reader:
|
| video_id = int(row[video_id_idx])
|
| assert (
|
| video_id not in video_id_to_keyframes
|
| ), f"Duplicate keyframes entry for video {fpath}"
|
| video_id_to_keyframes[video_id] = (
|
| [int(v) for v in row[keyframes_idx][1:-1].split(",")]
|
| if len(row[keyframes_idx]) > 2
|
| else []
|
| )
|
| except Exception as e:
|
| logger = logging.getLogger(__name__)
|
| logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
|
| return video_id_to_keyframes
|
|
|
|
|
| class VideoKeyframeDataset(Dataset):
|
| """
|
| Dataset that provides keyframes for a set of videos.
|
| """
|
|
|
| _EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
|
|
|
| def __init__(
|
| self,
|
| video_list: List[str],
|
| category_list: Union[str, List[str], None] = None,
|
| frame_selector: Optional[FrameSelector] = None,
|
| transform: Optional[FrameTransform] = None,
|
| keyframe_helper_fpath: Optional[str] = None,
|
| ):
|
| """
|
| Dataset constructor
|
|
|
| Args:
|
| video_list (List[str]): list of paths to video files
|
| category_list (Union[str, List[str], None]): list of animal categories for each
|
| video file. If it is a string, or None, this applies to all videos
|
| frame_selector (Callable: KeyFrameList -> KeyFrameList):
|
| selects keyframes to process, keyframes are given by
|
| packet timestamps in timebase counts. If None, all keyframes
|
| are selected (default: None)
|
| transform (Callable: torch.Tensor -> torch.Tensor):
|
| transforms a batch of RGB images (tensors of size [B, 3, H, W]),
|
| returns a tensor of the same size. If None, no transform is
|
| applied (default: None)
|
|
|
| """
|
| if type(category_list) == list:
|
| self.category_list = category_list
|
| else:
|
| self.category_list = [category_list] * len(video_list)
|
| assert len(video_list) == len(
|
| self.category_list
|
| ), "length of video and category lists must be equal"
|
| self.video_list = video_list
|
| self.frame_selector = frame_selector
|
| self.transform = transform
|
| self.keyframe_helper_data = (
|
| read_keyframe_helper_data(keyframe_helper_fpath)
|
| if keyframe_helper_fpath is not None
|
| else None
|
| )
|
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| """
|
| Gets selected keyframes from a given video
|
|
|
| Args:
|
| idx (int): video index in the video list file
|
| Returns:
|
| A dictionary containing two keys:
|
| images (torch.Tensor): tensor of size [N, H, W, 3] or of size
|
| defined by the transform that contains keyframes data
|
| categories (List[str]): categories of the frames
|
| """
|
| categories = [self.category_list[idx]]
|
| fpath = self.video_list[idx]
|
| keyframes = (
|
| list_keyframes(fpath)
|
| if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
|
| else self.keyframe_helper_data[idx]
|
| )
|
| transform = self.transform
|
| frame_selector = self.frame_selector
|
| if not keyframes:
|
| return {"images": self._EMPTY_FRAMES, "categories": []}
|
| if frame_selector is not None:
|
| keyframes = frame_selector(keyframes)
|
| frames = read_keyframes(fpath, keyframes)
|
| if not frames:
|
| return {"images": self._EMPTY_FRAMES, "categories": []}
|
| frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
|
| frames = torch.as_tensor(frames, device=torch.device("cpu"))
|
| frames = frames[..., [2, 1, 0]]
|
| frames = frames.permute(0, 3, 1, 2).float()
|
| if transform is not None:
|
| frames = transform(frames)
|
| return {"images": frames, "categories": categories}
|
|
|
| def __len__(self):
|
| return len(self.video_list)
|
|
|