Spaces:
Build error
Build error
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import io | |
| import json | |
| import pyarrow.parquet as pq | |
| import random | |
| from PIL import Image | |
| from .data_utils import pil_img2rgb | |
| from .distributed_iterable_dataset import DistributedIterableDataset | |
| from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs | |
| Image.MAX_IMAGE_PIXELS = 20_000_000 | |
| class T2IIterableDataset(DistributedIterableDataset): | |
| def __init__( | |
| self, dataset_name, transform, tokenizer, data_dir_list, num_used_data, | |
| local_rank=0, world_size=1, num_workers=8, data_status=None, | |
| ): | |
| """ | |
| data_dir_list: list of data directories contains parquet files | |
| num_used_data: list of number of sampled data paths for each data directory | |
| """ | |
| super().__init__(dataset_name, local_rank, world_size, num_workers) | |
| self.transform = transform | |
| self.tokenizer = tokenizer | |
| self.data_status = data_status | |
| self.data_paths = self.get_data_paths(data_dir_list, num_used_data) | |
| self.set_epoch() | |
| def get_data_paths(self, data_dir_list, num_used_data): | |
| return get_parquet_data_paths(data_dir_list, num_used_data) | |
| def __iter__(self): | |
| data_paths_per_worker, worker_id = self.get_data_paths_per_worker() | |
| if self.data_status is not None: | |
| parquet_start_id = self.data_status[worker_id][0] | |
| row_group_start_id = self.data_status[worker_id][1] | |
| row_start_id = self.data_status[worker_id][2] + 1 | |
| else: | |
| parquet_start_id = 0 | |
| row_group_start_id = 0 | |
| row_start_id = 0 | |
| transform_stride = self.transform.stride | |
| print( | |
| f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " | |
| f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}" | |
| ) | |
| while True: | |
| data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:] | |
| for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id): | |
| fs = init_arrow_pf_fs(parquet_file_path) | |
| with fs.open_input_file(parquet_file_path) as f: | |
| fr = pq.ParquetFile(f) | |
| row_group_ids = list(range(fr.num_row_groups)) | |
| row_group_ids_ = row_group_ids[row_group_start_id:] | |
| for row_group_id in row_group_ids_: | |
| df = fr.read_row_group(row_group_id).to_pandas() | |
| df = df.iloc[row_start_id:] | |
| for row_idx, row in df.iterrows(): | |
| num_tokens = 0 | |
| try: | |
| image_byte = row['image'] | |
| image = pil_img2rgb(Image.open(io.BytesIO(image_byte))) | |
| except Exception as e: | |
| print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}') | |
| continue | |
| image_tensor = self.transform(image) | |
| height, width = image_tensor.shape[1:] | |
| num_tokens += width * height // transform_stride ** 2 | |
| try: | |
| caption_dict = row['captions'] | |
| caption_dict = json.loads(caption_dict) | |
| except Exception as e: | |
| print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}') | |
| continue | |
| caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()] | |
| if len(caps_token) == 0: | |
| print(f'no caption in rg#{row_group_id}, {parquet_file_path}') | |
| caption_token = self.tokenizer.encode(' ') | |
| else: | |
| caption_token = random.choice(caps_token) | |
| sequence_plan, text_ids_list = [], [] | |
| text_ids = caption_token | |
| num_tokens += len(caption_token) | |
| text_ids_list.append(text_ids) | |
| sequence_plan.append({ | |
| 'type': 'text', | |
| 'enable_cfg': 1, | |
| 'loss': 0, | |
| 'special_token_loss': 0, | |
| 'special_token_label': None, | |
| }) | |
| sequence_plan.append({ | |
| 'type': 'vae_image', | |
| 'enable_cfg': 0, | |
| 'loss': 1, | |
| 'special_token_loss': 0, | |
| 'special_token_label': None, | |
| }) | |
| sample = dict( | |
| image_tensor_list=[image_tensor], | |
| text_ids_list=text_ids_list, | |
| num_tokens=num_tokens, | |
| sequence_plan=sequence_plan, | |
| data_indexes={ | |
| "data_indexes": [parquet_idx, row_group_id, row_idx], | |
| "worker_id": worker_id, | |
| "dataset_name": self.dataset_name, | |
| } | |
| ) | |
| yield sample | |
| row_start_id = 0 | |
| row_group_start_id = 0 | |
| parquet_start_id = 0 | |
| print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}") | |