# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import av import cv2 import numpy as np import torch # noqa: F401 # isort: skip import torchvision # noqa: F401 # isort: skip # Import decord with graceful fallback try: import decord # noqa: F401 DECORD_AVAILABLE = True except ImportError: DECORD_AVAILABLE = False try: import torchcodec TORCHCODEC_AVAILABLE = True except (ImportError, RuntimeError): TORCHCODEC_AVAILABLE = False def get_frames_by_indices( video_path: str, indices: list[int] | np.ndarray, video_backend: str = "decord", video_backend_kwargs: dict = {}, ) -> np.ndarray: if video_backend == "decord": if not DECORD_AVAILABLE: raise ImportError("decord is not available.") vr = decord.VideoReader(video_path, **video_backend_kwargs) frames = vr.get_batch(indices) return frames.asnumpy() elif video_backend == "torchcodec": if not TORCHCODEC_AVAILABLE: raise ImportError("torchcodec is not available.") decoder = torchcodec.decoders.VideoDecoder( video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 ) return decoder.get_frames_at(indices=indices).data.numpy() elif video_backend == "opencv": frames = [] cap = cv2.VideoCapture(video_path, **video_backend_kwargs) for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret: raise ValueError(f"Unable to read frame at index {idx}") frames.append(frame) cap.release() frames = np.array(frames) return frames else: raise NotImplementedError def get_frames_by_timestamps( video_path: str, timestamps: list[float] | np.ndarray, video_backend: str = "decord", video_backend_kwargs: dict = {}, ) -> np.ndarray: """Get frames from a video at specified timestamps. Args: video_path (str): Path to the video file. timestamps (list[int] | np.ndarray): Timestamps to retrieve frames for, in seconds. video_backend (str, optional): Video backend to use. Defaults to "decord". Returns: np.ndarray: Frames at the specified timestamps. """ if video_backend == "decord": # For some GPUs, AV format data cannot be read if not DECORD_AVAILABLE: raise ImportError("decord is not available.") vr = decord.VideoReader(video_path, **video_backend_kwargs) num_frames = len(vr) # Retrieve the timestamps for each frame in the video frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) # Map each requested timestamp to the closest frame index # Only take the first element of the frame_ts array which corresponds to start_seconds indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) frames = vr.get_batch(indices) return frames.asnumpy() elif video_backend == "torchcodec": if not TORCHCODEC_AVAILABLE: raise ImportError("torchcodec is not available.") decoder = torchcodec.decoders.VideoDecoder( video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 ) return decoder.get_frames_played_at(seconds=timestamps).data.numpy() elif video_backend == "opencv": # Open the video file cap = cv2.VideoCapture(video_path, **video_backend_kwargs) if not cap.isOpened(): raise ValueError(f"Unable to open video file: {video_path}") # Retrieve the total number of frames num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Calculate timestamps for each frame fps = cap.get(cv2.CAP_PROP_FPS) frame_ts = np.arange(num_frames) / fps frame_ts = frame_ts[:, np.newaxis] # Reshape to (num_frames, 1) for broadcasting # Map each requested timestamp to the closest frame index indices = np.abs(frame_ts - timestamps).argmin(axis=0) frames = [] for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret: raise ValueError(f"Unable to read frame at index {idx}") frames.append(frame) cap.release() frames = np.array(frames) return frames elif video_backend == "torchvision_av": torchvision.set_video_backend("pyav") loaded_frames = [] loaded_ts = [] reader = None try: reader = torchvision.io.VideoReader(video_path, "video") for target_ts in timestamps: # Reset reader state reader.seek(target_ts, keyframes_only=True) closest_frame = None closest_ts_diff = float('inf') for frame in reader: current_ts = frame["pts"] current_diff = abs(current_ts - target_ts) if closest_frame is None: closest_frame = frame if current_diff < closest_ts_diff: # Release the previous frame if closest_frame is not None: del closest_frame closest_ts_diff = current_diff closest_frame = frame else: # The time difference starts to increase, stop searching break if closest_frame is not None: frame_data = closest_frame["data"] if isinstance(frame_data, torch.Tensor): frame_data = frame_data.cpu().numpy() loaded_frames.append(frame_data) loaded_ts.append(closest_frame["pts"]) # Immediately release frame reference del closest_frame finally: # Thoroughly clean resources if reader is not None: if hasattr(reader, '_c'): reader._c = None if hasattr(reader, 'container'): reader.container.close() reader.container = None # Force garbage collection import gc gc.collect() frames = np.array(loaded_frames) return frames.transpose(0, 2, 3, 1) else: raise NotImplementedError def get_all_frames( video_path: str, video_backend: str = "decord", video_backend_kwargs: dict = {}, resize_size: tuple[int, int] | None = None, ) -> np.ndarray: """Get all frames from a video. Args: video_path (str): Path to the video file. video_backend (str, optional): Video backend to use. Defaults to "decord". video_backend_kwargs (dict, optional): Keyword arguments for the video backend. resize_size (tuple[int, int], optional): Resize size for the frames. Defaults to None. """ if video_backend == "decord": if not DECORD_AVAILABLE: raise ImportError("decord is not available.") vr = decord.VideoReader(video_path, **video_backend_kwargs) frames = vr.get_batch(range(len(vr))).asnumpy() elif video_backend == "torchcodec": if not TORCHCODEC_AVAILABLE: raise ImportError("torchcodec is not available.") decoder = torchcodec.decoders.VideoDecoder( video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 ) frames = decoder.get_frames_at(indices=range(len(decoder))) return frames.data.numpy(), frames.pts_seconds.numpy() elif video_backend == "pyav": container = av.open(video_path) frames = [] for frame in container.decode(video=0): frame = frame.to_ndarray(format="rgb24") frames.append(frame) frames = np.array(frames) elif video_backend == "torchvision_av": # set backend and reader torchvision.set_video_backend("pyav") reader = torchvision.io.VideoReader(video_path, "video") frames = [] for frame in reader: frames.append(frame["data"].numpy()) frames = np.array(frames) frames = frames.transpose(0, 2, 3, 1) else: raise NotImplementedError(f"Video backend {video_backend} not implemented") # resize frames if specified if resize_size is not None: frames = [cv2.resize(frame, resize_size) for frame in frames] frames = np.array(frames) return frames