|
|
|
|
|
|
|
|
import os |
|
|
from logging import getLogger |
|
|
from typing import Callable, List, Tuple |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
from torchvision.utils import draw_bounding_boxes |
|
|
|
|
|
from core.transforms.image_transform import ImageTransform |
|
|
|
|
|
logger = getLogger() |
|
|
|
|
|
|
|
|
def get_video_transform( |
|
|
image_res: int = 224, |
|
|
normalize_img: bool = True, |
|
|
) -> Tuple[Callable, int]: |
|
|
|
|
|
transforms = VideoTransform( |
|
|
size=image_res, |
|
|
normalize_img=normalize_img, |
|
|
) |
|
|
|
|
|
return transforms |
|
|
|
|
|
|
|
|
class VideoTransform(ImageTransform): |
|
|
def __init__( |
|
|
self, |
|
|
size: int = 224, |
|
|
normalize_img: bool = True, |
|
|
) -> None: |
|
|
super().__init__( |
|
|
size=size, |
|
|
normalize_img=normalize_img, |
|
|
) |
|
|
|
|
|
def __call__(self, video_info: tuple, sampling_fps: int = 1): |
|
|
video_path, max_frames, s, e, bbox_map = video_info |
|
|
|
|
|
frames, sample_pos = self.load_video( |
|
|
video_path, |
|
|
max_frames=max_frames, |
|
|
sampling_fps=sampling_fps, |
|
|
s=s, |
|
|
e=e, |
|
|
) |
|
|
|
|
|
if bbox_map: |
|
|
bbox_dict_map = {} |
|
|
for idx_pos, pos in enumerate(sample_pos): |
|
|
if str(pos) in bbox_map and bbox_map[str(pos)] is not None: |
|
|
bbox_dict_map[idx_pos] = bbox_map[str(pos)] |
|
|
if len(bbox_dict_map) > 0: |
|
|
frames = self.draw_bounding_boxes(frames, bbox_dict_map) |
|
|
|
|
|
return super()._transform_torch_tensor(frames) |
|
|
|
|
|
def _process_multiple_images(self, image_paths: List[str]): |
|
|
images = [Image.open(path).convert("RGB") for path in image_paths] |
|
|
processed_images = [] |
|
|
for image in images: |
|
|
image, (w, h) = super().__call__(image) |
|
|
processed_images.append(image) |
|
|
processed_images = torch.cat(processed_images, dim=0) |
|
|
return processed_images, (w, h) |
|
|
|
|
|
def _process_multiple_images_pil(self, images: List[Image.Image]): |
|
|
processed_images = [] |
|
|
for image in images: |
|
|
image, (w, h) = super().__call__(image) |
|
|
processed_images.append(image) |
|
|
processed_images = torch.cat(processed_images, dim=0) |
|
|
return processed_images, (w, h) |
|
|
|
|
|
def load_video(self, video_path, max_frames=16, sampling_fps=1, s=None, e=None): |
|
|
""" |
|
|
Loads a video from a given path and extracts frames based on specified parameters using OpenCV. |
|
|
|
|
|
Args: |
|
|
video_path (str): The path to the video file. |
|
|
max_frames (int, optional): The maximum number of frames to extract. Defaults to 16. |
|
|
sampling_fps (int, optional): The sampling frame rate. Defaults to 1. |
|
|
s (float, optional): The start time of the video in seconds. Defaults to None. |
|
|
e (float, optional): The end time of the video in seconds. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
list: A list of frames extracted from the video. |
|
|
""" |
|
|
|
|
|
if not os.path.exists(video_path): |
|
|
return |
|
|
|
|
|
decoder = VideoDecoder(video_path, device="cpu") |
|
|
decoder_metadata = decoder.metadata |
|
|
fps = decoder_metadata.average_fps |
|
|
total_frames = decoder_metadata.num_frames |
|
|
|
|
|
start_frame = 0 if s is None else int(s * fps) |
|
|
end_frame = total_frames - 1 if e is None else int(e * fps) |
|
|
end_frame = min(end_frame, total_frames - 1) |
|
|
|
|
|
if start_frame > end_frame: |
|
|
start_frame, end_frame = end_frame, start_frame |
|
|
elif start_frame == end_frame: |
|
|
end_frame = start_frame + 1 |
|
|
|
|
|
sample_fps = int(sampling_fps) |
|
|
t_stride = int(round(float(fps) / sample_fps)) |
|
|
|
|
|
all_pos = list(range(start_frame, end_frame + 1, t_stride)) |
|
|
if len(all_pos) > max_frames: |
|
|
sample_idxs = self.uniform_sample(len(all_pos), max_frames) |
|
|
sample_pos = [all_pos[i] for i in sample_idxs] |
|
|
elif len(all_pos) < max_frames: |
|
|
total_clip_frames = end_frame - start_frame + 1 |
|
|
if total_clip_frames < max_frames: |
|
|
max_frames = total_clip_frames |
|
|
sample_idxs = self.uniform_sample(total_clip_frames, max_frames) |
|
|
sample_pos = [start_frame + idx for idx in sample_idxs] |
|
|
else: |
|
|
sample_pos = all_pos |
|
|
|
|
|
all_frames = decoder.get_frames_at(indices=sample_pos) |
|
|
all_frames = all_frames.data |
|
|
|
|
|
return all_frames, sample_pos |
|
|
|
|
|
def uniform_sample(self, m, n): |
|
|
assert n <= m |
|
|
stride = (m - 1) / (n - 1) if n > 1 else 0 |
|
|
return [int(round(i * stride)) for i in range(n)] |
|
|
|
|
|
def draw_bounding_boxes(self, frames, all_bboxes): |
|
|
|
|
|
N, _, _, _ = frames.shape |
|
|
frames_with_bbox = ( |
|
|
frames.clone() |
|
|
) |
|
|
for i in range(N): |
|
|
if i in all_bboxes: |
|
|
bbox = all_bboxes[i] |
|
|
|
|
|
bbox_tensor = torch.tensor([bbox], dtype=torch.float32) |
|
|
|
|
|
frames_with_bbox[i] = draw_bounding_boxes( |
|
|
frames_with_bbox[i], |
|
|
boxes=bbox_tensor, |
|
|
colors=(255, 0, 0), |
|
|
width=4, |
|
|
) |
|
|
return frames_with_bbox |
|
|
|