| |
|
|
| import logging |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Any, Dict, Optional |
|
|
| from core.data.data_collators import MllmPaddingCollator |
| from core.data.data_mixer import DatasetMixer, PersistentDataLoader |
| from core.data.preprocessor import VisionPreprocessor |
| from core.transforms.image_transform import get_image_transform |
| from core.transforms.region_transform import get_region_transform |
| from core.transforms.video_transform import get_video_transform |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DataloadArgs: |
| datamix: str = ( |
| "dummy_image:1,dummy_multi_image:1,dummy_image_region:1,dummy_video:1,dummy_text:1,dummy_stc_RDCap:1,dummy_stc_RCap:1,dummy_stc_RTLoc:1" |
| ) |
| batch_size: int = 2 |
| seed: int = 42 |
| image_res: Optional[int] = None |
| max_num_tiles: Optional[int] = None |
| vision_input_type: Optional[str] = None |
| num_workers: Optional[int] = None |
| tokenizer_path: Optional[str] = None |
| tokenizer_name: Optional[str] = None |
| conversation_format: Optional[str] = None |
| patch_size: Optional[int] = None |
| seq_len: Optional[int] = None |
| max_video_frames: Optional[int] = None |
| show_first_batch: Optional[bool] = False |
|
|
|
|
| def get_rank_position(positions, rank, workers, world_size): |
| if positions is not None and rank in positions: |
| if positions["num_workers"] != workers or positions["world_size"] != world_size: |
| logger.warning( |
| f"Checkpoint resumed with different number of total dataloader workers. Dataloaders have been reset. " |
| f"num_workers: {positions['num_workers']} -> {workers}, " |
| f"world_size: {positions['world_size']} -> {world_size}" |
| ) |
| return None |
| return positions[rank] |
| return None |
|
|
|
|
| def get_dataloader( |
| args, |
| dp_rank, |
| dp_world_size, |
| dataset_configs: Dict[str, Any], |
| tokenizer=None, |
| positions=None, |
| ): |
| vision_input_type = args.vision_input_type |
| image_res = args.image_res |
| max_num_tiles = args.max_num_tiles |
| max_video_frames = args.max_video_frames |
|
|
| preprocessor = partial( |
| VisionPreprocessor, |
| transform={ |
| "image": get_image_transform( |
| vision_input_type=vision_input_type, |
| image_res=image_res, |
| max_num_tiles=max_num_tiles, |
| ), |
| "video": get_video_transform(image_res=image_res), |
| "region": get_region_transform(), |
| }, |
| tokenizer=tokenizer, |
| max_video_frames=max_video_frames, |
| ) |
|
|
| dataset = DatasetMixer( |
| args.datamix, |
| global_rank=dp_rank, |
| world_size=dp_world_size, |
| seed=args.seed, |
| preprocessors=[preprocessor], |
| dataset_configs=dataset_configs, |
| ) |
|
|
| |
| dataloader = PersistentDataLoader( |
| dataset, |
| args.batch_size, |
| args.num_workers, |
| collate_fn=MllmPaddingCollator( |
| tokenizer, |
| show_first_batch=args.show_first_batch, |
| ), |
| positions=get_rank_position( |
| positions, dp_rank, args.num_workers, dp_world_size |
| ), |
| ) |
|
|
| return dataloader |
|
|