MoTIF / utils /core /transforms /video_transform.py
P4ddyki's picture
Upload folder using huggingface_hub
3cf4fff verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from logging import getLogger
from typing import Callable, List, Tuple
import torch
from PIL import Image
from torchcodec.decoders import VideoDecoder
from torchvision.utils import draw_bounding_boxes
from core.transforms.image_transform import ImageTransform
logger = getLogger()
def get_video_transform(
image_res: int = 224,
normalize_img: bool = True,
) -> Tuple[Callable, int]:
transforms = VideoTransform(
size=image_res,
normalize_img=normalize_img,
)
return transforms
class VideoTransform(ImageTransform):
def __init__(
self,
size: int = 224,
normalize_img: bool = True,
) -> None:
super().__init__(
size=size,
normalize_img=normalize_img,
)
def __call__(self, video_info: tuple, sampling_fps: int = 1):
video_path, max_frames, s, e, bbox_map = video_info
frames, sample_pos = self.load_video(
video_path,
max_frames=max_frames,
sampling_fps=sampling_fps,
s=s,
e=e,
)
if bbox_map:
bbox_dict_map = {}
for idx_pos, pos in enumerate(sample_pos):
if str(pos) in bbox_map and bbox_map[str(pos)] is not None:
bbox_dict_map[idx_pos] = bbox_map[str(pos)]
if len(bbox_dict_map) > 0:
frames = self.draw_bounding_boxes(frames, bbox_dict_map)
return super()._transform_torch_tensor(frames)
def _process_multiple_images(self, image_paths: List[str]):
images = [Image.open(path).convert("RGB") for path in image_paths]
processed_images = []
for image in images:
image, (w, h) = super().__call__(image)
processed_images.append(image)
processed_images = torch.cat(processed_images, dim=0)
return processed_images, (w, h)
def _process_multiple_images_pil(self, images: List[Image.Image]):
processed_images = []
for image in images:
image, (w, h) = super().__call__(image)
processed_images.append(image)
processed_images = torch.cat(processed_images, dim=0)
return processed_images, (w, h)
def load_video(self, video_path, max_frames=16, sampling_fps=1, s=None, e=None):
"""
Loads a video from a given path and extracts frames based on specified parameters using OpenCV.
Args:
video_path (str): The path to the video file.
max_frames (int, optional): The maximum number of frames to extract. Defaults to 16.
sampling_fps (int, optional): The sampling frame rate. Defaults to 1.
s (float, optional): The start time of the video in seconds. Defaults to None.
e (float, optional): The end time of the video in seconds. Defaults to None.
Returns:
list: A list of frames extracted from the video.
"""
if not os.path.exists(video_path):
return
decoder = VideoDecoder(video_path, device="cpu")
decoder_metadata = decoder.metadata
fps = decoder_metadata.average_fps
total_frames = decoder_metadata.num_frames
start_frame = 0 if s is None else int(s * fps)
end_frame = total_frames - 1 if e is None else int(e * fps)
end_frame = min(end_frame, total_frames - 1)
if start_frame > end_frame:
start_frame, end_frame = end_frame, start_frame
elif start_frame == end_frame:
end_frame = start_frame + 1
sample_fps = int(sampling_fps)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(start_frame, end_frame + 1, t_stride))
if len(all_pos) > max_frames:
sample_idxs = self.uniform_sample(len(all_pos), max_frames)
sample_pos = [all_pos[i] for i in sample_idxs]
elif len(all_pos) < max_frames:
total_clip_frames = end_frame - start_frame + 1
if total_clip_frames < max_frames:
max_frames = total_clip_frames
sample_idxs = self.uniform_sample(total_clip_frames, max_frames)
sample_pos = [start_frame + idx for idx in sample_idxs]
else:
sample_pos = all_pos
all_frames = decoder.get_frames_at(indices=sample_pos)
all_frames = all_frames.data
return all_frames, sample_pos
def uniform_sample(self, m, n):
assert n <= m
stride = (m - 1) / (n - 1) if n > 1 else 0 # Calculate the stride
return [int(round(i * stride)) for i in range(n)]
def draw_bounding_boxes(self, frames, all_bboxes):
# Assuming frames is a torch.Tensor with shape (N, C, H, W)
N, _, _, _ = frames.shape
frames_with_bbox = (
frames.clone()
) # Clone the tensor to avoid modifying the original
for i in range(N):
if i in all_bboxes:
bbox = all_bboxes[i]
# Convert bbox to a tensor and add a batch dimension
bbox_tensor = torch.tensor([bbox], dtype=torch.float32)
# Draw the bounding box on the frame
frames_with_bbox[i] = draw_bounding_boxes(
frames_with_bbox[i],
boxes=bbox_tensor,
colors=(255, 0, 0), # Red color for the bounding box
width=4,
)
return frames_with_bbox