| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from dataclasses import dataclass, field |
| | import pathlib |
| | from typing import Optional, List |
| |
|
| | import torch |
| | import transformers |
| | from pointllm.train.pointllm_trainer import PointLLMTrainer |
| |
|
| | from pointllm import conversation as conversation_lib |
| | from pointllm.model import * |
| | from pointllm.data import make_object_point_data_module |
| |
|
| | |
| | from pointllm.utils import build_logger |
| |
|
| | IGNORE_INDEX = -100 |
| |
|
| | DEFAULT_PAD_TOKEN = "[PAD]" |
| | DEFAULT_EOS_TOKEN = "</s>" |
| | DEFAULT_BOS_TOKEN = "</s>" |
| | DEFAULT_UNK_TOKEN = "<unk>" |
| |
|
| |
|
| | @dataclass |
| | class ModelArguments: |
| | |
| | model_name_or_path: Optional[str] = field(default="/home/TinyGPT-V/pretrain_weight/phi-new") |
| | |
| | version: Optional[str] = field(default="v1") |
| |
|
| | @dataclass |
| | class DataArguments: |
| | data_path: str = field(default="/home/PointLLM/data/objaverse_data", metadata={"help": "Path to the training data."}) |
| | anno_path: str = field(default='/home/PointLLM/data/anno_data/PointLLM_complex_instruction_70K.json', metadata={"help": "Path to the utterance data. If None, will use referit3d by defautl."}) |
| | use_color: bool = field(default=True, metadata={"help": "Whether to use color."}) |
| | data_debug_num: int = field(default=0, metadata={"help": "Number of data to use in debug mode. If larger than 0, use debug mode, else use the whole data"}) |
| | split_train_val: bool = field(default=False, metadata={"help": "Whether to split train and val."}) |
| | split_ratio: float = field(default=0.9, metadata={"help": "Ratio of train and val."}) |
| | pointnum: int = field(default=8192, metadata={"help": "Number of points."}) |
| |
|
| | |
| | conversation_types: List[str] = field(default_factory=lambda: ["detailed_description", "single_round", "multi_round"], |
| | metadata={"help": "Conversation types to use."}) |
| | is_multimodal: bool = True |
| |
|
| | @dataclass |
| | class TrainingArguments(transformers.TrainingArguments): |
| | |
| | cache_dir: Optional[str] = field(default='/home/PointLLM/trash') |
| | output_dir: Optional[str] = field(default='/home/PointLLM/trash') |
| |
|
| | save_strategy: Optional[str] = field(default='no') |
| |
|
| | save_steps: int = field(default=2400) |
| | optim: str = field(default="adamw_torch") |
| | dataloader_num_workers: int = field(default=24) |
| |
|
| | model_max_length: int = field( |
| | default=2048, |
| | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
| | ) |
| | per_device_train_batch_size: int = field( |
| | default=6, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."} |
| | ) |
| | model_debug: bool = field(default=False, metadata={"help": "Whether to use small model."}) |
| | fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."}) |
| | fix_pointnet: bool = field(default=True, metadata={"help": "Whether to fix the PointNet."}) |
| |
|
| | remove_unused_columns: bool = field(default=False) |
| | force_fsdp: bool = field(default=False) |
| | bf16: bool = field(default=True) |
| | |
| | tune_mm_mlp_adapter: bool = field(default=True) |
| | stage_2: bool = field(default=False) |
| | pretrained_mm_mlp_adapter: Optional[str] = field(default=None) |
| | detatch_point_token: bool = field(default=False) |
| | |
| | |
| |
|
| |
|
| | |
| | point_backbone_ckpt: str = field(default="/home/pointllm_weight_2/point_model/point_model.pth") |
| | |
| |
|
| | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, |
| | output_dir: str): |
| | """Collects the state dict and dump to disk.""" |
| | state_dict = trainer.model.state_dict() |
| | if trainer.args.should_save: |
| | cpu_state_dict = { |
| | key: value.cpu() |
| | for key, value in state_dict.items() |
| | } |
| | del state_dict |
| | trainer._save(output_dir, state_dict=cpu_state_dict) |
| |
|
| |
|
| | def train(): |
| | parser = transformers.HfArgumentParser( |
| | (ModelArguments, DataArguments, TrainingArguments)) |
| | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| |
|
| | training_args.log_level = "info" |
| | |
| | |
| |
|
| | training_args.output_dir = '/home/PointLLM/trash' |
| | logger = build_logger(__name__, training_args.output_dir + '/train.log') |
| |
|
| | if training_args.model_debug: |
| | |
| | config = transformers.AutoConfig.from_pretrained( |
| | model_args.model_name_or_path, |
| | cache_dir=training_args.cache_dir, |
| | torch_dtype=torch.float32 |
| | ) |
| | model = PointLLMLlamaForCausalLM._from_config(config) |
| | else: |
| | model = PointLLMLlamaForCausalLM.from_pretrained( |
| | model_args.model_name_or_path, |
| | cache_dir=training_args.cache_dir, |
| | torch_dtype=torch.float16 |
| | ) |
| |
|
| | model.config.use_cache = False |
| |
|
| | if training_args.fix_llm: |
| | |
| | logger.info("LLM is fixed. Fix_llm flag is set to True") |
| | |
| | model.requires_grad_(False) |
| | model.get_model().fix_llm = True |
| | model.get_model().point_proj.requires_grad_(True) |
| | model.get_model().point_backbone.requires_grad_(True) |
| | else: |
| | model.get_model().fix_llm = False |
| | logger.warning("LLM is trainable. Fix_llm flag is set to False") |
| |
|
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | model_args.model_name_or_path, |
| | cache_dir=training_args.cache_dir, |
| | model_max_length=training_args.model_max_length, |
| | padding_side="right", |
| | use_fast=True, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | if model_args.version == "v0" or "v0" in model_args.model_name_or_path: |
| | raise ValueError("v0 is deprecated.") |
| | else: |
| | |
| | tokenizer.pad_token = tokenizer.eos_token |
| | conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"] |
| |
|
| | if not training_args.fix_pointnet: |
| | |
| | logger.info("Point backbone is trainable. Fix_pointnet flag is set to False, pointnet grad will be recorded.") |
| | model.get_model().fix_pointnet = False |
| | else: |
| | logger.info("Point backbone is fixed. Fix_pointnet flag is set to True, pointnet grad will not be recorded.") |
| | model.get_model().fix_pointnet = True |
| | if not training_args.stage_2: |
| | logger.info("Set requires_grad of point backbone to False") |
| | model.get_model().point_backbone.requires_grad_(False) |
| | |
| | if training_args.tune_mm_mlp_adapter: |
| | |
| | |
| | |
| | logger.info("Point projection layer is trainable.") |
| | else: |
| | model.get_model().point_proj.requires_grad_(False) |
| | logger.info("Point prejcetion layer is fixed.") |
| |
|
| | if not training_args.stage_2: |
| | |
| | print(f"Default point_backbone_ckpt is {training_args.point_backbone_ckpt}.") |
| | model.get_model().load_point_backbone_checkpoint(training_args.point_backbone_ckpt) |
| | model.initialize_tokenizer_point_backbone_config(tokenizer=tokenizer, device=training_args.device, fix_llm=training_args.fix_llm) |
| | else: |
| | |
| | model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer=tokenizer) |
| |
|
| | point_backbone_config = model.get_model().point_backbone_config |
| |
|
| | data_args.point_token_len = point_backbone_config['point_token_len'] |
| | data_args.mm_use_point_start_end = point_backbone_config['mm_use_point_start_end'] |
| | data_args.point_backbone_config = point_backbone_config |
| |
|
| | params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] |
| | if len(params_no_grad) > 0: |
| | if training_args.fsdp is not None and len(training_args.fsdp) > 0: |
| | if len(params_no_grad) < 10: |
| | print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) |
| | else: |
| | print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) |
| | print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") |
| | print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") |
| |
|
| | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
| | def patch_FSDP_use_orig_params(func): |
| | def wrap_func(*args, **kwargs): |
| | use_orig_params = kwargs.pop('use_orig_params', True) |
| | return func(*args, **kwargs, use_orig_params=use_orig_params) |
| | return wrap_func |
| |
|
| | FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) |
| |
|
| | data_module = make_object_point_data_module(tokenizer=tokenizer, |
| | data_args=data_args) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | for name, param in model.model.named_parameters(): |
| | param.requires_grad = False |
| |
|
| | for name, param in model.model.named_parameters(): |
| | if 'point_proj' in name: |
| | param.requires_grad = True |
| |
|
| | for name, param in model.model.named_parameters(): |
| | if 'q_layernorm' in name: |
| | param.requires_grad = True |
| |
|
| | if 'k_layernorm' in name: |
| | param.requires_grad = True |
| |
|
| | if 'post_layernorm' in name: |
| | param.requires_grad = True |
| |
|
| | if 'input_layernorm' in name: |
| | param.requires_grad = True |
| |
|
| | |
| | |
| |
|
| | if 'final_layernorm' in name: |
| | param.requires_grad = True |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for i, layer in enumerate(model.model.layers): |
| | |
| | |
| | layer.self_attn.q_layernorm.weight.requires_grad = True |
| | layer.self_attn.k_layernorm.weight.requires_grad = True |
| | layer.post_layernorm.weight.requires_grad = True |
| | layer.input_layernorm.weight.requires_grad = True |
| |
|
| | layer.self_attn.q_layernorm.weight.data = layer.self_attn.q_layernorm.weight.data.float() |
| | layer.self_attn.k_layernorm.weight.data = layer.self_attn.k_layernorm.weight.data.float() |
| | layer.post_layernorm.weight.data = layer.post_layernorm.weight.data.float() |
| | layer.input_layernorm.weight.data = layer.input_layernorm.weight.data.float() |
| |
|
| | |
| | if layer.self_attn.q_layernorm.bias is not None: |
| | layer.self_attn.q_layernorm.bias.data = layer.self_attn.q_layernorm.bias.data.float() |
| | if layer.self_attn.k_layernorm.bias is not None: |
| | layer.self_attn.k_layernorm.bias.data = layer.self_attn.k_layernorm.bias.data.float() |
| | if layer.input_layernorm.bias is not None: |
| | layer.input_layernorm.bias.data = layer.input_layernorm.bias.data.float() |
| |
|
| | model.model.final_layernorm.weight.requires_grad = True |
| | model.model.final_layernorm.weight.data = model.model.final_layernorm.weight.data.float() |
| | if model.model.final_layernorm.bias is not None: |
| | model.model.final_layernorm.bias.data = model.model.final_layernorm.bias.float() |
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| |
|
| | for name, param in model.model.named_parameters(): |
| | if param.requires_grad: |
| |
|
| | logger.info(f"Parameter {name} will be updated.") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | trainer = PointLLMTrainer(model=model, |
| | tokenizer=tokenizer, |
| | args=training_args, |
| | **data_module) |
| |
|
| | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
| | trainer.train(resume_from_checkpoint=True) |
| | else: |
| | trainer.train() |
| | trainer.save_state() |
| | safe_save_model_for_hf_trainer(trainer=trainer, |
| | output_dir=training_args.output_dir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|