litwell commited on
Commit
0743691
·
verified ·
1 Parent(s): f8e18e6

Upload models/src/training/train_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/training/train_utils.py +69 -0
models/src/training/train_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import logging
4
+
5
+
6
+ def maybe_zero_3(param, ignore_status=False, name=None):
7
+ from deepspeed import zero
8
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
9
+ if hasattr(param, "ds_id"):
10
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
11
+ if not ignore_status:
12
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
13
+ with zero.GatheredParameters([param]):
14
+ param = param.data.detach().cpu().clone()
15
+ else:
16
+ param = param.detach().cpu().clone()
17
+ return param
18
+
19
+ # Borrowed from peft.utils.get_peft_model_state_dict
20
+ def get_peft_state_maybe_zero_3(named_params, bias):
21
+ if bias == "none":
22
+ to_return = {k: t for k, t in named_params if "lora_" in k}
23
+ elif bias == "all":
24
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
25
+ elif bias == "lora_only":
26
+ to_return = {}
27
+ maybe_lora_bias = {}
28
+ lora_bias_names = set()
29
+ for k, t in named_params:
30
+ if "lora_" in k:
31
+ to_return[k] = t
32
+ bias_name = k.split("lora_")[0] + "bias"
33
+ lora_bias_names.add(bias_name)
34
+ elif "bias" in k:
35
+ maybe_lora_bias[k] = t
36
+ for k, t in maybe_lora_bias:
37
+ if bias_name in lora_bias_names:
38
+ to_return[bias_name] = t
39
+ else:
40
+ raise NotImplementedError
41
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
42
+ return to_return
43
+
44
+
45
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
46
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
47
+ if require_grad_only:
48
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
49
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
50
+ return to_return
51
+
52
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
53
+ output_dir: str):
54
+ """Collects the state dict and dump to disk."""
55
+
56
+ if trainer.deepspeed:
57
+ torch.cuda.synchronize()
58
+ trainer.save_model(output_dir)
59
+ return
60
+
61
+ state_dict = trainer.model.state_dict()
62
+ if trainer.args.should_save:
63
+ cpu_state_dict = {
64
+ key: value.cpu()
65
+ for key, value in state_dict.items()
66
+ }
67
+ del state_dict
68
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
69
+ trainer.model.config.save_pretrained(output_dir)