|
|
|
|
|
|
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from core.data.data_collators import MllmPaddingCollator |
|
|
from core.data.data_mixer import DatasetMixer, PersistentDataLoader |
|
|
from core.data.preprocessor import VisionPreprocessor |
|
|
from core.transforms.image_transform import get_image_transform |
|
|
from core.transforms.region_transform import get_region_transform |
|
|
from core.transforms.video_transform import get_video_transform |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataloadArgs: |
|
|
datamix: str = ( |
|
|
"dummy_image:1,dummy_multi_image:1,dummy_image_region:1,dummy_video:1,dummy_text:1,dummy_stc_RDCap:1,dummy_stc_RCap:1,dummy_stc_RTLoc:1" |
|
|
) |
|
|
batch_size: int = 2 |
|
|
seed: int = 42 |
|
|
image_res: Optional[int] = None |
|
|
max_num_tiles: Optional[int] = None |
|
|
vision_input_type: Optional[str] = None |
|
|
num_workers: Optional[int] = None |
|
|
tokenizer_path: Optional[str] = None |
|
|
tokenizer_name: Optional[str] = None |
|
|
conversation_format: Optional[str] = None |
|
|
patch_size: Optional[int] = None |
|
|
seq_len: Optional[int] = None |
|
|
max_video_frames: Optional[int] = None |
|
|
show_first_batch: Optional[bool] = False |
|
|
|
|
|
|
|
|
def get_rank_position(positions, rank, workers, world_size): |
|
|
if positions is not None and rank in positions: |
|
|
if positions["num_workers"] != workers or positions["world_size"] != world_size: |
|
|
logger.warning( |
|
|
f"Checkpoint resumed with different number of total dataloader workers. Dataloaders have been reset. " |
|
|
f"num_workers: {positions['num_workers']} -> {workers}, " |
|
|
f"world_size: {positions['world_size']} -> {world_size}" |
|
|
) |
|
|
return None |
|
|
return positions[rank] |
|
|
return None |
|
|
|
|
|
|
|
|
def get_dataloader( |
|
|
args, |
|
|
dp_rank, |
|
|
dp_world_size, |
|
|
dataset_configs: Dict[str, Any], |
|
|
tokenizer=None, |
|
|
positions=None, |
|
|
): |
|
|
vision_input_type = args.vision_input_type |
|
|
image_res = args.image_res |
|
|
max_num_tiles = args.max_num_tiles |
|
|
max_video_frames = args.max_video_frames |
|
|
|
|
|
preprocessor = partial( |
|
|
VisionPreprocessor, |
|
|
transform={ |
|
|
"image": get_image_transform( |
|
|
vision_input_type=vision_input_type, |
|
|
image_res=image_res, |
|
|
max_num_tiles=max_num_tiles, |
|
|
), |
|
|
"video": get_video_transform(image_res=image_res), |
|
|
"region": get_region_transform(), |
|
|
}, |
|
|
tokenizer=tokenizer, |
|
|
max_video_frames=max_video_frames, |
|
|
) |
|
|
|
|
|
dataset = DatasetMixer( |
|
|
args.datamix, |
|
|
global_rank=dp_rank, |
|
|
world_size=dp_world_size, |
|
|
seed=args.seed, |
|
|
preprocessors=[preprocessor], |
|
|
dataset_configs=dataset_configs, |
|
|
) |
|
|
|
|
|
|
|
|
dataloader = PersistentDataLoader( |
|
|
dataset, |
|
|
args.batch_size, |
|
|
args.num_workers, |
|
|
collate_fn=MllmPaddingCollator( |
|
|
tokenizer, |
|
|
show_first_batch=args.show_first_batch, |
|
|
), |
|
|
positions=get_rank_position( |
|
|
positions, dp_rank, args.num_workers, dp_world_size |
|
|
), |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|