English
music
File size: 1,162 Bytes
bd2d17d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch

from transformers import Trainer
from typing import Optional

LOG_INTERVAL = 5000

def get_state(model):
    trainable_state_dict = dict()
    for name, param in model.state_dict().items():
        try:
            if model.get_parameter(name).requires_grad:
                trainable_state_dict[name] = param
        except:
            trainable_state_dict[name] = param
    return trainable_state_dict


class SALMONNTrainer(Trainer):

    def _save_checkpoint(self, model, trial, metrics=None):
        from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{(self.state.global_step // LOG_INTERVAL) * LOG_INTERVAL}"

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        os.makedirs(output_dir, exist_ok=True)

        # Only save Adapter
        weight_to_save = get_state(self.model)
        torch.save(weight_to_save, os.path.join(output_dir, f'salomnn_7b.bin'))

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        super(SALMONNTrainer, self)._save(output_dir, state_dict)