Spaces:
Sleeping
Sleeping
File size: 9,557 Bytes
0fd26a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | """
Dataset factory class (OnlineFeatures) and DataLoader wrapper (WebDataLoader).
"""
from PIL import Image
import os
import webdataset as wds
from data.transforms import DatasetFactory
from data.web_dataset import WebDatasetDataset
class WebDataLoader:
"""
encapsulates a unified interface for WebDataset and wds.WebLoader.
"""
@staticmethod
def create(dataset, vl_chat_processor=None, device=None,
pin_memory=True, persistent_workers=True):
if vl_chat_processor is not None:
dataset.set_vl_chat_processor(vl_chat_processor)
if device is not None:
dataset.set_device(device)
num_workers = dataset.num_workers
return wds.WebLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
)
class OnlineFeatures(DatasetFactory):
"""
The dataset factory class uses WebDatasetDataset to load data from tar files.
"""
def __init__(self, vis_image_root=None, train_tar_pattern=None, test_tar_pattern=None,
task='visual_instruction', cfg=False, resolution=256,
shuffle_buffer=300, resampled=True, split_data_by_node=True, estimated_samples_per_shard=1000,
vl_chat_processor=None, device=None, fid_stat_path=None,
num_workers=None, batch_size=None, test_batch_size=None, test_num_workers=None, sampling_weights=None,
**kwargs):
"""
Args:
vis_image_root: Root directory of the visualization image
train_tar_pattern: Training set tar file path pattern, supports braceexpand
test_tar_pattern: Test set tar file path pattern, supports braceexpand
task: Task type
cfg: Whether to use classifier-free guidance
resolution: Image resolution
shuffle_buffer: Size of the WebDataset shuffle buffer
resampled: Whether to use resampled mode (for distributed training)
split_data_by_node: Whether to distribute shards across multiple nodes
estimated_samples_per_shard: Estimated number of samples per shard
vl_chat_processor: VLChatProcessor instance (optional)
device: torch.device or string (optional)
num_workers: num_workers of the DataLoader (optional)
batch_size: Training set batch size (optional)
test_batch_size: Test set batch size (optional)
test_num_workers: Test set num_workers (optional)
sampling_weights: List of sampling weights (optional)
"""
super().__init__()
self.task = task
self.vis_image_root = vis_image_root
self.fid_stat_path = fid_stat_path
if train_tar_pattern is None:
raise ValueError("train_tar_pattern must be provided")
print(f'Creating WebDataset with pattern: {train_tar_pattern}')
self.train = WebDatasetDataset(
tar_pattern=train_tar_pattern,
resolution=resolution,
shuffle_buffer=shuffle_buffer,
resampled=resampled,
split_data_by_node_flag=split_data_by_node,
estimated_samples_per_shard=estimated_samples_per_shard,
vl_chat_processor=vl_chat_processor,
device=device,
num_workers=num_workers,
batch_size=batch_size,
sampling_weights=sampling_weights
)
test_batch_size_to_use = test_batch_size if test_batch_size is not None else batch_size
test_num_workers_to_use = test_num_workers if test_num_workers is not None else num_workers
self.test = WebDatasetDataset(
tar_pattern=test_tar_pattern,
resolution=resolution,
shuffle_buffer=100,
resampled=False,
split_data_by_node_flag=split_data_by_node,
allow_shared_shards=False,
estimated_samples_per_shard=estimated_samples_per_shard,
vl_chat_processor=vl_chat_processor,
device=device,
num_workers=test_num_workers_to_use,
batch_size=test_batch_size_to_use,
force_simple_mode=True,
enable_shuffle=False,
partial=True
)
assert not cfg
self.resolution = resolution
self.vis_image_paths = []
self.vis_output_paths = []
self._scan_vis_images(vis_image_root)
self._train_dataloader = None
self._test_dataloader = None
def set_vl_chat_processor(self, vl_chat_processor):
"""bind VLChatProcessor, and invalidate the cached DataLoader."""
self.train.set_vl_chat_processor(vl_chat_processor)
self.test.set_vl_chat_processor(vl_chat_processor)
self._train_dataloader = None
self._test_dataloader = None
def set_device(self, device):
"""bind target device, and invalidate the cached DataLoader."""
self.train.set_device(device)
self.test.set_device(device)
self._train_dataloader = None
self._test_dataloader = None
@property
def train_dataloader(self):
if self._train_dataloader is None:
self._train_dataloader = WebDataLoader.create(self.train)
return self._train_dataloader
@property
def test_dataloader(self):
if self._test_dataloader is None:
self._test_dataloader = WebDataLoader.create(self.test)
return self._test_dataloader
def _scan_vis_images(self, vis_image_root):
valid_extensions = {'.jpg', '.jpeg', '.png', '.JPEG', '.JPG', '.PNG'}
if not vis_image_root or not os.path.exists(vis_image_root):
if vis_image_root:
print(f"Warning: vis_image_root does not exist: {vis_image_root}")
return
input_dir = os.path.join(vis_image_root, 'input')
output_dir = os.path.join(vis_image_root, 'output')
if os.path.exists(input_dir):
print(f"Scanning input images in: {input_dir}")
for root, dirs, files in os.walk(input_dir):
for filename in files:
if any(filename.endswith(ext) for ext in valid_extensions):
self.vis_image_paths.append(os.path.join(root, filename))
self.vis_image_paths = sorted(self.vis_image_paths)
print(f"Found {len(self.vis_image_paths)} input images")
else:
print(f"Warning: input directory does not exist: {input_dir}")
if os.path.exists(output_dir):
print(f"Scanning output images in: {output_dir}")
for root, dirs, files in os.walk(output_dir):
for filename in files:
if any(filename.endswith(ext) for ext in valid_extensions):
self.vis_output_paths.append(os.path.join(root, filename))
self.vis_output_paths = sorted(self.vis_output_paths)
print(f"Found {len(self.vis_output_paths)} output images")
else:
print(f"Warning: output directory does not exist: {output_dir}")
if self.vis_image_paths and self.vis_output_paths:
input_map = {os.path.splitext(os.path.basename(p))[0]: p for p in self.vis_image_paths}
output_map = {os.path.splitext(os.path.basename(p))[0]: p for p in self.vis_output_paths}
matched_keys = sorted(set(input_map.keys()) & set(output_map.keys()))
self.vis_image_paths = [input_map[key] for key in matched_keys]
self.vis_output_paths = [output_map[key] for key in matched_keys]
print(f"Matched {len(self.vis_image_paths)} input-output image pairs")
print(f"Images will be loaded on-demand when needed.")
@staticmethod
def _load_images_parallel(paths, max_workers=8):
import concurrent.futures
if not paths:
return []
def load_image(image_path):
try:
pil_img = Image.open(image_path)
pil_img.load()
return pil_img.convert("RGB")
except Exception as e:
print(f"Warning: Failed to load image {image_path}: {e}")
return Image.new('RGB', (384, 384), color='black')
workers = min(len(paths), max_workers)
if workers > 1:
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
return list(executor.map(load_image, paths))
return [load_image(path) for path in paths]
def get_vis_images_as_pil(self, max_images=None):
paths = self.vis_image_paths[:max_images] if max_images else self.vis_image_paths
return self._load_images_parallel(paths)
def get_vis_output_images_as_pil(self, max_images=None):
paths = self.vis_output_paths[:max_images] if max_images else self.vis_output_paths
return self._load_images_parallel(paths)
@property
def data_shape(self):
if self.resolution == 512:
return 4, 64, 64
else:
return 4, 32, 32
@property
def fid_stat(self):
if self.fid_stat_path:
return self.fid_stat_path
return '/path/to/fid_stats_mscoco256_val.npz'
|