Spaces:
Paused
Paused
| # ************************************************************************* | |
| # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- | |
| # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- | |
| # ytedance Inc.. | |
| # ************************************************************************* | |
| # Copyright 2022 ByteDance and/or its affiliates. | |
| # | |
| # Copyright (2022) PV3D Authors | |
| # | |
| # ByteDance, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from ByteDance or | |
| # its affiliates is strictly prohibited. | |
| import av, gc | |
| import torch | |
| import warnings | |
| import numpy as np | |
| _CALLED_TIMES = 0 | |
| _GC_COLLECTION_INTERVAL = 20 | |
| # remove warnings | |
| av.logging.set_level(av.logging.ERROR) | |
| class VideoReader(): | |
| """ | |
| Simple wrapper around PyAV that exposes a few useful functions for | |
| dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries. | |
| Acknowledgement: Codes are borrowed from Bruno Korbar | |
| """ | |
| def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False): | |
| """ | |
| Arguments: | |
| video_path (str): path or byte of the video to be loaded | |
| """ | |
| self.container = av.open(video) | |
| self.num_frames = num_frames | |
| self.bi_frame = bi_frame | |
| self.resampler = None | |
| if audio_resample_rate is not None: | |
| self.resampler = av.AudioResampler(rate=audio_resample_rate) | |
| if self.container.streams.video: | |
| # enable multi-threaded video decoding | |
| if decode_lossy: | |
| warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) | |
| self.container.streams.video[0].thread_type = 'AUTO' | |
| self.video_stream = self.container.streams.video[0] | |
| else: | |
| self.video_stream = None | |
| self.fps = self._get_video_frame_rate() | |
| def seek(self, pts, backward=True, any_frame=False): | |
| stream = self.video_stream | |
| self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream) | |
| def _occasional_gc(self): | |
| # there are a lot of reference cycles in PyAV, so need to manually call | |
| # the garbage collector from time to time | |
| global _CALLED_TIMES, _GC_COLLECTION_INTERVAL | |
| _CALLED_TIMES += 1 | |
| if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: | |
| gc.collect() | |
| def _read_video(self, offset): | |
| self._occasional_gc() | |
| pts = self.container.duration * offset | |
| time_ = pts / float(av.time_base) | |
| self.container.seek(int(pts)) | |
| video_frames = [] | |
| count = 0 | |
| for _, frame in enumerate(self._iter_frames()): | |
| if frame.pts * frame.time_base >= time_: | |
| video_frames.append(frame) | |
| if count >= self.num_frames - 1: | |
| break | |
| count += 1 | |
| return video_frames | |
| def _iter_frames(self): | |
| for packet in self.container.demux(self.video_stream): | |
| for frame in packet.decode(): | |
| yield frame | |
| def _compute_video_stats(self): | |
| if self.video_stream is None or self.container is None: | |
| return 0 | |
| num_of_frames = self.container.streams.video[0].frames | |
| if num_of_frames == 0: | |
| num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base) | |
| self.seek(0, backward=False) | |
| count = 0 | |
| time_base = 512 | |
| for p in self.container.decode(video=0): | |
| count = count + 1 | |
| if count == 1: | |
| start_pts = p.pts | |
| elif count == 2: | |
| time_base = p.pts - start_pts | |
| break | |
| return start_pts, time_base, num_of_frames | |
| def _get_video_frame_rate(self): | |
| return float(self.container.streams.video[0].guessed_rate) | |
| def sample(self, debug=False): | |
| if self.container is None: | |
| raise RuntimeError('video stream not found') | |
| sample = dict() | |
| _, _, total_num_frames = self._compute_video_stats() | |
| offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item() | |
| video_frames = self._read_video(offset/total_num_frames) | |
| video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) | |
| sample["frames"] = video_frames | |
| sample["frame_idx"] = [offset] | |
| if self.bi_frame: | |
| frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)] | |
| frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)] | |
| frames.sort() | |
| video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]]) | |
| Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)] | |
| sample["frames"] = video_frames | |
| sample["real_t"] = torch.tensor(Ts, dtype=torch.float32) | |
| sample["frame_idx"] = [offset+min(frames), offset+max(frames)] | |
| return sample | |
| return sample | |
| def read_frames(self, frame_indices): | |
| self.num_frames = frame_indices[1] - frame_indices[0] | |
| video_frames = self._read_video(frame_indices[0]/self.get_num_frames()) | |
| video_frames = np.array([ | |
| np.uint8(video_frames[0].to_rgb().to_ndarray()), | |
| np.uint8(video_frames[-1].to_rgb().to_ndarray()) | |
| ]) | |
| return video_frames | |
| def read(self): | |
| video_frames = self._read_video(0) | |
| video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames]) | |
| return video_frames | |
| def get_num_frames(self): | |
| _, _, total_num_frames = self._compute_video_stats() | |
| return total_num_frames |