import gc import random import shutil import subprocess from contextlib import contextmanager from typing import List, Optional, Tuple import numpy as np from decord import VideoReader from PIL import Image ALL_FRAME_SAMPLE_METHODS = [ "mid", "uniform", "random", "stride", "first", "last", "keyframe", "keyframe+first", "keyframe+last" ] @contextmanager def video_reader(*args, **kwargs): """A context manager to solve the memory leak of decord. """ vr = VideoReader(*args, **kwargs) try: yield vr finally: del vr gc.collect() def get_keyframe_index(video_path): """Extract the frame index list of I-frames. In general, the first frame in a video should be the I-frame. The extracted frame index is more accurate than the pts_time * avg_fps. """ assert shutil.which("ffprobe") is not None, f"Please install ffprobe and make sure it is in the system path." command = [ "ffprobe", "-v", "quiet", "-select_streams", "v:0", "-show_entries", "frame=pict_type", "-of", "csv=p=0", video_path ] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) keyframe_index_list = [] frame_index = 0 for line in result.stdout.split("\n"): line = line.strip(",") pict_type = line.strip() if pict_type == "I": keyframe_index_list.append(frame_index) if pict_type == "I" or pict_type == "B" or pict_type == "P": frame_index += 1 return keyframe_index_list, frame_index def extract_frames( video_path: str, sample_method: str = "mid", num_sampled_frames: int = 1, sample_stride: Optional[int] = None, **kwargs ) -> Optional[Tuple[List[int], List[Image.Image]]]: if num_sampled_frames < 1: raise ValueError(f"The num_sampled_frames must be greater than 1.") if sample_stride is not None and sample_stride < 1: raise ValueError(f"The sample_stride must be greater than 1.") if sample_stride is not None and sample_method not in ["random", "stride"]: raise ValueError(f"The sample_method must be random or stride when sample_stride is specified.") with video_reader(video_path, num_threads=2, **kwargs) as vr: if sample_method == "mid": sampled_frame_idx_list = [len(vr) // 2] elif sample_method == "uniform": sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) elif sample_method == "random": clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1) start_idx = random.randint(0, len(vr) - clip_length) sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int) elif sample_method == "stride": sampled_frame_idx_list = np.arange(0, len(vr), sample_stride) elif sample_method == "first": sampled_frame_idx_list = [0] elif sample_method == "last": sampled_frame_idx_list = [len(vr) - 1] elif sample_method == "keyframe": sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) elif sample_method == "keyframe+first": # keyframe + the first second sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) if len(sampled_frame_idx_list) == 1 or sampled_frame_idx_list[1] > 1 * vr.get_avg_fps(): if int(1 * vr.get_avg_fps()) > len(vr): raise ValueError(f"The duration of {video_path} is less than 1s.") sampled_frame_idx_list.insert(1, int(1 * vr.get_avg_fps())) elif sample_method == "keyframe+last": # keyframe + the last frame sampled_frame_idx_list, final_frame_index = get_keyframe_index(video_path) if sampled_frame_idx_list[-1] != (len(vr) - 1): sampled_frame_idx_list.append(len(vr) - 1) else: raise ValueError(f"The sample_method must be within {ALL_FRAME_SAMPLE_METHODS}.") if "keyframe" in sample_method: if final_frame_index != len(vr): raise ValueError(f"The keyframe index list is not accurate. Please check the video {video_path}.") sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy() sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list] return list(sampled_frame_idx_list), sampled_frame_list