Spaces:
Sleeping
Sleeping
| """ | |
| Dataset factory class (OnlineFeatures) and DataLoader wrapper (WebDataLoader). | |
| """ | |
| from PIL import Image | |
| import os | |
| import webdataset as wds | |
| from data.transforms import DatasetFactory | |
| from data.web_dataset import WebDatasetDataset | |
| class WebDataLoader: | |
| """ | |
| encapsulates a unified interface for WebDataset and wds.WebLoader. | |
| """ | |
| def create(dataset, vl_chat_processor=None, device=None, | |
| pin_memory=True, persistent_workers=True): | |
| if vl_chat_processor is not None: | |
| dataset.set_vl_chat_processor(vl_chat_processor) | |
| if device is not None: | |
| dataset.set_device(device) | |
| num_workers = dataset.num_workers | |
| return wds.WebLoader( | |
| dataset, | |
| batch_size=None, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| persistent_workers=persistent_workers if num_workers > 0 else False, | |
| ) | |
| class OnlineFeatures(DatasetFactory): | |
| """ | |
| The dataset factory class uses WebDatasetDataset to load data from tar files. | |
| """ | |
| def __init__(self, vis_image_root=None, train_tar_pattern=None, test_tar_pattern=None, | |
| task='visual_instruction', cfg=False, resolution=256, | |
| shuffle_buffer=300, resampled=True, split_data_by_node=True, estimated_samples_per_shard=1000, | |
| vl_chat_processor=None, device=None, fid_stat_path=None, | |
| num_workers=None, batch_size=None, test_batch_size=None, test_num_workers=None, sampling_weights=None, | |
| **kwargs): | |
| """ | |
| Args: | |
| vis_image_root: Root directory of the visualization image | |
| train_tar_pattern: Training set tar file path pattern, supports braceexpand | |
| test_tar_pattern: Test set tar file path pattern, supports braceexpand | |
| task: Task type | |
| cfg: Whether to use classifier-free guidance | |
| resolution: Image resolution | |
| shuffle_buffer: Size of the WebDataset shuffle buffer | |
| resampled: Whether to use resampled mode (for distributed training) | |
| split_data_by_node: Whether to distribute shards across multiple nodes | |
| estimated_samples_per_shard: Estimated number of samples per shard | |
| vl_chat_processor: VLChatProcessor instance (optional) | |
| device: torch.device or string (optional) | |
| num_workers: num_workers of the DataLoader (optional) | |
| batch_size: Training set batch size (optional) | |
| test_batch_size: Test set batch size (optional) | |
| test_num_workers: Test set num_workers (optional) | |
| sampling_weights: List of sampling weights (optional) | |
| """ | |
| super().__init__() | |
| self.task = task | |
| self.vis_image_root = vis_image_root | |
| self.fid_stat_path = fid_stat_path | |
| if train_tar_pattern is None: | |
| raise ValueError("train_tar_pattern must be provided") | |
| print(f'Creating WebDataset with pattern: {train_tar_pattern}') | |
| self.train = WebDatasetDataset( | |
| tar_pattern=train_tar_pattern, | |
| resolution=resolution, | |
| shuffle_buffer=shuffle_buffer, | |
| resampled=resampled, | |
| split_data_by_node_flag=split_data_by_node, | |
| estimated_samples_per_shard=estimated_samples_per_shard, | |
| vl_chat_processor=vl_chat_processor, | |
| device=device, | |
| num_workers=num_workers, | |
| batch_size=batch_size, | |
| sampling_weights=sampling_weights | |
| ) | |
| test_batch_size_to_use = test_batch_size if test_batch_size is not None else batch_size | |
| test_num_workers_to_use = test_num_workers if test_num_workers is not None else num_workers | |
| self.test = WebDatasetDataset( | |
| tar_pattern=test_tar_pattern, | |
| resolution=resolution, | |
| shuffle_buffer=100, | |
| resampled=False, | |
| split_data_by_node_flag=split_data_by_node, | |
| allow_shared_shards=False, | |
| estimated_samples_per_shard=estimated_samples_per_shard, | |
| vl_chat_processor=vl_chat_processor, | |
| device=device, | |
| num_workers=test_num_workers_to_use, | |
| batch_size=test_batch_size_to_use, | |
| force_simple_mode=True, | |
| enable_shuffle=False, | |
| partial=True | |
| ) | |
| assert not cfg | |
| self.resolution = resolution | |
| self.vis_image_paths = [] | |
| self.vis_output_paths = [] | |
| self._scan_vis_images(vis_image_root) | |
| self._train_dataloader = None | |
| self._test_dataloader = None | |
| def set_vl_chat_processor(self, vl_chat_processor): | |
| """bind VLChatProcessor, and invalidate the cached DataLoader.""" | |
| self.train.set_vl_chat_processor(vl_chat_processor) | |
| self.test.set_vl_chat_processor(vl_chat_processor) | |
| self._train_dataloader = None | |
| self._test_dataloader = None | |
| def set_device(self, device): | |
| """bind target device, and invalidate the cached DataLoader.""" | |
| self.train.set_device(device) | |
| self.test.set_device(device) | |
| self._train_dataloader = None | |
| self._test_dataloader = None | |
| def train_dataloader(self): | |
| if self._train_dataloader is None: | |
| self._train_dataloader = WebDataLoader.create(self.train) | |
| return self._train_dataloader | |
| def test_dataloader(self): | |
| if self._test_dataloader is None: | |
| self._test_dataloader = WebDataLoader.create(self.test) | |
| return self._test_dataloader | |
| def _scan_vis_images(self, vis_image_root): | |
| valid_extensions = {'.jpg', '.jpeg', '.png', '.JPEG', '.JPG', '.PNG'} | |
| if not vis_image_root or not os.path.exists(vis_image_root): | |
| if vis_image_root: | |
| print(f"Warning: vis_image_root does not exist: {vis_image_root}") | |
| return | |
| input_dir = os.path.join(vis_image_root, 'input') | |
| output_dir = os.path.join(vis_image_root, 'output') | |
| if os.path.exists(input_dir): | |
| print(f"Scanning input images in: {input_dir}") | |
| for root, dirs, files in os.walk(input_dir): | |
| for filename in files: | |
| if any(filename.endswith(ext) for ext in valid_extensions): | |
| self.vis_image_paths.append(os.path.join(root, filename)) | |
| self.vis_image_paths = sorted(self.vis_image_paths) | |
| print(f"Found {len(self.vis_image_paths)} input images") | |
| else: | |
| print(f"Warning: input directory does not exist: {input_dir}") | |
| if os.path.exists(output_dir): | |
| print(f"Scanning output images in: {output_dir}") | |
| for root, dirs, files in os.walk(output_dir): | |
| for filename in files: | |
| if any(filename.endswith(ext) for ext in valid_extensions): | |
| self.vis_output_paths.append(os.path.join(root, filename)) | |
| self.vis_output_paths = sorted(self.vis_output_paths) | |
| print(f"Found {len(self.vis_output_paths)} output images") | |
| else: | |
| print(f"Warning: output directory does not exist: {output_dir}") | |
| if self.vis_image_paths and self.vis_output_paths: | |
| input_map = {os.path.splitext(os.path.basename(p))[0]: p for p in self.vis_image_paths} | |
| output_map = {os.path.splitext(os.path.basename(p))[0]: p for p in self.vis_output_paths} | |
| matched_keys = sorted(set(input_map.keys()) & set(output_map.keys())) | |
| self.vis_image_paths = [input_map[key] for key in matched_keys] | |
| self.vis_output_paths = [output_map[key] for key in matched_keys] | |
| print(f"Matched {len(self.vis_image_paths)} input-output image pairs") | |
| print(f"Images will be loaded on-demand when needed.") | |
| def _load_images_parallel(paths, max_workers=8): | |
| import concurrent.futures | |
| if not paths: | |
| return [] | |
| def load_image(image_path): | |
| try: | |
| pil_img = Image.open(image_path) | |
| pil_img.load() | |
| return pil_img.convert("RGB") | |
| except Exception as e: | |
| print(f"Warning: Failed to load image {image_path}: {e}") | |
| return Image.new('RGB', (384, 384), color='black') | |
| workers = min(len(paths), max_workers) | |
| if workers > 1: | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: | |
| return list(executor.map(load_image, paths)) | |
| return [load_image(path) for path in paths] | |
| def get_vis_images_as_pil(self, max_images=None): | |
| paths = self.vis_image_paths[:max_images] if max_images else self.vis_image_paths | |
| return self._load_images_parallel(paths) | |
| def get_vis_output_images_as_pil(self, max_images=None): | |
| paths = self.vis_output_paths[:max_images] if max_images else self.vis_output_paths | |
| return self._load_images_parallel(paths) | |
| def data_shape(self): | |
| if self.resolution == 512: | |
| return 4, 64, 64 | |
| else: | |
| return 4, 32, 32 | |
| def fid_stat(self): | |
| if self.fid_stat_path: | |
| return self.fid_stat_path | |
| return '/path/to/fid_stats_mscoco256_val.npz' | |