MoTIF / utils /core /data /dataloader.py
P4ddyki's picture
Upload folder using huggingface_hub
3cf4fff verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional
from core.data.data_collators import MllmPaddingCollator
from core.data.data_mixer import DatasetMixer, PersistentDataLoader
from core.data.preprocessor import VisionPreprocessor
from core.transforms.image_transform import get_image_transform
from core.transforms.region_transform import get_region_transform
from core.transforms.video_transform import get_video_transform
logger = logging.getLogger(__name__)
@dataclass
class DataloadArgs:
datamix: str = (
"dummy_image:1,dummy_multi_image:1,dummy_image_region:1,dummy_video:1,dummy_text:1,dummy_stc_RDCap:1,dummy_stc_RCap:1,dummy_stc_RTLoc:1"
)
batch_size: int = 2
seed: int = 42
image_res: Optional[int] = None
max_num_tiles: Optional[int] = None
vision_input_type: Optional[str] = None
num_workers: Optional[int] = None
tokenizer_path: Optional[str] = None
tokenizer_name: Optional[str] = None
conversation_format: Optional[str] = None
patch_size: Optional[int] = None
seq_len: Optional[int] = None
max_video_frames: Optional[int] = None
show_first_batch: Optional[bool] = False
def get_rank_position(positions, rank, workers, world_size):
if positions is not None and rank in positions:
if positions["num_workers"] != workers or positions["world_size"] != world_size:
logger.warning(
f"Checkpoint resumed with different number of total dataloader workers. Dataloaders have been reset. "
f"num_workers: {positions['num_workers']} -> {workers}, "
f"world_size: {positions['world_size']} -> {world_size}"
)
return None
return positions[rank]
return None
def get_dataloader(
args,
dp_rank,
dp_world_size,
dataset_configs: Dict[str, Any],
tokenizer=None,
positions=None,
):
vision_input_type = args.vision_input_type
image_res = args.image_res
max_num_tiles = args.max_num_tiles
max_video_frames = args.max_video_frames
preprocessor = partial(
VisionPreprocessor,
transform={
"image": get_image_transform(
vision_input_type=vision_input_type,
image_res=image_res,
max_num_tiles=max_num_tiles,
),
"video": get_video_transform(image_res=image_res),
"region": get_region_transform(),
},
tokenizer=tokenizer,
max_video_frames=max_video_frames,
)
dataset = DatasetMixer(
args.datamix,
global_rank=dp_rank,
world_size=dp_world_size,
seed=args.seed,
preprocessors=[preprocessor],
dataset_configs=dataset_configs,
)
# Create the dataloader
dataloader = PersistentDataLoader(
dataset,
args.batch_size,
args.num_workers,
collate_fn=MllmPaddingCollator(
tokenizer,
show_first_batch=args.show_first_batch,
),
positions=get_rank_position(
positions, dp_rank, args.num_workers, dp_world_size
),
)
return dataloader