| import os |
| import torch |
| import torch.nn as nn |
|
|
| from transformers import Trainer |
| from typing import Optional |
|
|
|
|
| def unwrap_model(model: nn.Module) -> nn.Module: |
| """ |
| Recursively unwraps a model from potential containers (as used in distributed training). |
| |
| Args: |
| model (`torch.nn.Module`): The model to unwrap. |
| """ |
| |
| if hasattr(model, "module"): |
| return unwrap_model(model.module) |
| else: |
| return model |
|
|
|
|
| class PointLLMTrainer(Trainer): |
|
|
| def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| print("no save!!!!!!1") |
| pass |
| if getattr(self.args, 'tune_mm_mlp_adapter', False): |
| |
| _state_dict = state_dict |
| if _state_dict is None: |
| |
| model_to_save = unwrap_model(self.model) |
| _state_dict = model_to_save.state_dict() |
|
|
| weight_to_save = {} |
| keys_to_match = ['point_proj', 'embed_tokens', 'embed_in'] |
| for k, v in _state_dict.items(): |
| if any(key_match in k for key_match in keys_to_match): |
| weight_to_save[k] = v |
|
|
| current_folder = output_dir.split('/')[-1] |
| parent_folder = os.path.dirname(output_dir) |
| if current_folder.startswith('checkpoint-'): |
| mm_projector_folder = os.path.join(parent_folder, "point_proj") |
| os.makedirs(mm_projector_folder, exist_ok=True) |
| torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) |
| else: |
| torch.save(weight_to_save, os.path.join(output_dir, f'point_proj.bin')) |
|
|
| super(PointLLMTrainer, self)._save(output_dir, state_dict) |
|
|