| | import csv |
| | import gc |
| | import io |
| | import json |
| | import math |
| | import os |
| | import random |
| | from contextlib import contextmanager |
| | from random import shuffle |
| | from threading import Thread |
| |
|
| | import albumentations |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms as transforms |
| | from decord import VideoReader |
| | from einops import rearrange |
| | from func_timeout import FunctionTimedOut, func_timeout |
| | from packaging import version as pver |
| | from PIL import Image |
| | from safetensors.torch import load_file |
| | from torch.utils.data import BatchSampler, Sampler |
| | from torch.utils.data.dataset import Dataset |
| |
|
| | from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, |
| | custom_meshgrid, get_random_mask, get_relative_pose, |
| | get_video_reader_batch, padding_image, process_pose_file, |
| | process_pose_params, ray_condition, resize_frame, |
| | resize_image_with_target_area) |
| |
|
| |
|
| | class ImageVideoSampler(BatchSampler): |
| | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
| | |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | dataset (Dataset): Dataset providing data information. |
| | batch_size (int): Size of mini-batch. |
| | drop_last (bool): If ``True``, the sampler will drop the last batch if |
| | its size would be less than ``batch_size``. |
| | aspect_ratios (dict): The predefined aspect ratios. |
| | """ |
| |
|
| | def __init__(self, |
| | sampler: Sampler, |
| | dataset: Dataset, |
| | batch_size: int, |
| | drop_last: bool = False |
| | ) -> None: |
| | if not isinstance(sampler, Sampler): |
| | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| | f'but got {sampler}') |
| | if not isinstance(batch_size, int) or batch_size <= 0: |
| | raise ValueError('batch_size should be a positive integer value, ' |
| | f'but got batch_size={batch_size}') |
| | self.sampler = sampler |
| | self.dataset = dataset |
| | self.batch_size = batch_size |
| | self.drop_last = drop_last |
| |
|
| | |
| | self.bucket = {'image':[], 'video':[]} |
| |
|
| | def __iter__(self): |
| | for idx in self.sampler: |
| | content_type = self.dataset.dataset[idx].get('type', 'image') |
| | self.bucket[content_type].append(idx) |
| |
|
| | |
| | if len(self.bucket['video']) == self.batch_size: |
| | bucket = self.bucket['video'] |
| | yield bucket[:] |
| | del bucket[:] |
| | elif len(self.bucket['image']) == self.batch_size: |
| | bucket = self.bucket['image'] |
| | yield bucket[:] |
| | del bucket[:] |
| |
|
| |
|
| | class ImageVideoDataset(Dataset): |
| | def __init__( |
| | self, |
| | ann_path, data_root=None, |
| | video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
| | image_sample_size=512, |
| | video_repeat=0, |
| | text_drop_ratio=0.1, |
| | enable_bucket=False, |
| | video_length_drop_start=0.0, |
| | video_length_drop_end=1.0, |
| | enable_inpaint=False, |
| | return_file_name=False, |
| | ): |
| | |
| | print(f"loading annotations from {ann_path} ...") |
| | if ann_path.endswith('.csv'): |
| | with open(ann_path, 'r') as csvfile: |
| | dataset = list(csv.DictReader(csvfile)) |
| | elif ann_path.endswith('.json'): |
| | dataset = json.load(open(ann_path)) |
| | |
| | self.data_root = data_root |
| |
|
| | |
| | if video_repeat > 0: |
| | self.dataset = [] |
| | for data in dataset: |
| | if data.get('type', 'image') != 'video': |
| | self.dataset.append(data) |
| | |
| | for _ in range(video_repeat): |
| | for data in dataset: |
| | if data.get('type', 'image') == 'video': |
| | self.dataset.append(data) |
| | else: |
| | self.dataset = dataset |
| | del dataset |
| |
|
| | self.length = len(self.dataset) |
| | print(f"data scale: {self.length}") |
| | |
| | self.enable_bucket = enable_bucket |
| | self.text_drop_ratio = text_drop_ratio |
| | self.enable_inpaint = enable_inpaint |
| | self.return_file_name = return_file_name |
| |
|
| | self.video_length_drop_start = video_length_drop_start |
| | self.video_length_drop_end = video_length_drop_end |
| |
|
| | |
| | self.video_sample_stride = video_sample_stride |
| | self.video_sample_n_frames = video_sample_n_frames |
| | self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
| | self.video_transforms = transforms.Compose( |
| | [ |
| | transforms.Resize(min(self.video_sample_size)), |
| | transforms.CenterCrop(self.video_sample_size), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| | ] |
| | ) |
| |
|
| | |
| | self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
| | self.image_transforms = transforms.Compose([ |
| | transforms.Resize(min(self.image_sample_size)), |
| | transforms.CenterCrop(self.image_sample_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
| | ]) |
| |
|
| | self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
| |
|
| | def get_batch(self, idx): |
| | data_info = self.dataset[idx % len(self.dataset)] |
| | |
| | if data_info.get('type', 'image')=='video': |
| | video_id, text = data_info['file_path'], data_info['text'] |
| |
|
| | if self.data_root is None: |
| | video_dir = video_id |
| | else: |
| | video_dir = os.path.join(self.data_root, video_id) |
| |
|
| | with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
| | min_sample_n_frames = min( |
| | self.video_sample_n_frames, |
| | int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
| | ) |
| | if min_sample_n_frames == 0: |
| | raise ValueError(f"No Frames in video.") |
| |
|
| | video_length = int(self.video_length_drop_end * len(video_reader)) |
| | clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
| | start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
| | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
| |
|
| | try: |
| | sample_args = (video_reader, batch_index) |
| | pixel_values = func_timeout( |
| | VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
| | ) |
| | resized_frames = [] |
| | for i in range(len(pixel_values)): |
| | frame = pixel_values[i] |
| | resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
| | resized_frames.append(resized_frame) |
| | pixel_values = np.array(resized_frames) |
| | except FunctionTimedOut: |
| | raise ValueError(f"Read {idx} timeout.") |
| | except Exception as e: |
| | raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
| |
|
| | if not self.enable_bucket: |
| | pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
| | pixel_values = pixel_values / 255. |
| | del video_reader |
| | else: |
| | pixel_values = pixel_values |
| |
|
| | if not self.enable_bucket: |
| | pixel_values = self.video_transforms(pixel_values) |
| | |
| | |
| | if random.random() < self.text_drop_ratio: |
| | text = '' |
| | return pixel_values, text, 'video', video_dir |
| | else: |
| | image_path, text = data_info['file_path'], data_info['text'] |
| | if self.data_root is not None: |
| | image_path = os.path.join(self.data_root, image_path) |
| | image = Image.open(image_path).convert('RGB') |
| | if not self.enable_bucket: |
| | image = self.image_transforms(image).unsqueeze(0) |
| | else: |
| | image = np.expand_dims(np.array(image), 0) |
| | if random.random() < self.text_drop_ratio: |
| | text = '' |
| | return image, text, 'image', image_path |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, idx): |
| | data_info = self.dataset[idx % len(self.dataset)] |
| | data_type = data_info.get('type', 'image') |
| | while True: |
| | sample = {} |
| | try: |
| | data_info_local = self.dataset[idx % len(self.dataset)] |
| | data_type_local = data_info_local.get('type', 'image') |
| | if data_type_local != data_type: |
| | raise ValueError("data_type_local != data_type") |
| |
|
| | pixel_values, name, data_type, file_path = self.get_batch(idx) |
| | sample["pixel_values"] = pixel_values |
| | sample["text"] = name |
| | sample["data_type"] = data_type |
| | sample["idx"] = idx |
| | if self.return_file_name: |
| | sample["file_name"] = os.path.basename(file_path) |
| | |
| | if len(sample) > 0: |
| | break |
| | except Exception as e: |
| | print(e, self.dataset[idx % len(self.dataset)]) |
| | idx = random.randint(0, self.length-1) |
| |
|
| | if self.enable_inpaint and not self.enable_bucket: |
| | mask = get_random_mask(pixel_values.size()) |
| | mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask |
| | sample["mask_pixel_values"] = mask_pixel_values |
| | sample["mask"] = mask |
| |
|
| | clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() |
| | clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 |
| | sample["clip_pixel_values"] = clip_pixel_values |
| |
|
| | return sample |
| |
|
| |
|
| | class ImageVideoControlDataset(Dataset): |
| | def __init__( |
| | self, |
| | ann_path, data_root=None, |
| | video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, |
| | image_sample_size=512, |
| | video_repeat=0, |
| | text_drop_ratio=0.1, |
| | enable_bucket=False, |
| | video_length_drop_start=0.1, |
| | video_length_drop_end=0.9, |
| | enable_inpaint=False, |
| | enable_camera_info=False, |
| | return_file_name=False, |
| | enable_subject_info=False, |
| | padding_subject_info=True, |
| | ): |
| | |
| | print(f"loading annotations from {ann_path} ...") |
| | if ann_path.endswith('.csv'): |
| | with open(ann_path, 'r') as csvfile: |
| | dataset = list(csv.DictReader(csvfile)) |
| | elif ann_path.endswith('.json'): |
| | dataset = json.load(open(ann_path)) |
| | |
| | self.data_root = data_root |
| |
|
| | |
| | if video_repeat > 0: |
| | self.dataset = [] |
| | for data in dataset: |
| | if data.get('type', 'image') != 'video': |
| | self.dataset.append(data) |
| | |
| | for _ in range(video_repeat): |
| | for data in dataset: |
| | if data.get('type', 'image') == 'video': |
| | self.dataset.append(data) |
| | else: |
| | self.dataset = dataset |
| | del dataset |
| |
|
| | self.length = len(self.dataset) |
| | print(f"data scale: {self.length}") |
| | |
| | self.enable_bucket = enable_bucket |
| | self.text_drop_ratio = text_drop_ratio |
| | self.enable_inpaint = enable_inpaint |
| | self.enable_camera_info = enable_camera_info |
| | self.enable_subject_info = enable_subject_info |
| | self.padding_subject_info = padding_subject_info |
| |
|
| | self.video_length_drop_start = video_length_drop_start |
| | self.video_length_drop_end = video_length_drop_end |
| |
|
| | |
| | self.video_sample_stride = video_sample_stride |
| | self.video_sample_n_frames = video_sample_n_frames |
| | self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) |
| | self.video_transforms = transforms.Compose( |
| | [ |
| | transforms.Resize(min(self.video_sample_size)), |
| | transforms.CenterCrop(self.video_sample_size), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| | ] |
| | ) |
| | if self.enable_camera_info: |
| | self.video_transforms_camera = transforms.Compose( |
| | [ |
| | transforms.Resize(min(self.video_sample_size)), |
| | transforms.CenterCrop(self.video_sample_size) |
| | ] |
| | ) |
| |
|
| | |
| | self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) |
| | self.image_transforms = transforms.Compose([ |
| | transforms.Resize(min(self.image_sample_size)), |
| | transforms.CenterCrop(self.image_sample_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) |
| | ]) |
| |
|
| | self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) |
| | |
| | def get_batch(self, idx): |
| | data_info = self.dataset[idx % len(self.dataset)] |
| | video_id, text = data_info['file_path'], data_info['text'] |
| |
|
| | if data_info.get('type', 'image')=='video': |
| | if self.data_root is None: |
| | video_dir = video_id |
| | else: |
| | video_dir = os.path.join(self.data_root, video_id) |
| |
|
| | with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: |
| | min_sample_n_frames = min( |
| | self.video_sample_n_frames, |
| | int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) |
| | ) |
| | if min_sample_n_frames == 0: |
| | raise ValueError(f"No Frames in video.") |
| |
|
| | video_length = int(self.video_length_drop_end * len(video_reader)) |
| | clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) |
| | start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 |
| | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) |
| |
|
| | try: |
| | sample_args = (video_reader, batch_index) |
| | pixel_values = func_timeout( |
| | VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
| | ) |
| | resized_frames = [] |
| | for i in range(len(pixel_values)): |
| | frame = pixel_values[i] |
| | resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
| | resized_frames.append(resized_frame) |
| | pixel_values = np.array(resized_frames) |
| | except FunctionTimedOut: |
| | raise ValueError(f"Read {idx} timeout.") |
| | except Exception as e: |
| | raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
| |
|
| | if not self.enable_bucket: |
| | pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() |
| | pixel_values = pixel_values / 255. |
| | del video_reader |
| | else: |
| | pixel_values = pixel_values |
| |
|
| | if not self.enable_bucket: |
| | pixel_values = self.video_transforms(pixel_values) |
| | |
| | |
| | if random.random() < self.text_drop_ratio: |
| | text = '' |
| |
|
| | control_video_id = data_info['control_file_path'] |
| | |
| | if control_video_id is not None: |
| | if self.data_root is None: |
| | control_video_id = control_video_id |
| | else: |
| | control_video_id = os.path.join(self.data_root, control_video_id) |
| | |
| | if self.enable_camera_info: |
| | if control_video_id.lower().endswith('.txt'): |
| | if not self.enable_bucket: |
| | control_pixel_values = torch.zeros_like(pixel_values) |
| |
|
| | control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) |
| | control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() |
| | control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) |
| | control_camera_values = self.video_transforms_camera(control_camera_values) |
| | else: |
| | control_pixel_values = np.zeros_like(pixel_values) |
| |
|
| | control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) |
| | control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) |
| | control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] |
| | control_camera_values = np.array([control_camera_values[index] for index in batch_index]) |
| | else: |
| | if not self.enable_bucket: |
| | control_pixel_values = torch.zeros_like(pixel_values) |
| | control_camera_values = None |
| | else: |
| | control_pixel_values = np.zeros_like(pixel_values) |
| | control_camera_values = None |
| | else: |
| | if control_video_id is not None: |
| | with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: |
| | try: |
| | sample_args = (control_video_reader, batch_index) |
| | control_pixel_values = func_timeout( |
| | VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args |
| | ) |
| | resized_frames = [] |
| | for i in range(len(control_pixel_values)): |
| | frame = control_pixel_values[i] |
| | resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) |
| | resized_frames.append(resized_frame) |
| | control_pixel_values = np.array(resized_frames) |
| | except FunctionTimedOut: |
| | raise ValueError(f"Read {idx} timeout.") |
| | except Exception as e: |
| | raise ValueError(f"Failed to extract frames from video. Error is {e}.") |
| |
|
| | if not self.enable_bucket: |
| | control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() |
| | control_pixel_values = control_pixel_values / 255. |
| | del control_video_reader |
| | else: |
| | control_pixel_values = control_pixel_values |
| |
|
| | if not self.enable_bucket: |
| | control_pixel_values = self.video_transforms(control_pixel_values) |
| | else: |
| | if not self.enable_bucket: |
| | control_pixel_values = torch.zeros_like(pixel_values) |
| | else: |
| | control_pixel_values = np.zeros_like(pixel_values) |
| | control_camera_values = None |
| | |
| | if self.enable_subject_info: |
| | if not self.enable_bucket: |
| | visual_height, visual_width = pixel_values.shape[-2:] |
| | else: |
| | visual_height, visual_width = pixel_values.shape[1:3] |
| |
|
| | subject_id = data_info.get('object_file_path', []) |
| | shuffle(subject_id) |
| | subject_images = [] |
| | for i in range(min(len(subject_id), 4)): |
| | subject_image = Image.open(subject_id[i]) |
| | width, height = subject_image.size |
| | total_pixels = width * height |
| |
|
| | if self.padding_subject_info: |
| | img = padding_image(subject_image, visual_width, visual_height) |
| | else: |
| | img = resize_image_with_target_area(subject_image, 1024 * 1024) |
| |
|
| | if random.random() < 0.5: |
| | img = img.transpose(Image.FLIP_LEFT_RIGHT) |
| | subject_images.append(np.array(img)) |
| | if self.padding_subject_info: |
| | subject_image = np.array(subject_images) |
| | else: |
| | subject_image = subject_images |
| | else: |
| | subject_image = None |
| |
|
| | return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video" |
| | else: |
| | image_path, text = data_info['file_path'], data_info['text'] |
| | if self.data_root is not None: |
| | image_path = os.path.join(self.data_root, image_path) |
| | image = Image.open(image_path).convert('RGB') |
| | if not self.enable_bucket: |
| | image = self.image_transforms(image).unsqueeze(0) |
| | else: |
| | image = np.expand_dims(np.array(image), 0) |
| |
|
| | if random.random() < self.text_drop_ratio: |
| | text = '' |
| |
|
| | control_image_id = data_info['control_file_path'] |
| |
|
| | if self.data_root is None: |
| | control_image_id = control_image_id |
| | else: |
| | control_image_id = os.path.join(self.data_root, control_image_id) |
| |
|
| | control_image = Image.open(control_image_id).convert('RGB') |
| | if not self.enable_bucket: |
| | control_image = self.image_transforms(control_image).unsqueeze(0) |
| | else: |
| | control_image = np.expand_dims(np.array(control_image), 0) |
| | |
| | if self.enable_subject_info: |
| | if not self.enable_bucket: |
| | visual_height, visual_width = image.shape[-2:] |
| | else: |
| | visual_height, visual_width = image.shape[1:3] |
| |
|
| | subject_id = data_info.get('object_file_path', []) |
| | shuffle(subject_id) |
| | subject_images = [] |
| | for i in range(min(len(subject_id), 4)): |
| | subject_image = Image.open(subject_id[i]).convert('RGB') |
| | width, height = subject_image.size |
| | total_pixels = width * height |
| |
|
| | if self.padding_subject_info: |
| | img = padding_image(subject_image, visual_width, visual_height) |
| | else: |
| | img = resize_image_with_target_area(subject_image, 1024 * 1024) |
| |
|
| | if random.random() < 0.5: |
| | img = img.transpose(Image.FLIP_LEFT_RIGHT) |
| | subject_images.append(np.array(img)) |
| | if self.padding_subject_info: |
| | subject_image = np.array(subject_images) |
| | else: |
| | subject_image = subject_images |
| | else: |
| | subject_image = None |
| |
|
| | return image, control_image, subject_image, None, text, 'image' |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, idx): |
| | data_info = self.dataset[idx % len(self.dataset)] |
| | data_type = data_info.get('type', 'image') |
| | while True: |
| | sample = {} |
| | try: |
| | data_info_local = self.dataset[idx % len(self.dataset)] |
| | data_type_local = data_info_local.get('type', 'image') |
| | if data_type_local != data_type: |
| | raise ValueError("data_type_local != data_type") |
| |
|
| | pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx) |
| |
|
| | sample["pixel_values"] = pixel_values |
| | sample["control_pixel_values"] = control_pixel_values |
| | sample["subject_image"] = subject_image |
| | sample["text"] = name |
| | sample["data_type"] = data_type |
| | sample["idx"] = idx |
| |
|
| | if self.enable_camera_info: |
| | sample["control_camera_values"] = control_camera_values |
| |
|
| | if len(sample) > 0: |
| | break |
| | except Exception as e: |
| | print(e, self.dataset[idx % len(self.dataset)]) |
| | idx = random.randint(0, self.length-1) |
| |
|
| | if self.enable_inpaint and not self.enable_bucket: |
| | mask = get_random_mask(pixel_values.size()) |
| | mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask |
| | sample["mask_pixel_values"] = mask_pixel_values |
| | sample["mask"] = mask |
| |
|
| | clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() |
| | clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 |
| | sample["clip_pixel_values"] = clip_pixel_values |
| |
|
| | return sample |
| |
|
| |
|
| | class ImageVideoSafetensorsDataset(Dataset): |
| | def __init__( |
| | self, |
| | ann_path, |
| | data_root=None, |
| | ): |
| | |
| | print(f"loading annotations from {ann_path} ...") |
| | if ann_path.endswith('.json'): |
| | dataset = json.load(open(ann_path)) |
| |
|
| | self.data_root = data_root |
| | self.dataset = dataset |
| | self.length = len(self.dataset) |
| | print(f"data scale: {self.length}") |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, idx): |
| | if self.data_root is None: |
| | path = self.dataset[idx]["file_path"] |
| | else: |
| | path = os.path.join(self.data_root, self.dataset[idx]["file_path"]) |
| | state_dict = load_file(path) |
| | return state_dict |
| |
|
| |
|
| | class TextDataset(Dataset): |
| | def __init__(self, ann_path, text_drop_ratio=0.0): |
| | print(f"loading annotations from {ann_path} ...") |
| | with open(ann_path, 'r') as f: |
| | self.dataset = json.load(f) |
| | self.length = len(self.dataset) |
| | print(f"data scale: {self.length}") |
| | self.text_drop_ratio = text_drop_ratio |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, idx): |
| | while True: |
| | try: |
| | item = self.dataset[idx] |
| | text = item['text'] |
| |
|
| | |
| | if random.random() < self.text_drop_ratio: |
| | text = '' |
| |
|
| | sample = { |
| | "text": text, |
| | "idx": idx |
| | } |
| | return sample |
| |
|
| | except Exception as e: |
| | print(f"Error at index {idx}: {e}, retrying with random index...") |
| | idx = np.random.randint(0, self.length - 1) |