import time import torch import cv2 def count_frames(video_path): # 打开视频文件 video = cv2.VideoCapture(video_path) # 统计实际读取到的帧数 actual_frame_count = 0 while True: ret, frame = video.read() if not ret: break actual_frame_count += 1 # 释放视频对象 video.release() return actual_frame_count def smart_nframes( ele: dict, total_frames: int, video_fps: int | float, ) -> int: """calculate the number of frames for video used for model inputs. Args: ele (dict): a dict contains the configuration of video. support either `fps` or `nframes`: - nframes: the number of frames to extract for model inputs. - fps: the fps to extract frames for model inputs. - min_frames: the minimum number of frames of the video, only used when fps is provided. - max_frames: the maximum number of frames of the video, only used when fps is provided. total_frames (int): the original total number of frames of the video. video_fps (int | float): the original fps of the video. Raises: ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. Returns: int: the number of frames for video used for model inputs. """ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" if "nframes" in ele: nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) else: fps = ele.get("fps", FPS) min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) nframes = total_frames / video_fps * fps if nframes > total_frames: logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") nframes = min(min(max(nframes, min_frames), max_frames), total_frames) nframes = floor_by_factor(nframes, FRAME_FACTOR) if not (FRAME_FACTOR <= nframes and nframes <= total_frames): raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") return nframes def _read_video_decord( ele: dict, ) -> (torch.Tensor, float): """read video using decord.VideoReader Args: ele (dict): a dict contains the configuration of video. support keys: - video: the path of video. support "file://", "http://", "https://" and local path. - video_start: the start time of video. - video_end: the end time of video. Returns: torch.Tensor: the video tensor with shape (T, C, H, W). """ import decord video_path = ele["video"] st = time.time() import pdb; pdb.set_trace() vr = decord.VideoReader(video_path) # TODO: support start_pts and end_pts if 'video_start' in ele or 'video_end' in ele: raise NotImplementedError("not support start_pts and end_pts in decord for now.") actual_frame_count = count_frames(video_path) total_frames, video_fps = len(vr), vr.get_avg_fps() total_frames = actual_frame_count #logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() video = vr.get_batch(idx).asnumpy() video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format sample_fps = nframes / max(total_frames, 1e-6) * video_fps return video, sample_fps ele_example = { 'video': "/home/world_model/egoexo4d/keystep_train_takes-cut/georgiatech_cooking_14_02_2/aria02_214-1_0000030.mp4" } video, sample_fps = _read_video_decord(ele_example) print(video.shape) print(sample_fps)