|
|
import time |
|
|
import torch |
|
|
import cv2 |
|
|
|
|
|
|
|
|
def count_frames(video_path): |
|
|
|
|
|
video = cv2.VideoCapture(video_path) |
|
|
|
|
|
|
|
|
actual_frame_count = 0 |
|
|
while True: |
|
|
ret, frame = video.read() |
|
|
if not ret: |
|
|
break |
|
|
actual_frame_count += 1 |
|
|
|
|
|
|
|
|
video.release() |
|
|
return actual_frame_count |
|
|
|
|
|
def smart_nframes( |
|
|
ele: dict, |
|
|
total_frames: int, |
|
|
video_fps: int | float, |
|
|
) -> int: |
|
|
"""calculate the number of frames for video used for model inputs. |
|
|
|
|
|
Args: |
|
|
ele (dict): a dict contains the configuration of video. |
|
|
support either `fps` or `nframes`: |
|
|
- nframes: the number of frames to extract for model inputs. |
|
|
- fps: the fps to extract frames for model inputs. |
|
|
- min_frames: the minimum number of frames of the video, only used when fps is provided. |
|
|
- max_frames: the maximum number of frames of the video, only used when fps is provided. |
|
|
total_frames (int): the original total number of frames of the video. |
|
|
video_fps (int | float): the original fps of the video. |
|
|
|
|
|
Raises: |
|
|
ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. |
|
|
|
|
|
Returns: |
|
|
int: the number of frames for video used for model inputs. |
|
|
""" |
|
|
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" |
|
|
if "nframes" in ele: |
|
|
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) |
|
|
else: |
|
|
fps = ele.get("fps", FPS) |
|
|
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) |
|
|
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) |
|
|
nframes = total_frames / video_fps * fps |
|
|
if nframes > total_frames: |
|
|
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") |
|
|
nframes = min(min(max(nframes, min_frames), max_frames), total_frames) |
|
|
nframes = floor_by_factor(nframes, FRAME_FACTOR) |
|
|
if not (FRAME_FACTOR <= nframes and nframes <= total_frames): |
|
|
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") |
|
|
return nframes |
|
|
|
|
|
def _read_video_decord( |
|
|
ele: dict, |
|
|
) -> (torch.Tensor, float): |
|
|
"""read video using decord.VideoReader |
|
|
|
|
|
Args: |
|
|
ele (dict): a dict contains the configuration of video. |
|
|
support keys: |
|
|
- video: the path of video. support "file://", "http://", "https://" and local path. |
|
|
- video_start: the start time of video. |
|
|
- video_end: the end time of video. |
|
|
Returns: |
|
|
torch.Tensor: the video tensor with shape (T, C, H, W). |
|
|
""" |
|
|
import decord |
|
|
video_path = ele["video"] |
|
|
st = time.time() |
|
|
import pdb; pdb.set_trace() |
|
|
vr = decord.VideoReader(video_path) |
|
|
|
|
|
if 'video_start' in ele or 'video_end' in ele: |
|
|
raise NotImplementedError("not support start_pts and end_pts in decord for now.") |
|
|
|
|
|
actual_frame_count = count_frames(video_path) |
|
|
total_frames, video_fps = len(vr), vr.get_avg_fps() |
|
|
total_frames = actual_frame_count |
|
|
|
|
|
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) |
|
|
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() |
|
|
video = vr.get_batch(idx).asnumpy() |
|
|
video = torch.tensor(video).permute(0, 3, 1, 2) |
|
|
sample_fps = nframes / max(total_frames, 1e-6) * video_fps |
|
|
return video, sample_fps |
|
|
|
|
|
|
|
|
ele_example = { |
|
|
'video': "/home/world_model/egoexo4d/keystep_train_takes-cut/georgiatech_cooking_14_02_2/aria02_214-1_0000030.mp4" |
|
|
} |
|
|
|
|
|
video, sample_fps = _read_video_decord(ele_example) |
|
|
print(video.shape) |
|
|
print(sample_fps) |