import json import os import fire from dataclasses import asdict from functools import partial import torch import torch.distributed as dist from hpsv3.model.qwen2vl_trainer import ( Qwen2VLRewardModelBT, VLMRewardTrainer, compute_multi_attr_accuracy, PartialEmbeddingUpdateCallback, ) from hpsv3.dataset.pairwise_dataset import PairwiseOriginalDataset from hpsv3.dataset.data_collator_qwen import QWen2VLDataCollator from hpsv3.utils.parser import ModelConfig, PEFTLoraConfig, TrainingConfig, DataConfig from hpsv3.utils.training_utils import load_model_from_checkpoint, find_target_linear_names from hpsv3.utils.parser import parse_args_with_yaml from transformers import AutoProcessor from peft import LoraConfig, get_peft_model from trl import get_kbit_device_map, get_quantization_config from hpsv3.model.differentiable_image_processor import Qwen2VLImageProcessor try: import flash_attn except ImportError: flash_attn = None print("Flash Attention is not installed. Falling to SDPA.") def create_model_and_processor( model_config, peft_lora_config, training_args, cache_dir=None, differentiable=False, ): # create model torch_dtype = ( model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype) ) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, use_cache=False ) # create processor and set padding processor = AutoProcessor.from_pretrained( model_config.model_name_or_path, padding_side="right", cache_dir=cache_dir ) if differentiable: processor.image_processor = Qwen2VLImageProcessor() special_token_ids = None if model_config.use_special_tokens: special_tokens = ["<|Reward|>"] processor.tokenizer.add_special_tokens( {"additional_special_tokens": special_tokens} ) special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens) model = Qwen2VLRewardModelBT.from_pretrained( model_config.model_name_or_path, output_dim=model_config.output_dim, reward_token=model_config.reward_token, special_token_ids=special_token_ids, torch_dtype=torch_dtype, attn_implementation=( "flash_attention_2" if not training_args.disable_flash_attn2 and flash_attn is not None else "sdpa" ), cache_dir=cache_dir, rm_head_type=model_config.rm_head_type, rm_head_kwargs=model_config.rm_head_kwargs, **model_kwargs, ) if model_config.use_special_tokens: model.resize_token_embeddings(len(processor.tokenizer)) if training_args.bf16: model.to(torch.bfloat16) if training_args.fp16: model.to(torch.float16) model.rm_head.to(torch.float32) # create lora and peft model if peft_lora_config.lora_enable: target_modules = find_target_linear_names( model, num_lora_modules=peft_lora_config.num_lora_modules, lora_namespan_exclude=peft_lora_config.lora_namespan_exclude, ) peft_config = LoraConfig( target_modules=target_modules, r=peft_lora_config.lora_r, lora_alpha=peft_lora_config.lora_alpha, lora_dropout=peft_lora_config.lora_dropout, task_type=peft_lora_config.lora_task_type, use_rslora=peft_lora_config.use_rslora, bias="none", modules_to_save=peft_lora_config.lora_modules_to_save, ) model = get_peft_model(model, peft_config) else: peft_config = None model.config.tokenizer_padding_side = processor.tokenizer.padding_side model.config.pad_token_id = processor.tokenizer.pad_token_id return model, processor, peft_config def save_configs_to_json(data_config, training_args, model_config, peft_lora_config): """ Save all configurations to a JSON file. """ config_dict = { "data_config": asdict(data_config), "training_args": asdict(training_args), "model_config": asdict(model_config), "peft_lora_config": asdict(peft_lora_config), } # del information about local device del config_dict["training_args"]["local_rank"] del config_dict["training_args"]["_n_gpu"] save_path = os.path.join(training_args.output_dir, "model_config.json") os.makedirs(training_args.output_dir, exist_ok=True) print(training_args.output_dir) with open(save_path, "w") as f: json.dump(config_dict, f, indent=4) def set_requires_grad(parameters, requires_grad): for p in parameters: p.requires_grad = requires_grad def train(config, local_rank=0, debug=False): ## ===> Step 1: Parse arguments (data_config, training_args, model_config, peft_lora_config), config_path = ( parse_args_with_yaml( (DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig), config, is_train=True ) ) training_args.output_dir = os.path.join( training_args.output_dir, config.split("/")[-1].split(".")[0] ) training_args.logging_dir = training_args.output_dir # check valid (lora config) assert not ( peft_lora_config.lora_enable and model_config.freeze_llm ), "When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA." if not peft_lora_config.lora_enable: assert ( not peft_lora_config.vision_lora ), "Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled." else: if peft_lora_config.lora_namespan_exclude is None: peft_lora_config.lora_namespan_exclude = [] if not peft_lora_config.vision_lora: peft_lora_config.lora_namespan_exclude += ["visual"] ## ===> Step 2: Load model and configure model, processor, peft_config = create_model_and_processor( model_config=model_config, peft_lora_config=peft_lora_config, training_args=training_args, ) ## load model if training_args.load_from_pretrained is not None: model, checkpoint_step = load_model_from_checkpoint( model, training_args.load_from_pretrained, training_args.load_from_pretrained_step, ) model.train() if peft_lora_config.lora_enable: model_to_configure = model.model else: model_to_configure = model # set requires_grad for LLM set_requires_grad( model_to_configure.model.parameters(), not model_config.freeze_llm ) set_requires_grad(model_to_configure.model.embed_tokens.parameters(), False) if not peft_lora_config.vision_lora: # set requires_grad for visual encoder and merger set_requires_grad( model_to_configure.visual.parameters(), not model_config.freeze_vision_tower ) set_requires_grad( model_to_configure.visual.merger.parameters(), model_config.tune_merger ) if model_config.trainable_visual_layers: # This is inverse order to index of model.visual.blocks, set -1 to unfreeze all layers assert model_config.trainable_visual_layers <= len(model_to_configure.visual.blocks), "trainable_visual_layers should be less than or equal to the number of visual blocks" freeze_layer_num = len(model_to_configure.visual.blocks) - model_config.trainable_visual_layers if model_config.trainable_visual_layers > 0 else 0 for index, layer in enumerate(model_to_configure.visual.blocks): if index < freeze_layer_num: set_requires_grad(layer.parameters(), False) else: set_requires_grad(layer.parameters(), True) # set requires_grad for regression head set_requires_grad(model_to_configure.rm_head.parameters(), True) ## ===> Step 3: Load Dataset and configure train_dataset = PairwiseOriginalDataset( data_config.train_json_list, data_config.soft_label, data_config.confidence_threshold, ) test_set_dict = {} for item in data_config.test_json_list: test_set_dict[item[0]] = PairwiseOriginalDataset( item[1], data_config.soft_label, data_config.confidence_threshold, ) print(f"===> Selected {len(train_dataset)} samples for training.") for key, value in test_set_dict.items(): print(f"===> Selected {len(value)} samples for {key} testing.") num_gpu = int(os.environ.get("WORLD_SIZE", 1)) data_collator = QWen2VLDataCollator( processor, max_pixels=data_config.max_pixels, min_pixels=data_config.min_pixels, with_instruction=data_config.with_instruction, use_special_tokens=model_config.use_special_tokens, ) compute_metrics = partial(compute_multi_attr_accuracy) actual_batch_size = ( training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * num_gpu ) total_steps = ( training_args.num_train_epochs * len(train_dataset) // actual_batch_size ) if training_args.save_epochs is not None: training_args.save_steps = round( training_args.save_epochs * len(train_dataset) / actual_batch_size ) if training_args.eval_epochs is not None: training_args.eval_steps = round( training_args.eval_epochs * len(train_dataset) / actual_batch_size ) if training_args.logging_epochs is not None: training_args.logging_steps = round( training_args.logging_epochs * len(train_dataset) / actual_batch_size ) if training_args.local_rank == -1 or training_args.local_rank == 0: print(f"===> Using {num_gpu} GPUs.") print(f"===> Total Batch Size: {actual_batch_size}") print(f"===> Training Epochs: {training_args.num_train_epochs}") print(f"===> Total Steps: {total_steps}") print(f"===> Save Steps: {training_args.save_steps}") print(f"===> Eval Steps: {training_args.eval_steps}") print(f"===> Logging Steps: {training_args.logging_steps}") ## ===> Step 4: Save configs for re-check if training_args.local_rank == -1 or training_args.local_rank == 0: save_configs_to_json(data_config, training_args, model_config, peft_lora_config) print(train_dataset) ## ===> Step 5: Start Training! special_token_ids = model.special_token_ids callbacks = [] if special_token_ids is not None: callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids)) trainer = VLMRewardTrainer( model=model, compute_metrics=compute_metrics, data_collator=data_collator, args=training_args, train_dataset=train_dataset, eval_dataset=(test_set_dict if training_args.conduct_eval else None), peft_config=peft_config, callbacks=callbacks, loss_type=model_config.loss_type, loss_hyperparameters=model_config.loss_hyperparameters, tokenizer=processor.tokenizer, tied_threshold=data_config.tied_threshold, visualization_steps=training_args.visualization_steps, max_viz_samples=training_args.max_viz_samples, ) trainer.train() if training_args.local_rank == -1 or training_args.local_rank == 0: model_state_dict = model.state_dict() torch.save( model_state_dict, os.path.join(training_args.output_dir, "final_model.pth") ) model.config.save_pretrained(training_args.output_dir) if __name__ == "__main__": fire.Fire(train)