from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW from transformers import AutoTokenizer from xtuner.dataset import ConcatDataset from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop from xtuner.utils import PROMPT_TEMPLATE from xtuner.dataset.map_fns import template_map_fn_factory from projects.InternVL.collect_fns import internvl_collate_fn from peft import LoraConfig from projects.InternVL.internvl import InternVL_vlm from projects.lisa.datasets.vqa_dataset import LLaVADataset from projects.llava_sam2.datasets import ReferSegmDataset from projects.llava_sam2.models.preprocess.image_resize import DirectResize ####################################################################### # PART 1 Settings # ####################################################################### # Model path = './pretrained/internvl/InternVL2-4B' # Data image_folder = './data/DiagrammaticReasoning/' data_file = './data//DiagrammaticReasoning/train.json' prompt_template = PROMPT_TEMPLATE.phi3_chat max_length = 8192 # Scheduler & Optimizer batch_size = 4 # per_device accumulative_counts = 4 dataloader_num_workers = 4 max_epochs = 1 optim_type = AdamW # official 1024 -> 4e-5 # lr = 1e-6 lr = 4e-5 betas = (0.9, 0.999) weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.05 # Save save_steps = 1000 save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=path, trust_remote_code=True, padding_side='right') extra_image_processor = dict( type=DirectResize, target_length=1024, ) ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( dict( type=InternVL_vlm, model_path=path, freeze_llm=True, freeze_visual_encoder=True, llm_lora=dict( type=LoraConfig, r=128, lora_alpha=256, lora_dropout=0.05, bias='none', task_type='CAUSAL_LM'), ), ) ####################################################################### # PART 3 Dataset & Dataloader # ####################################################################### ################## image chat llava_vqa_dataset = dict( type=LLaVADataset, tokenizer=tokenizer, data_path=data_file, prompt_template=prompt_template, special_tokens=None, image_folder=image_folder, ) train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=llava_vqa_dataset, sampler=dict( type=LengthGroupedSampler, length_property='modality_length', per_device_batch_size=batch_size * accumulative_counts), collate_fn=dict(type=internvl_collate_fn) ) ####################################################################### # PART 4 Scheduler & Optimizer # ####################################################################### # optimizer optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='bfloat16' ) # learning policy # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio * max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio * max_epochs, end=max_epochs, convert_to_iter_based=True) ] # train, val, test setting train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) ####################################################################### # PART 5 Runtime # ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), ] # configure default hooks default_hooks = dict( # record the time of every iteration. timer=dict(type=IterTimerHook), # print log every 10 iterations. logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), # enable the parameter scheduler. param_scheduler=dict(type=ParamSchedulerHook), # save checkpoint per `save_steps`. checkpoint=dict( type=CheckpointHook, save_optimizer=False, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit), # set sampler seed in distributed evrionment. sampler_seed=dict(type=DistSamplerSeedHook), ) # configure environment env_cfg = dict( # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # set distributed parameters dist_cfg=dict(backend='nccl'), ) # set visualizer visualizer = None # set log level log_level = 'INFO' # load from which checkpoint load_from = None # whether to resume training from the loaded checkpoint resume = False # Defaults to use random seed and disable `deterministic` randomness = dict(seed=None, deterministic=False) # set log processor log_processor = dict(by_epoch=False)