| | |
| |
|
| | import os |
| | from logging import getLogger |
| | from typing import Callable, List, Tuple |
| |
|
| | import torch |
| | from PIL import Image |
| | from torchcodec.decoders import VideoDecoder |
| | from torchvision.utils import draw_bounding_boxes |
| |
|
| | from core.transforms.image_transform import ImageTransform |
| |
|
| | logger = getLogger() |
| |
|
| |
|
| | def get_video_transform( |
| | image_res: int = 224, |
| | normalize_img: bool = True, |
| | ) -> Tuple[Callable, int]: |
| |
|
| | transforms = VideoTransform( |
| | size=image_res, |
| | normalize_img=normalize_img, |
| | ) |
| |
|
| | return transforms |
| |
|
| |
|
| | class VideoTransform(ImageTransform): |
| | def __init__( |
| | self, |
| | size: int = 224, |
| | normalize_img: bool = True, |
| | ) -> None: |
| | super().__init__( |
| | size=size, |
| | normalize_img=normalize_img, |
| | ) |
| |
|
| | def __call__(self, video_info: tuple, sampling_fps: int = 1): |
| | video_path, max_frames, s, e, bbox_map = video_info |
| |
|
| | frames, sample_pos = self.load_video( |
| | video_path, |
| | max_frames=max_frames, |
| | sampling_fps=sampling_fps, |
| | s=s, |
| | e=e, |
| | ) |
| |
|
| | if bbox_map: |
| | bbox_dict_map = {} |
| | for idx_pos, pos in enumerate(sample_pos): |
| | if str(pos) in bbox_map and bbox_map[str(pos)] is not None: |
| | bbox_dict_map[idx_pos] = bbox_map[str(pos)] |
| | if len(bbox_dict_map) > 0: |
| | frames = self.draw_bounding_boxes(frames, bbox_dict_map) |
| |
|
| | return super()._transform_torch_tensor(frames) |
| |
|
| | def _process_multiple_images(self, image_paths: List[str]): |
| | images = [Image.open(path).convert("RGB") for path in image_paths] |
| | processed_images = [] |
| | for image in images: |
| | image, (w, h) = super().__call__(image) |
| | processed_images.append(image) |
| | processed_images = torch.cat(processed_images, dim=0) |
| | return processed_images, (w, h) |
| |
|
| | def _process_multiple_images_pil(self, images: List[Image.Image]): |
| | processed_images = [] |
| | for image in images: |
| | image, (w, h) = super().__call__(image) |
| | processed_images.append(image) |
| | processed_images = torch.cat(processed_images, dim=0) |
| | return processed_images, (w, h) |
| |
|
| | def load_video(self, video_path, max_frames=16, sampling_fps=1, s=None, e=None): |
| | """ |
| | Loads a video from a given path and extracts frames based on specified parameters using OpenCV. |
| | |
| | Args: |
| | video_path (str): The path to the video file. |
| | max_frames (int, optional): The maximum number of frames to extract. Defaults to 16. |
| | sampling_fps (int, optional): The sampling frame rate. Defaults to 1. |
| | s (float, optional): The start time of the video in seconds. Defaults to None. |
| | e (float, optional): The end time of the video in seconds. Defaults to None. |
| | |
| | Returns: |
| | list: A list of frames extracted from the video. |
| | """ |
| |
|
| | if not os.path.exists(video_path): |
| | return |
| |
|
| | decoder = VideoDecoder(video_path, device="cpu") |
| | decoder_metadata = decoder.metadata |
| | fps = decoder_metadata.average_fps |
| | total_frames = decoder_metadata.num_frames |
| |
|
| | start_frame = 0 if s is None else int(s * fps) |
| | end_frame = total_frames - 1 if e is None else int(e * fps) |
| | end_frame = min(end_frame, total_frames - 1) |
| |
|
| | if start_frame > end_frame: |
| | start_frame, end_frame = end_frame, start_frame |
| | elif start_frame == end_frame: |
| | end_frame = start_frame + 1 |
| |
|
| | sample_fps = int(sampling_fps) |
| | t_stride = int(round(float(fps) / sample_fps)) |
| |
|
| | all_pos = list(range(start_frame, end_frame + 1, t_stride)) |
| | if len(all_pos) > max_frames: |
| | sample_idxs = self.uniform_sample(len(all_pos), max_frames) |
| | sample_pos = [all_pos[i] for i in sample_idxs] |
| | elif len(all_pos) < max_frames: |
| | total_clip_frames = end_frame - start_frame + 1 |
| | if total_clip_frames < max_frames: |
| | max_frames = total_clip_frames |
| | sample_idxs = self.uniform_sample(total_clip_frames, max_frames) |
| | sample_pos = [start_frame + idx for idx in sample_idxs] |
| | else: |
| | sample_pos = all_pos |
| |
|
| | all_frames = decoder.get_frames_at(indices=sample_pos) |
| | all_frames = all_frames.data |
| |
|
| | return all_frames, sample_pos |
| |
|
| | def uniform_sample(self, m, n): |
| | assert n <= m |
| | stride = (m - 1) / (n - 1) if n > 1 else 0 |
| | return [int(round(i * stride)) for i in range(n)] |
| |
|
| | def draw_bounding_boxes(self, frames, all_bboxes): |
| | |
| | N, _, _, _ = frames.shape |
| | frames_with_bbox = ( |
| | frames.clone() |
| | ) |
| | for i in range(N): |
| | if i in all_bboxes: |
| | bbox = all_bboxes[i] |
| | |
| | bbox_tensor = torch.tensor([bbox], dtype=torch.float32) |
| | |
| | frames_with_bbox[i] = draw_bounding_boxes( |
| | frames_with_bbox[i], |
| | boxes=bbox_tensor, |
| | colors=(255, 0, 0), |
| | width=4, |
| | ) |
| | return frames_with_bbox |
| |
|