Spaces:
Configuration error
Configuration error
| import io | |
| import os | |
| import sys | |
| from functools import partial | |
| import math | |
| import torchvision.transforms as TT | |
| from webds import MetaDistributedWebDataset | |
| import random | |
| from fractions import Fraction | |
| from typing import Union, Optional, Dict, Any, Tuple | |
| from torchvision.io.video import av | |
| import numpy as np | |
| import torch | |
| from torchvision.io import _video_opt | |
| from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames | |
| from torchvision.transforms.functional import center_crop, resize | |
| from torchvision.transforms import InterpolationMode | |
| import decord | |
| from decord import VideoReader | |
| from torch.utils.data import Dataset | |
| def read_video( | |
| filename: str, | |
| start_pts: Union[float, Fraction] = 0, | |
| end_pts: Optional[Union[float, Fraction]] = None, | |
| pts_unit: str = "pts", | |
| output_format: str = "THWC", | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: | |
| """ | |
| Reads a video from a file, returning both the video frames and the audio frames | |
| Args: | |
| filename (str): path to the video file | |
| start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): | |
| The start presentation time of the video | |
| end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): | |
| The end presentation time | |
| pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, | |
| either 'pts' or 'sec'. Defaults to 'pts'. | |
| output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". | |
| Returns: | |
| vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames | |
| aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points | |
| info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) | |
| """ | |
| output_format = output_format.upper() | |
| if output_format not in ("THWC", "TCHW"): | |
| raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") | |
| _check_av_available() | |
| if end_pts is None: | |
| end_pts = float("inf") | |
| if end_pts < start_pts: | |
| raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") | |
| info = {} | |
| audio_frames = [] | |
| audio_timebase = _video_opt.default_timebase | |
| with av.open(filename, metadata_errors="ignore") as container: | |
| if container.streams.audio: | |
| audio_timebase = container.streams.audio[0].time_base | |
| if container.streams.video: | |
| video_frames = _read_from_stream( | |
| container, | |
| start_pts, | |
| end_pts, | |
| pts_unit, | |
| container.streams.video[0], | |
| {"video": 0}, | |
| ) | |
| video_fps = container.streams.video[0].average_rate | |
| # guard against potentially corrupted files | |
| if video_fps is not None: | |
| info["video_fps"] = float(video_fps) | |
| if container.streams.audio: | |
| audio_frames = _read_from_stream( | |
| container, | |
| start_pts, | |
| end_pts, | |
| pts_unit, | |
| container.streams.audio[0], | |
| {"audio": 0}, | |
| ) | |
| info["audio_fps"] = container.streams.audio[0].rate | |
| aframes_list = [frame.to_ndarray() for frame in audio_frames] | |
| vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) | |
| if aframes_list: | |
| aframes = np.concatenate(aframes_list, 1) | |
| aframes = torch.as_tensor(aframes) | |
| if pts_unit == "sec": | |
| start_pts = int(math.floor(start_pts * (1 / audio_timebase))) | |
| if end_pts != float("inf"): | |
| end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) | |
| aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) | |
| else: | |
| aframes = torch.empty((1, 0), dtype=torch.float32) | |
| if output_format == "TCHW": | |
| # [T,H,W,C] --> [T,C,H,W] | |
| vframes = vframes.permute(0, 3, 1, 2) | |
| return vframes, aframes, info | |
| def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): | |
| if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: | |
| arr = resize( | |
| arr, | |
| size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], | |
| interpolation=InterpolationMode.BICUBIC, | |
| ) | |
| else: | |
| arr = resize( | |
| arr, | |
| size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], | |
| interpolation=InterpolationMode.BICUBIC, | |
| ) | |
| h, w = arr.shape[2], arr.shape[3] | |
| arr = arr.squeeze(0) | |
| delta_h = h - image_size[0] | |
| delta_w = w - image_size[1] | |
| if reshape_mode == "random" or reshape_mode == "none": | |
| top = np.random.randint(0, delta_h + 1) | |
| left = np.random.randint(0, delta_w + 1) | |
| elif reshape_mode == "center": | |
| top, left = delta_h // 2, delta_w // 2 | |
| else: | |
| raise NotImplementedError | |
| arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) | |
| return arr | |
| def pad_last_frame(tensor, num_frames): | |
| # T, H, W, C | |
| if len(tensor) < num_frames: | |
| pad_length = num_frames - len(tensor) | |
| # Use the last frame to pad instead of zero | |
| last_frame = tensor[-1] | |
| pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:]) | |
| padded_tensor = torch.cat([tensor, pad_tensor], dim=0) | |
| return padded_tensor | |
| else: | |
| return tensor[:num_frames] | |
| def load_video( | |
| video_data, | |
| sampling="uniform", | |
| duration=None, | |
| num_frames=4, | |
| wanted_fps=None, | |
| actual_fps=None, | |
| skip_frms_num=0.0, | |
| nb_read_frames=None, | |
| ): | |
| decord.bridge.set_bridge("torch") | |
| vr = VideoReader(uri=video_data, height=-1, width=-1) | |
| if nb_read_frames is not None: | |
| ori_vlen = nb_read_frames | |
| else: | |
| ori_vlen = min(int(duration * actual_fps) - 1, len(vr)) | |
| max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps) | |
| start = random.randint(skip_frms_num, max_seek + 1) | |
| end = int(start + num_frames / wanted_fps * actual_fps) | |
| n_frms = num_frames | |
| if sampling == "uniform": | |
| indices = np.arange(start, end, (end - start) / n_frms).astype(int) | |
| else: | |
| raise NotImplementedError | |
| # get_batch -> T, H, W, C | |
| temp_frms = vr.get_batch(np.arange(start, end)) | |
| assert temp_frms is not None | |
| tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
| tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
| return pad_last_frame(tensor_frms, num_frames) | |
| import threading | |
| def load_video_with_timeout(*args, **kwargs): | |
| video_container = {} | |
| def target_function(): | |
| video = load_video(*args, **kwargs) | |
| video_container["video"] = video | |
| thread = threading.Thread(target=target_function) | |
| thread.start() | |
| timeout = 20 | |
| thread.join(timeout) | |
| if thread.is_alive(): | |
| print("Loading video timed out") | |
| raise TimeoutError | |
| return video_container.get("video", None).contiguous() | |
| def process_video( | |
| video_path, | |
| image_size=None, | |
| duration=None, | |
| num_frames=4, | |
| wanted_fps=None, | |
| actual_fps=None, | |
| skip_frms_num=0.0, | |
| nb_read_frames=None, | |
| ): | |
| """ | |
| video_path: str or io.BytesIO | |
| image_size: . | |
| duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown. | |
| num_frames: wanted num_frames. | |
| wanted_fps: . | |
| skip_frms_num: ignore the first and the last xx frames, avoiding transitions. | |
| """ | |
| video = load_video_with_timeout( | |
| video_path, | |
| duration=duration, | |
| num_frames=num_frames, | |
| wanted_fps=wanted_fps, | |
| actual_fps=actual_fps, | |
| skip_frms_num=skip_frms_num, | |
| nb_read_frames=nb_read_frames, | |
| ) | |
| # --- copy and modify the image process --- | |
| video = video.permute(0, 3, 1, 2) # [T, C, H, W] | |
| # resize | |
| if image_size is not None: | |
| video = resize_for_rectangle_crop(video, image_size, reshape_mode="center") | |
| return video | |
| def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"): | |
| while True: | |
| r = next(src) | |
| if "mp4" in r: | |
| video_data = r["mp4"] | |
| elif "avi" in r: | |
| video_data = r["avi"] | |
| else: | |
| print("No video data found") | |
| continue | |
| if txt_key not in r: | |
| txt = "" | |
| else: | |
| txt = r[txt_key] | |
| if isinstance(txt, bytes): | |
| txt = txt.decode("utf-8") | |
| else: | |
| txt = str(txt) | |
| duration = r.get("duration", None) | |
| if duration is not None: | |
| duration = float(duration) | |
| else: | |
| continue | |
| actual_fps = r.get("fps", None) | |
| if actual_fps is not None: | |
| actual_fps = float(actual_fps) | |
| else: | |
| continue | |
| required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num | |
| required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps | |
| if duration is not None and duration < required_duration: | |
| continue | |
| try: | |
| frames = process_video( | |
| io.BytesIO(video_data), | |
| num_frames=num_frames, | |
| wanted_fps=fps, | |
| image_size=image_size, | |
| duration=duration, | |
| actual_fps=actual_fps, | |
| skip_frms_num=skip_frms_num, | |
| ) | |
| frames = (frames - 127.5) / 127.5 | |
| except Exception as e: | |
| print(e) | |
| continue | |
| item = { | |
| "mp4": frames, | |
| "txt": txt, | |
| "num_frames": num_frames, | |
| "fps": fps, | |
| } | |
| yield item | |
| class VideoDataset(MetaDistributedWebDataset): | |
| def __init__( | |
| self, | |
| path, | |
| image_size, | |
| num_frames, | |
| fps, | |
| skip_frms_num=0.0, | |
| nshards=sys.maxsize, | |
| seed=1, | |
| meta_names=None, | |
| shuffle_buffer=1000, | |
| include_dirs=None, | |
| txt_key="caption", | |
| **kwargs, | |
| ): | |
| if seed == -1: | |
| seed = random.randint(0, 1000000) | |
| if meta_names is None: | |
| meta_names = [] | |
| if path.startswith(";"): | |
| path, include_dirs = path.split(";", 1) | |
| super().__init__( | |
| path, | |
| partial( | |
| process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num | |
| ), | |
| seed, | |
| meta_names=meta_names, | |
| shuffle_buffer=shuffle_buffer, | |
| nshards=nshards, | |
| include_dirs=include_dirs, | |
| ) | |
| def create_dataset_function(cls, path, args, **kwargs): | |
| return cls(path, **kwargs) | |
| class SFTDataset(Dataset): | |
| def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3): | |
| """ | |
| skip_frms_num: ignore the first and the last xx frames, avoiding transitions. | |
| """ | |
| super(SFTDataset, self).__init__() | |
| self.video_size = video_size | |
| self.fps = fps | |
| self.max_num_frames = max_num_frames | |
| self.skip_frms_num = skip_frms_num | |
| self.video_paths = [] | |
| self.captions = [] | |
| for root, dirnames, filenames in os.walk(data_dir): | |
| for filename in filenames: | |
| if filename.endswith(".mp4"): | |
| video_path = os.path.join(root, filename) | |
| self.video_paths.append(video_path) | |
| caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") | |
| if os.path.exists(caption_path): | |
| caption = open(caption_path, "r").read().splitlines()[0] | |
| else: | |
| caption = "" | |
| self.captions.append(caption) | |
| def __getitem__(self, index): | |
| decord.bridge.set_bridge("torch") | |
| video_path = self.video_paths[index] | |
| vr = VideoReader(uri=video_path, height=-1, width=-1) | |
| actual_fps = vr.get_avg_fps() | |
| ori_vlen = len(vr) | |
| if ori_vlen / actual_fps * self.fps > self.max_num_frames: | |
| num_frames = self.max_num_frames | |
| start = int(self.skip_frms_num) | |
| end = int(start + num_frames / self.fps * actual_fps) | |
| end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen)) | |
| indices = np.arange(start, end, (end - start) // num_frames).astype(int) | |
| temp_frms = vr.get_batch(np.arange(start, end_safty)) | |
| assert temp_frms is not None | |
| tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
| tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
| else: | |
| if ori_vlen > self.max_num_frames: | |
| num_frames = self.max_num_frames | |
| start = int(self.skip_frms_num) | |
| end = int(ori_vlen - self.skip_frms_num) | |
| indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) | |
| temp_frms = vr.get_batch(np.arange(start, end)) | |
| assert temp_frms is not None | |
| tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
| tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
| else: | |
| def nearest_smaller_4k_plus_1(n): | |
| remainder = n % 4 | |
| if remainder == 0: | |
| return n - 3 | |
| else: | |
| return n - remainder + 1 | |
| start = int(self.skip_frms_num) | |
| end = int(ori_vlen - self.skip_frms_num) | |
| num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1 | |
| end = int(start + num_frames) | |
| temp_frms = vr.get_batch(np.arange(start, end)) | |
| assert temp_frms is not None | |
| tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
| tensor_frms = pad_last_frame( | |
| tensor_frms, self.max_num_frames | |
| ) # the len of indices may be less than num_frames, due to round error | |
| tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] | |
| tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") | |
| tensor_frms = (tensor_frms - 127.5) / 127.5 | |
| item = { | |
| "mp4": tensor_frms, | |
| "txt": self.captions[index], | |
| "num_frames": num_frames, | |
| "fps": self.fps, | |
| } | |
| return item | |
| def __len__(self): | |
| return len(self.video_paths) | |
| def create_dataset_function(cls, path, args, **kwargs): | |
| return cls(data_dir=path, **kwargs) | |