yevvonlim commited on
Commit
0c085bd
·
verified ·
1 Parent(s): 0bc07df

Add files using upload-large-folder tool

Browse files
attention_mask.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ def _make_causal_mask(
7
+ attention_mask: torch.Tensor, dtype: torch.dtype, device: torch.device
8
+ ):
9
+ """
10
+ Make causal mask used for bi-directional self-attention.
11
+ """
12
+ bsz, tgt_len = attention_mask.shape
13
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
14
+ mask_cond = torch.arange(mask.size(-1), device=device)
15
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
16
+ mask = mask.to(dtype)
17
+
18
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
19
+
20
+
21
+ def _make_2dvison_mask(column_mask, dtype: torch.dtype, device: torch.device):
22
+ """
23
+ """
24
+ bsz, seq_length = column_mask.shape
25
+ cross_mask = torch.zeros((bsz, 1, seq_length, seq_length), dtype=dtype, device=device)
26
+
27
+ # 找到连续的 1 的区间
28
+ start = None
29
+ for bsz_idx in range(bsz):
30
+ for i in range(seq_length):
31
+ if column_mask[bsz_idx, i] == 1:
32
+ if start is None:
33
+ start = i
34
+ else:
35
+ if start is not None:
36
+ # 填充区间
37
+ cross_mask[bsz_idx, 0, start:i, start:i] = 1
38
+ start = None
39
+
40
+ # 处理最后一个区间
41
+ if start is not None:
42
+ cross_mask[bsz_idx, 0, start:seq_length, start:seq_length] = 1
43
+
44
+ return cross_mask
45
+
46
+
47
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
48
+ """
49
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
50
+ """
51
+ bsz, src_len = mask.size()
52
+ tgt_len = tgt_len if tgt_len is not None else src_len
53
+
54
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
55
+
56
+ inverted_mask = 1.0 - expanded_mask
57
+
58
+ return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
59
+
60
+
61
+ def make_mask(attention_mask: torch.Tensor, dtype: torch.dtype=None, device: torch.device=None, mode: str="default", vision_mask: torch.Tensor=None, ):
62
+ if dtype is None:
63
+ dtype = attention_mask.dtype
64
+ if device is None:
65
+ device = attention_mask.device
66
+ expanded_attn_mask = _expand_mask(attention_mask, dtype).to(device)
67
+ causal_mask = _make_causal_mask(attention_mask, dtype, device).to(device)
68
+ if mode == "default":
69
+ return attention_mask
70
+ else:
71
+ assert vision_mask is not None, "vision_mask is None"
72
+ vision_mask = vision_mask.to(device)
73
+ bsz, seq_length = attention_mask.shape
74
+ vision_mask_bg = vision_mask[:, None, :, None]
75
+ vision_mask_2d = _make_2dvison_mask(vision_mask, dtype, device)
76
+ if mode == "bidirectional":
77
+ mask = expanded_attn_mask + causal_mask
78
+ mask = mask.clone().masked_fill_(vision_mask_2d.to(torch.bool), 0)
79
+ return mask
80
+ else:
81
+ raise NotImplementedError(f"mode {mode} is not implemented")
aux_vision.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPVisionModel, AutoModel
4
+
5
+ from .configuration_vora import VoRAConfig
6
+ from .eva_model import EVAVisionTransformer
7
+ import loguru
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, dim: int, eps: float = 1e-5):
10
+ super().__init__()
11
+ self.weight = nn.Parameter(torch.ones(dim))
12
+ self.eps = eps
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ output = self._norm(x.float()).type_as(x)
16
+ return output * self.weight
17
+
18
+ def extra_repr(self) -> str:
19
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
20
+
21
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
22
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
23
+
24
+
25
+ class CosineLoss(nn.Module):
26
+ def __init__(self, reduction='mean'):
27
+ super(CosineLoss, self).__init__()
28
+ self.reduction = reduction
29
+
30
+ @staticmethod
31
+ def interpolate_tokens_2d(self, teacher_tokens, target_size):
32
+ """
33
+ Interpolate teacher tokens to the target size using bilinear interpolation.
34
+ """
35
+ # teacher_tokens shape is (batch_size, height, width, feature_dim)
36
+ teacher_tokens = teacher_tokens.permute(0, 3, 1, 2) # Convert to (batch_size, feature_dim, height, width)
37
+ interpolated = torch.nn.functional.interpolate(teacher_tokens, size=target_size, mode='bilinear', align_corners=True).flatten(2) # Flatten height and width dimensions
38
+ return interpolated.permute(0, 2, 1) # Convert back to (batch_size, new_height * new_width, feature_dim)
39
+
40
+ def forward(self, input: torch.Tensor, target: torch.Tensor, input_shape=None, target_shape=None) -> torch.Tensor:
41
+ if input_shape is not None and target_shape is not None:
42
+ input = input.reshape((input.shape[0], ) + input_shape + (-1, ))
43
+ input = self.interpolate_tokens_2d(input, target_shape)
44
+
45
+ cos_sim = nn.functional.cosine_similarity(input, target, dim=1)
46
+ loss = 1 - cos_sim
47
+
48
+ if self.reduction == 'mean':
49
+ return loss.mean()
50
+ elif self.reduction == 'sum':
51
+ return loss.sum()
52
+ else:
53
+ return loss
54
+
55
+
56
+ class AuxVision(nn.Module):
57
+ def __init__(self,
58
+ config: VoRAConfig = None,
59
+ ):
60
+ super().__init__()
61
+ self.skip_aux_cls = config.skip_aux_cls # whether to skip the cls token in ViT
62
+ # ---------------- Setup Aux Model ----------------
63
+ # support jina clip encoder
64
+ if 'jina' in config.aux_vision.lower() and 'clip' in config.aux_vision.lower():
65
+ cfg = {
66
+ "img_size": 512,
67
+ "num_classes": 1024,
68
+ "embed_dim": 1024,
69
+ "patch_size": 14,
70
+ "depth": 24,
71
+ "qkv_bias": True,
72
+ "naiveswiglu": True,
73
+ "num_heads": 16,
74
+ "patch_dropout":0, # disable patch dropout
75
+ "subln": True,
76
+ "mlp_ratio": 2.66666,
77
+ "use_mean_pooling": False,
78
+ }
79
+ self.aux_model = EVAVisionTransformer(**cfg)
80
+ self.aux_model.load_state_dict(torch.load(config.aux_vision, map_location='cpu', weights_only=True), strict=False)
81
+ vision_hidden_size = 1024
82
+ num_hidden_layers = 24
83
+
84
+
85
+ elif 'clip' in config.aux_vision.lower():
86
+ self.aux_model = CLIPVisionModel.from_pretrained(config.aux_vision)
87
+ vision_hidden_size = self.aux_model.vision_model.config.hidden_size
88
+ num_hidden_layers = self.aux_model.vision_model.config.num_hidden_layers
89
+
90
+ else:
91
+ self.aux_model = AutoModel.from_pretrained(config.aux_vision, trust_remote_code=True)
92
+ vision_hidden_size = self.aux_model.config.hidden_size
93
+ num_hidden_layers = self.aux_model.config.num_hidden_layers
94
+ for name, param in self.aux_model.named_parameters():
95
+ param.requires_grad = False
96
+ # -------------------------------------------------
97
+
98
+ # ---------------- Setup Aux Heads ----------------
99
+ self.aux_layers = list(range(num_hidden_layers))
100
+ for layer_id in self.aux_layers:
101
+ self.add_module(f"aux_layer_{layer_id}", self.build_aux_layer(config.hidden_size, vision_hidden_size))
102
+ # -------------------------------------------------
103
+
104
+ self.loss_function = CosineLoss()
105
+ self.loss_keys = [f"loss_aux_layer_{layer_id}" for layer_id in self.aux_layers]
106
+
107
+ def build_aux_layer(self, llm_hidden_size, vit_hidden_size):
108
+ return nn.Sequential(
109
+ RMSNorm(llm_hidden_size),
110
+ nn.Linear(
111
+ llm_hidden_size,
112
+ vit_hidden_size,
113
+ bias=False,
114
+ )
115
+ )
116
+
117
+ def forward(self, frames, llm_hidden_states, vision_mask):
118
+ vision_hidden_states = self.aux_model(frames, output_hidden_states=True).hidden_states
119
+ losses = {}
120
+ for layer_idx in self.aux_layers:
121
+ aux_hidden_states = getattr(self, f"aux_layer_{layer_idx}")(llm_hidden_states[layer_idx][vision_mask == 1])
122
+ start_id = 1 if self.skip_aux_cls else 0
123
+ try:
124
+ aux_loss = self.loss_function(vision_hidden_states[layer_idx][:, start_id:].reshape(aux_hidden_states.shape), aux_hidden_states)
125
+ except Exception as e:
126
+ loguru.logger.error(f"Aux Vision loss function error: {e} occured at layer {layer_idx}")
127
+ loguru.logger.error(f"Aux Vision aux_hidden_states: {aux_hidden_states.shape}, vision_hidden_states: {vision_hidden_states[layer_idx][:, start_id:].reshape(aux_hidden_states.shape).shape}")
128
+ raise e
129
+ losses[f"loss_aux_layer_{layer_idx}"] = aux_loss
130
+ return losses
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VoRAForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_vora.VoRAConfig",
9
+ "AutoModelForCausalLM": "modeling_vora.VoRAForCausalLM"
10
+ },
11
+ "aux_vision": "/workspace/VoRAParse/output/jina-clip/image-encoder.pt",
12
+ "bos_token_id": 151643,
13
+ "eos_token_id": 151645,
14
+ "head_dim": 128,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "image_size": 512,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 12288,
20
+ "llm": "Qwen/Qwen3-8B",
21
+ "lora": {
22
+ "layers": 24,
23
+ "r": 1024,
24
+ "target_modules": [
25
+ "self_attn.q_proj",
26
+ "self_attn.k_proj",
27
+ "self_attn.v_proj",
28
+ "self_attn.o_proj",
29
+ "mlp.up_proj",
30
+ "mlp.gate_proj",
31
+ "mlp.down_proj"
32
+ ]
33
+ },
34
+ "max_position_embeddings": 40960,
35
+ "max_window_layers": 36,
36
+ "model_type": "vora",
37
+ "num_attention_heads": 32,
38
+ "num_hidden_layers": 36,
39
+ "num_key_value_heads": 8,
40
+ "patch_size": 14,
41
+ "reuse_aux_vision_embedding_layers": "",
42
+ "rms_norm_eps": 1e-06,
43
+ "rope_scaling": null,
44
+ "rope_theta": 1000000,
45
+ "skip_aux_cls": false,
46
+ "sliding_window": null,
47
+ "tie_word_embeddings": false,
48
+ "torch_dtype": "bfloat16",
49
+ "transformers_version": "4.51.3",
50
+ "use_cache": true,
51
+ "use_sliding_window": false,
52
+ "vision_attention_mask": "bidirectional",
53
+ "vision_embedding": "AIMv2Embedding",
54
+ "vision_embedding_intermediate_size": 1536,
55
+ "vocab_size": 151936
56
+ }
configuration_vora.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["VoRAConfig"]
6
+
7
+
8
+ class VoRAConfig(PretrainedConfig):
9
+ model_type = "vora"
10
+ _auto_class = "AutoConfig"
11
+
12
+ def __init__(
13
+ self,
14
+ llm: str = "",
15
+ aux_vision: str = "",
16
+ skip_aux_cls: bool = False,
17
+ reuse_aux_vision_embedding_layers: str = "",
18
+ lora: dict = {},
19
+ image_size: int = 448,
20
+ vision_embedding: str = "AIMv2",
21
+ vision_embedding_intermediate_size: int = 1536,
22
+ patch_size: int = 14,
23
+ vision_attention_mask: str = "bidirectional",
24
+ rms_norm_eps: float = 1e-5,
25
+ **kwargs: Any,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.llm = llm
29
+ self.aux_vision = aux_vision
30
+ self.skip_aux_cls = skip_aux_cls
31
+ self.reuse_aux_vision_embedding_layers = reuse_aux_vision_embedding_layers
32
+ self.lora = lora
33
+ self.image_size = image_size
34
+ self.vision_embedding = vision_embedding
35
+ self.vision_embedding_intermediate_size = vision_embedding_intermediate_size
36
+ self.patch_size = patch_size
37
+ self.vision_attention_mask = vision_attention_mask
38
+ self.rms_norm_eps = rms_norm_eps
eva_model.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import os
8
+ import warnings
9
+ from functools import partial
10
+ from easydict import EasyDict as edict
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as f
14
+ import loguru
15
+
16
+ try:
17
+ warnings.filterwarnings('ignore', category=FutureWarning, module='timm')
18
+ from timm.models.layers import drop_path as timm_drop_path
19
+ from timm.models.layers import to_2tuple, trunc_normal_
20
+ except ImportError or ModuleNotFoundError:
21
+ from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
22
+
23
+ from .rope_embeddings import VisionRotaryEmbeddingFast
24
+
25
+ if os.getenv('ENV_TYPE') == 'deepspeed':
26
+ try:
27
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
28
+ except ImportError or ModuleNotFoundError:
29
+ from torch.utils.checkpoint import checkpoint
30
+ else:
31
+ from torch.utils.checkpoint import checkpoint
32
+
33
+ try:
34
+ import xformers.ops as xops
35
+ except ImportError:
36
+ xops = None
37
+
38
+
39
+ class PatchDropout(nn.Module):
40
+ """
41
+ https://arxiv.org/abs/2212.00794
42
+ """
43
+
44
+ def __init__(self, prob, exclude_first_token=True):
45
+ super().__init__()
46
+ assert 0 <= prob < 1.0
47
+ self.prob = prob
48
+ self.exclude_first_token = exclude_first_token # exclude CLS token
49
+
50
+ def forward(self, x):
51
+ if not self.training or self.prob == 0.0:
52
+ return x
53
+
54
+ if self.exclude_first_token:
55
+ cls_tokens, x = x[:, :1], x[:, 1:]
56
+ else:
57
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
58
+
59
+ batch = x.size()[0]
60
+ num_tokens = x.size()[1]
61
+
62
+ batch_indices = torch.arange(batch)
63
+ batch_indices = batch_indices[..., None]
64
+
65
+ keep_prob = 1 - self.prob
66
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
67
+
68
+ rand = torch.randn(batch, num_tokens)
69
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
70
+
71
+ x = x[batch_indices, patch_indices_keep]
72
+
73
+ if self.exclude_first_token:
74
+ x = torch.cat((cls_tokens, x), dim=1)
75
+
76
+ return x, patch_indices_keep
77
+
78
+
79
+ class DropPath(nn.Module):
80
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
81
+ residual blocks)."""
82
+
83
+ def __init__(self, drop_prob=None):
84
+ super(DropPath, self).__init__()
85
+ self.drop_prob = drop_prob
86
+
87
+ def forward(self, x):
88
+ return timm_drop_path(x, self.drop_prob, self.training)
89
+
90
+ def extra_repr(self) -> str:
91
+ return 'p={}'.format(self.drop_prob)
92
+
93
+
94
+ class Mlp(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_features,
98
+ hidden_features=None,
99
+ out_features=None,
100
+ act_layer=nn.GELU,
101
+ norm_layer=nn.LayerNorm,
102
+ drop=0.0,
103
+ subln=False,
104
+ ):
105
+ super().__init__()
106
+ out_features = out_features or in_features
107
+ hidden_features = hidden_features or in_features
108
+ self.fc1 = nn.Linear(in_features, hidden_features)
109
+ self.act = act_layer()
110
+
111
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
112
+
113
+ self.fc2 = nn.Linear(hidden_features, out_features)
114
+ self.drop = nn.Dropout(drop)
115
+
116
+ def forward(self, x):
117
+ x = self.fc1(x)
118
+ x = self.act(x)
119
+ # x = self.drop(x)
120
+ # commit this for the orignal BERT implement
121
+ x = self.ffn_ln(x)
122
+
123
+ x = self.fc2(x)
124
+ x = self.drop(x)
125
+ return x
126
+
127
+
128
+ class SwiGLU(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_features,
132
+ hidden_features=None,
133
+ out_features=None,
134
+ act_layer=nn.SiLU,
135
+ drop=0.0,
136
+ norm_layer=nn.LayerNorm,
137
+ subln=False,
138
+ ):
139
+ super().__init__()
140
+ out_features = out_features or in_features
141
+ hidden_features = hidden_features or in_features
142
+
143
+ self.w1 = nn.Linear(in_features, hidden_features)
144
+ self.w2 = nn.Linear(in_features, hidden_features)
145
+
146
+ self.act = act_layer()
147
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
148
+ self.w3 = nn.Linear(hidden_features, out_features)
149
+
150
+ self.drop = nn.Dropout(drop)
151
+
152
+ def forward(self, x):
153
+ x1 = self.w1(x)
154
+ x2 = self.w2(x)
155
+ hidden = self.act(x1) * x2
156
+ x = self.ffn_ln(hidden)
157
+ x = self.w3(x)
158
+ x = self.drop(x)
159
+ return x
160
+
161
+
162
+ class Attention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim,
166
+ num_heads=8,
167
+ qkv_bias=False,
168
+ qk_scale=None,
169
+ attn_drop=0.0,
170
+ proj_drop=0.0,
171
+ window_size=None,
172
+ attn_head_dim=None,
173
+ xattn=False,
174
+ rope=None,
175
+ subln=False,
176
+ norm_layer=nn.LayerNorm,
177
+ ):
178
+ super().__init__()
179
+ self.num_heads = num_heads
180
+ head_dim = dim // num_heads
181
+ if attn_head_dim is not None:
182
+ head_dim = attn_head_dim
183
+ all_head_dim = head_dim * self.num_heads
184
+ self.scale = qk_scale or head_dim**-0.5
185
+
186
+ self.subln = subln
187
+ if self.subln:
188
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
189
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
190
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
191
+ else:
192
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
193
+
194
+ if qkv_bias:
195
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
196
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
197
+ else:
198
+ self.q_bias = None
199
+ self.v_bias = None
200
+
201
+ if window_size:
202
+ self.window_size = window_size
203
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
204
+ 2 * window_size[1] - 1
205
+ ) + 3
206
+ self.relative_position_bias_table = nn.Parameter(
207
+ torch.zeros(self.num_relative_distance, num_heads)
208
+ ) # 2*Wh-1 * 2*Ww-1, nH
209
+ # cls to token & token 2 cls & cls to cls
210
+
211
+ # get pair-wise relative position index for each token inside the window
212
+ coords_h = torch.arange(window_size[0])
213
+ coords_w = torch.arange(window_size[1])
214
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
215
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
216
+ relative_coords = (
217
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
218
+ ) # 2, Wh*Ww, Wh*Ww
219
+ relative_coords = relative_coords.permute(
220
+ 1, 2, 0
221
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
222
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
223
+ relative_coords[:, :, 1] += window_size[1] - 1
224
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
225
+ relative_position_index = torch.zeros(
226
+ size=(window_size[0] * window_size[1] + 1,) * 2,
227
+ dtype=relative_coords.dtype,
228
+ )
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer('relative_position_index', relative_position_index)
235
+ else:
236
+ self.window_size = None
237
+ self.relative_position_bias_table = None
238
+ self.relative_position_index = None
239
+
240
+ self.attn_drop = nn.Dropout(attn_drop)
241
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
242
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
243
+ self.proj = nn.Linear(all_head_dim, dim)
244
+ self.proj_drop = nn.Dropout(proj_drop)
245
+ self.xattn = xattn
246
+ self.xattn_drop = attn_drop
247
+
248
+ self.rope = rope
249
+
250
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
251
+ b, n, _ = x.shape
252
+ if self.subln:
253
+ q = f.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
254
+ k = f.linear(input=x, weight=self.k_proj.weight, bias=None)
255
+ v = f.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
256
+
257
+ q = q.reshape(b, n, self.num_heads, -1).permute(
258
+ 0, 2, 1, 3
259
+ ) # B, num_heads, N, C
260
+ k = k.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
261
+ v = v.reshape(b, n, self.num_heads, -1).permute(0, 2, 1, 3)
262
+ else:
263
+ qkv_bias = None
264
+ if self.q_bias is not None:
265
+ qkv_bias = torch.cat(
266
+ (
267
+ self.q_bias,
268
+ torch.zeros_like(self.v_bias, requires_grad=False),
269
+ self.v_bias,
270
+ )
271
+ )
272
+
273
+ qkv = f.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
274
+ qkv = qkv.reshape(b, n, 3, self.num_heads, -1).permute(
275
+ 2, 0, 3, 1, 4
276
+ ) # 3, B, num_heads, N, C
277
+ q, k, v = qkv[0], qkv[1], qkv[2]
278
+
279
+ if self.rope:
280
+ # slightly fast impl
281
+ q_t = q[:, :, 1:, :]
282
+ ro_q_t = self.rope(q_t)
283
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
284
+
285
+ k_t = k[:, :, 1:, :]
286
+ ro_k_t = self.rope(k_t)
287
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
288
+
289
+ if self.xattn:
290
+ if xops is None:
291
+ raise ValueError(
292
+ "Can't use xattn without xformers. Please 'pip install xformers'"
293
+ )
294
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
295
+ k = k.permute(0, 2, 1, 3)
296
+ v = v.permute(0, 2, 1, 3)
297
+
298
+ x = xops.memory_efficient_attention(
299
+ q,
300
+ k,
301
+ v,
302
+ p=self.xattn_drop,
303
+ scale=self.scale,
304
+ )
305
+ x = x.reshape(b, n, -1)
306
+ x = self.inner_attn_ln(x)
307
+ x = self.proj(x)
308
+ x = self.proj_drop(x)
309
+ else:
310
+ q = q * self.scale
311
+ attn = q @ k.transpose(-2, -1)
312
+
313
+ if self.relative_position_bias_table is not None:
314
+ relative_position_bias = self.relative_position_bias_table[
315
+ self.relative_position_index.view(-1)
316
+ ].view(
317
+ self.window_size[0] * self.window_size[1] + 1,
318
+ self.window_size[0] * self.window_size[1] + 1,
319
+ -1,
320
+ ) # Wh*Ww,Wh*Ww,nH
321
+ relative_position_bias = relative_position_bias.permute(
322
+ 2, 0, 1
323
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
324
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
325
+
326
+ if rel_pos_bias is not None:
327
+ attn = attn + rel_pos_bias.type_as(attn)
328
+
329
+ if attn_mask is not None:
330
+ attn_mask = attn_mask.bool()
331
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float('-inf'))
332
+
333
+ attn = attn.softmax(dim=-1)
334
+ attn = self.attn_drop(attn)
335
+
336
+ x = (attn @ v).transpose(1, 2).reshape(b, n, -1)
337
+ x = self.inner_attn_ln(x)
338
+ x = self.proj(x)
339
+ x = self.proj_drop(x)
340
+ return x
341
+
342
+
343
+ class Block(nn.Module):
344
+ def __init__(
345
+ self,
346
+ dim,
347
+ num_heads,
348
+ mlp_ratio=4.0,
349
+ qkv_bias=False,
350
+ qk_scale=None,
351
+ drop=0.0,
352
+ attn_drop=0.0,
353
+ drop_path=0.0,
354
+ init_values=None,
355
+ act_layer=nn.GELU,
356
+ norm_layer=nn.LayerNorm,
357
+ window_size=None,
358
+ attn_head_dim=None,
359
+ xattn=False,
360
+ rope=None,
361
+ postnorm=False,
362
+ subln=False,
363
+ naiveswiglu=False,
364
+ ):
365
+ super().__init__()
366
+ self.norm1 = norm_layer(dim)
367
+ self.attn = Attention(
368
+ dim,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ qk_scale=qk_scale,
372
+ attn_drop=attn_drop,
373
+ proj_drop=drop,
374
+ window_size=window_size,
375
+ attn_head_dim=attn_head_dim,
376
+ xattn=xattn,
377
+ rope=rope,
378
+ subln=subln,
379
+ norm_layer=norm_layer,
380
+ )
381
+ # NOTE: drop path for stochastic depth, we shall see if this is better
382
+ # than dropout here
383
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
384
+ self.norm2 = norm_layer(dim)
385
+ mlp_hidden_dim = int(dim * mlp_ratio)
386
+
387
+ if naiveswiglu:
388
+ self.mlp = SwiGLU(
389
+ in_features=dim,
390
+ hidden_features=mlp_hidden_dim,
391
+ subln=subln,
392
+ norm_layer=norm_layer,
393
+ )
394
+ else:
395
+ self.mlp = Mlp(
396
+ in_features=dim,
397
+ hidden_features=mlp_hidden_dim,
398
+ act_layer=act_layer,
399
+ subln=subln,
400
+ drop=drop,
401
+ )
402
+
403
+ if init_values is not None and init_values > 0:
404
+ self.gamma_1 = nn.Parameter(
405
+ init_values * torch.ones((dim,)), requires_grad=True
406
+ )
407
+ self.gamma_2 = nn.Parameter(
408
+ init_values * torch.ones((dim,)), requires_grad=True
409
+ )
410
+ else:
411
+ self.gamma_1, self.gamma_2 = None, None
412
+
413
+ self.postnorm = postnorm
414
+
415
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
416
+ if self.gamma_1 is None:
417
+ if self.postnorm:
418
+ x = x + self.drop_path(
419
+ self.norm1(
420
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
421
+ )
422
+ )
423
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
424
+ else:
425
+ x = x + self.drop_path(
426
+ self.attn(
427
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
428
+ )
429
+ )
430
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
431
+ else:
432
+ if self.postnorm:
433
+ x = x + self.drop_path(
434
+ self.gamma_1
435
+ * self.norm1(
436
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
437
+ )
438
+ )
439
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
440
+ else:
441
+ x = x + self.drop_path(
442
+ self.gamma_1
443
+ * self.attn(
444
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
445
+ )
446
+ )
447
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
448
+ return x
449
+
450
+
451
+ class PatchEmbed(nn.Module):
452
+ """Image to Patch Embedding"""
453
+
454
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
455
+ super().__init__()
456
+ img_size = to_2tuple(img_size)
457
+ patch_size = to_2tuple(patch_size)
458
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
459
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
460
+ self.img_size = img_size
461
+ self.patch_size = patch_size
462
+ self.num_patches = num_patches
463
+
464
+ self.proj = nn.Conv2d(
465
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
466
+ )
467
+
468
+ def forward(self, x, **_):
469
+ target_dtype = self.proj.weight.dtype
470
+ _, __, h, w = x.shape
471
+ # FIXME look at relaxing size constraints
472
+ assert h == self.img_size[0] and w == self.img_size[1], (
473
+ f"Input image size ({h}*{w}) doesn't match model "
474
+ f'({self.img_size[0]}*{self.img_size[1]}).'
475
+ )
476
+ x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
477
+ return x
478
+
479
+
480
+ class RelativePositionBias(nn.Module):
481
+ def __init__(self, window_size, num_heads):
482
+ super().__init__()
483
+ self.window_size = window_size
484
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
485
+ 2 * window_size[1] - 1
486
+ ) + 3
487
+ self.relative_position_bias_table = nn.Parameter(
488
+ torch.zeros(self.num_relative_distance, num_heads)
489
+ ) # 2*Wh-1 * 2*Ww-1, nH
490
+ # cls to token & token 2 cls & cls to cls
491
+
492
+ # get pair-wise relative position index for each token inside the window
493
+ coords_h = torch.arange(window_size[0])
494
+ coords_w = torch.arange(window_size[1])
495
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
496
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
497
+ relative_coords = (
498
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
499
+ ) # 2, Wh*Ww, Wh*Ww
500
+ relative_coords = relative_coords.permute(
501
+ 1, 2, 0
502
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
503
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
504
+ relative_coords[:, :, 1] += window_size[1] - 1
505
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
506
+ relative_position_index = torch.zeros(
507
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
508
+ )
509
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
510
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
511
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
512
+ relative_position_index[0, 0] = self.num_relative_distance - 1
513
+
514
+ self.register_buffer('relative_position_index', relative_position_index)
515
+
516
+ def forward(self):
517
+ relative_position_bias = self.relative_position_bias_table[
518
+ self.relative_position_index.view(-1)
519
+ ].view(
520
+ self.window_size[0] * self.window_size[1] + 1,
521
+ self.window_size[0] * self.window_size[1] + 1,
522
+ -1,
523
+ ) # Wh*Ww,Wh*Ww,nH
524
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
525
+
526
+
527
+ class EVAVisionTransformer(nn.Module):
528
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
529
+
530
+ def __init__(
531
+ self,
532
+ img_size=224,
533
+ patch_size=16,
534
+ in_chans=3,
535
+ num_classes=0,
536
+ embed_dim=768,
537
+ depth=12,
538
+ num_heads=12,
539
+ mlp_ratio=4.0,
540
+ qkv_bias=False,
541
+ qk_scale=None,
542
+ drop_rate=0.0,
543
+ attn_drop_rate=0.0,
544
+ drop_path_rate=0.0,
545
+ norm_layer=nn.LayerNorm,
546
+ init_values=None,
547
+ patch_dropout=0.0,
548
+ use_abs_pos_emb=True,
549
+ use_rel_pos_bias=False,
550
+ use_shared_rel_pos_bias=False,
551
+ rope=False,
552
+ use_mean_pooling=True,
553
+ init_scale=0.001,
554
+ grad_checkpointing=False,
555
+ xattn=False,
556
+ postnorm=False,
557
+ pt_hw_seq_len=16,
558
+ intp_freq=False,
559
+ naiveswiglu=False,
560
+ subln=False,
561
+ proj_type=None,
562
+ ):
563
+ super().__init__()
564
+ self.image_size = img_size
565
+ self.num_classes = num_classes
566
+ # num_features for consistency with other models
567
+ self.num_features = self.embed_dim = embed_dim
568
+
569
+ self.patch_embed = PatchEmbed(
570
+ img_size=img_size,
571
+ patch_size=patch_size,
572
+ in_chans=in_chans,
573
+ embed_dim=embed_dim,
574
+ )
575
+ num_patches = self.patch_embed.num_patches
576
+
577
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
578
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
579
+ if use_abs_pos_emb:
580
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
581
+ else:
582
+ self.pos_embed = None
583
+ self.pos_drop = nn.Dropout(p=drop_rate)
584
+
585
+ if use_shared_rel_pos_bias:
586
+ self.rel_pos_bias = RelativePositionBias(
587
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
588
+ )
589
+ else:
590
+ self.rel_pos_bias = None
591
+
592
+ if rope:
593
+ half_head_dim = embed_dim // num_heads // 2
594
+ hw_seq_len = img_size // patch_size
595
+ self.rope = VisionRotaryEmbeddingFast(
596
+ dim=half_head_dim,
597
+ pt_seq_len=pt_hw_seq_len,
598
+ ft_seq_len=hw_seq_len if intp_freq else None,
599
+ patch_dropout=patch_dropout,
600
+ )
601
+ else:
602
+ self.rope = None
603
+
604
+ self.naiveswiglu = naiveswiglu
605
+
606
+ dpr = [
607
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
608
+ ] # stochastic depth decay rule
609
+ self.use_rel_pos_bias = use_rel_pos_bias
610
+ self.blocks = nn.ModuleList(
611
+ [
612
+ Block(
613
+ dim=embed_dim,
614
+ num_heads=num_heads,
615
+ mlp_ratio=mlp_ratio,
616
+ qkv_bias=qkv_bias,
617
+ qk_scale=qk_scale,
618
+ drop=drop_rate,
619
+ attn_drop=attn_drop_rate,
620
+ drop_path=dpr[i],
621
+ norm_layer=norm_layer,
622
+ init_values=init_values,
623
+ window_size=self.patch_embed.patch_shape
624
+ if use_rel_pos_bias
625
+ else None,
626
+ xattn=xattn,
627
+ rope=self.rope,
628
+ postnorm=postnorm,
629
+ subln=subln,
630
+ naiveswiglu=naiveswiglu,
631
+ )
632
+ for i in range(depth)
633
+ ]
634
+ )
635
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
636
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
637
+ if (num_classes == embed_dim) and (proj_type is None):
638
+ self.head = nn.Identity()
639
+ elif proj_type == 'linear':
640
+ self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias)
641
+ elif proj_type == 'mlp':
642
+ hidden_size = (embed_dim + num_classes) // 2
643
+ self.proj = nn.Sequential(
644
+ nn.Linear(embed_dim, hidden_size, bias=qkv_bias),
645
+ nn.GELU(),
646
+ nn.Linear(hidden_size, num_classes, bias=qkv_bias),
647
+ )
648
+
649
+ if self.pos_embed is not None:
650
+ trunc_normal_(self.pos_embed, std=0.02)
651
+
652
+ trunc_normal_(self.cls_token, std=0.02)
653
+
654
+ self.apply(self._init_weights)
655
+ self.fix_init_weight()
656
+
657
+ if isinstance(self.head, nn.Linear):
658
+ trunc_normal_(self.head.weight, std=0.02)
659
+ self.head.weight.data.mul_(init_scale)
660
+ if qkv_bias:
661
+ self.head.bias.data.mul_(init_scale)
662
+
663
+ # setting a patch_dropout of 0. would mean it is disabled and this function
664
+ # would be the identity fn
665
+ self.patch_dropout = (
666
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
667
+ )
668
+
669
+ self.grad_checkpointing = grad_checkpointing
670
+
671
+ def fix_init_weight(self):
672
+ def rescale(param, _layer_id):
673
+ param.div_(math.sqrt(2.0 * _layer_id))
674
+
675
+ for layer_id, layer in enumerate(self.blocks):
676
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
677
+ if self.naiveswiglu:
678
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
679
+ else:
680
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
681
+
682
+ def get_cast_dtype(self) -> torch.dtype:
683
+ return self.blocks[0].mlp.fc2.weight.dtype
684
+
685
+ @staticmethod
686
+ def _init_weights(m):
687
+ if isinstance(m, nn.Linear):
688
+ trunc_normal_(m.weight, std=0.02)
689
+ if m.bias is not None:
690
+ nn.init.constant_(m.bias, 0)
691
+ elif isinstance(m, nn.LayerNorm):
692
+ nn.init.constant_(m.bias, 0)
693
+ nn.init.constant_(m.weight, 1.0)
694
+
695
+ def get_num_layers(self):
696
+ return len(self.blocks)
697
+
698
+ def lock(self, unlocked_groups=0, *_, **__):
699
+ assert (
700
+ unlocked_groups == 0
701
+ ), 'partial locking not currently supported for this model'
702
+ for param in self.parameters():
703
+ param.requires_grad = False
704
+
705
+ @torch.jit.ignore
706
+ def set_grad_checkpointing(self, enable=True):
707
+ self.grad_checkpointing = enable
708
+
709
+ @torch.jit.ignore
710
+ def no_weight_decay(self):
711
+ return {'pos_embed', 'cls_token'}
712
+
713
+ def get_classifier(self):
714
+ return self.head
715
+
716
+ def reset_classifier(self, num_classes, *_, **__):
717
+ self.num_classes = num_classes
718
+ self.head = (
719
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
720
+ )
721
+
722
+ def forward_features(self, x, return_all_features=False):
723
+ x = self.patch_embed(x)
724
+ batch_size, seq_len, _ = x.size()
725
+
726
+ cls_tokens = self.cls_token.expand(
727
+ batch_size, -1, -1
728
+ ) # stole cls_tokens impl from Phil Wang, thanks
729
+ x = torch.cat((cls_tokens, x), dim=1)
730
+ if self.pos_embed is not None:
731
+ x = x + self.pos_embed
732
+ x = self.pos_drop(x)
733
+
734
+ # a patch_dropout of 0. would mean it is disabled and this function would do
735
+ # nothing but return what was passed in
736
+ if self.rope is not None:
737
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
738
+ x, patch_indices_keep = self.patch_dropout(x)
739
+ self.rope.forward = partial(
740
+ self.rope.forward, patch_indices_keep=patch_indices_keep
741
+ )
742
+ else:
743
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
744
+ x = self.patch_dropout(x)
745
+ else:
746
+ x = self.patch_dropout(x)
747
+
748
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
749
+ hidden_states = [x[:, 1:]]
750
+ for blk in self.blocks:
751
+ if self.grad_checkpointing:
752
+ x = checkpoint(blk, x, (rel_pos_bias,))
753
+ else:
754
+ x = blk(x, rel_pos_bias=rel_pos_bias)
755
+
756
+ hidden_states.append(x[:, 1:])
757
+
758
+ return edict(
759
+ {
760
+ 'hidden_states': hidden_states,
761
+ 'last_hidden_state': x,
762
+ 'cls_token': x[:, 0],
763
+ }
764
+ )
765
+
766
+ def forward(self, x, return_all_features=False, **kwargs):
767
+
768
+ return self.forward_features(x, return_all_features=return_all_features)
latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step5200
lora.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import types
3
+ import math
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ QWEN2_TARGET_MODULES = [
9
+ "self_attn.q_proj",
10
+ "self_attn.k_proj",
11
+ "self_attn.v_proj",
12
+ "self_attn.o_proj",
13
+ "mlp.up_proj",
14
+ "mlp.gate_proj",
15
+ "mlp.down_proj",
16
+ ]
17
+
18
+
19
+ class LoRALayer(nn.Linear):
20
+ def __init__(
21
+ self,
22
+ in_features: int,
23
+ out_features: int,
24
+ r: int = 1024,
25
+ **kwargs
26
+ ):
27
+ nn.Linear.__init__(self, in_features, out_features)
28
+ # we elimate lora_alpha here bc we find it unnecessary in VoRA
29
+ if r < 0:
30
+ self.forward = self.naive_forward
31
+ else:
32
+ self.lora_A = nn.Linear(in_features, r, bias=False)
33
+ self.lora_B = nn.Linear(r, out_features, bias=False)
34
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
35
+ nn.init.zeros_(self.lora_B.weight)
36
+
37
+ def forward(self, x: torch.Tensor):
38
+ intermediate = F.linear(x, self.weight, bias=self.bias)
39
+ result = intermediate + self.lora_B(self.lora_A(x))
40
+ return result
41
+
42
+ def naive_forward(self, x: torch.Tensor):
43
+ return F.linear(x, self.weight, bias=self.bias)
44
+
45
+
46
+ def _get_submodules(self, key):
47
+ parent = self.get_submodule(".".join(key.split(".")[:-1]))
48
+ target_name = key.split(".")[-1]
49
+ target = self.get_submodule(key)
50
+ return parent, target, target_name
51
+
52
+
53
+ def _find_and_replace(self, lora_params):
54
+ target_modules = lora_params["target_modules"]
55
+
56
+ for llm_module_name in target_modules:
57
+ parent, target, target_name = self._get_submodules(llm_module_name)
58
+ bias = target.bias is not None
59
+ vora_layer = LoRALayer(
60
+ target.in_features,
61
+ target.out_features,
62
+ bias=bias,
63
+ **lora_params
64
+ )
65
+ self._replace_module(parent, target_name, vora_layer, target)
66
+
67
+
68
+ def _replace_module(self, parent_module, child_name, new_module, old_module):
69
+ setattr(parent_module, child_name, new_module)
70
+ new_module.weight = old_module.weight
71
+ if old_module.bias is not None:
72
+ new_module.bias = old_module.bias
73
+ if getattr(old_module, "state", None) is not None:
74
+ new_module.state = old_module.state
75
+ new_module.to(old_module.weight.device)
76
+
77
+
78
+ def apply_lora(llm, lora_params={"layers": "all", "r": 1024, "target_modules": QWEN2_TARGET_MODULES}):
79
+ llm_num_layers = llm.config.num_hidden_layers
80
+ total_layers = lora_params.get("layers", "all")
81
+
82
+ # -------------------- validation check ---------------------
83
+ if isinstance(total_layers, str):
84
+ if total_layers.lower() == "all":
85
+ total_layers = list(range(llm_num_layers))
86
+ else:
87
+ assert isinstance(total_layers, int), "total_layers must be an integer or 'all'"
88
+ total_layers = list(range(total_layers))
89
+ # -------------------- validation check ---------------------
90
+
91
+ # -------------------- replace llm layers ---------------------
92
+ for i in total_layers:
93
+ llm_layer = llm.model.layers[i]
94
+ llm_layer._get_submodules = types.MethodType(_get_submodules, llm_layer)
95
+ llm_layer._find_and_replace = types.MethodType(_find_and_replace, llm_layer)
96
+ llm_layer._replace_module = types.MethodType(_replace_module, llm_layer)
97
+ llm_layer._find_and_replace(lora_params)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ from transformers import LlamaForCausalLM, CLIPVisionModel, AutoModel
102
+ llama = LlamaForCausalLM.from_pretrained("/mnt/bn/wh-data/data/models/llama2_7b_hf_chat")
103
+ apply_lora(llama)
104
+ print(llama)
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3044c43e513388c043d65715ff65b2bb698b1ecb8e0f4f7473cd43c7b66a82b9
3
+ size 4999235192
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1874fdb1f01ba18924426aaed32499e136e161c4cdb2d6bb556ef9b4c4d4cee
3
+ size 4992139816
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81b074b3c6706d5c560b3d5801abd9b3dace504c2ec352f886635ab2314cc415
3
+ size 4912324776
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c78ab40f2c92347e41b819c62d45080d4263e2e69070d1fbea700387a59a835
3
+ size 3959630128
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f737b07b8709cad6da08c5266d82e93a7b2b467649d836ff6288038cf9bc94f
3
+ size 2073391616
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_vora.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ PreTrainedModel,
7
+ PretrainedConfig,
8
+ )
9
+
10
+ import loguru
11
+ from .attention_mask import make_mask
12
+ from .configuration_vora import VoRAConfig
13
+ from .vision_embedding import * # hacking, let transformers find vision_embedding
14
+ from . import vision_embedding as VB
15
+ from .lora import apply_lora
16
+ from .vora_generation_utils import (
17
+ VoraGenerationMixin,
18
+ custom_prepare_4d_causal_attention_mask_with_cache_position,
19
+ )
20
+
21
+ try:
22
+ from utils import logging
23
+ except:
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class VoRAForCausalLM(PreTrainedModel):
31
+ config_class = VoRAConfig
32
+ _auto_class = 'AutoModelForCausalLM'
33
+ supports_gradient_checkpointing = True
34
+ supports_report_metrics: bool = True
35
+
36
+ def __init__(self, config: PretrainedConfig = VoRAConfig()):
37
+ super().__init__(config)
38
+ self.config = config
39
+ # -------------- Setup LLM ---------------------
40
+ self.llm = AutoModelForCausalLM.from_pretrained(config.llm)
41
+ self.tokenizer = AutoTokenizer.from_pretrained(config.llm)
42
+ self.llm.__class__ = type(self.llm.__class__.__name__, (self.llm.__class__, VoraGenerationMixin), {})
43
+ self.llm.model._prepare_4d_causal_attention_mask_with_cache_position = staticmethod(custom_prepare_4d_causal_attention_mask_with_cache_position)
44
+
45
+ self.config.update(self.llm.config.to_dict())
46
+
47
+ # -------------- Setup LoRA -------------------
48
+ if config.lora:
49
+ for _, param in self.llm.named_parameters():
50
+ param.requires_grad = False
51
+ apply_lora(self.llm, config.lora)
52
+ # ----------------------------------------------
53
+
54
+ # ------------ Setup Vision Embedding ----------
55
+ self.vision_embedding = getattr(VB, config.vision_embedding)(self.config) # setup after llm so that we know the hiddensize
56
+ # ----------------------------------------------
57
+
58
+ # ------------- Setup Aux Vision ---------------
59
+ self.enable_aux_vision = False
60
+ if config.aux_vision:
61
+ from .aux_vision import AuxVision
62
+ self.enable_aux_vision = True
63
+ self.aux_vision = AuxVision(self.config)
64
+ if config.reuse_aux_vision_embedding_layers:
65
+ weights = getattr(self.aux_vision.aux_model, config.reuse_aux_vision_embedding_layers).state_dict()
66
+ msg = self.vision_embedding.load_state_dict(weights, strict=False)
67
+ msg = self.vision_embedding.patchifier.load_state_dict(weights, strict=False)
68
+ logger.info(f"Loaded aux vision weights: {msg}")
69
+ # ----------------------------------------------
70
+ # print trainable prameters and total parameters so that we can check if we are loading the correct model
71
+ logger.info("Trainable parameters:")
72
+ for name, param in self.named_parameters():
73
+ if param.requires_grad:
74
+ logger.info(f"{name}: {param.numel()}")
75
+ logger.info(f"Total parameters: {sum(p.numel() for p in self.parameters())}")
76
+
77
+ def detach_and_gather_loss(self, loss, dtype, device):
78
+ if not dist.is_initialized():
79
+ return loss.item()
80
+ gathered_loss = [torch.tensor(0.0, dtype=loss.dtype).to(device) for _ in range(dist.get_world_size())]
81
+ dist.all_gather(gathered_loss, loss.detach().clone())
82
+ avg_gathered_loss = torch.mean(torch.stack(gathered_loss))
83
+ return avg_gathered_loss.item()
84
+
85
+ def _encode_vision(self, images, n_frames):
86
+ # TODO: we need a more elegant way here to deal with mixed image and pure text training
87
+ if images.size(0) > 0:
88
+ vision_embeds = self.vision_embedding(images)
89
+ else:
90
+ # FIXME: hacking for deepspeed training
91
+ # we feed a dummy image tensor (1, 3, H, W) into vision_encoder when training a pure-text batch
92
+ images = images.new_zeros((1, *images.shape[1:]))
93
+ vision_embeds = self.vision_embedding(images)[0:0]
94
+ vision_embeds = vision_embeds.split(n_frames, dim=0)
95
+ attention_mask = [torch.ones(feature.size()[:-1], dtype=torch.long).to(feature.device) for feature in vision_embeds]
96
+ vision_targets = [torch.ones(feature.size(), dtype=torch.long).to(feature.device).fill_(-100) for feature in attention_mask]
97
+
98
+ image_shapes = images.shape[-2:]
99
+
100
+ return vision_embeds, attention_mask, vision_targets, image_shapes
101
+
102
+ def _concat_embedding(self, vision_encode_out, batch, vision_placeholder_index, left_padding=False):
103
+ """ concat vision and text
104
+ """
105
+
106
+ vision_embeds, vision_atts, vision_targets, _ = vision_encode_out
107
+
108
+ input_embeds = []
109
+ attention_mask = []
110
+ targets = []
111
+ vision_mask = [] # set vision token as 1, text token as 0
112
+
113
+ for cur_batch_idx, cur_input_ids in enumerate(batch["input_ids"]):
114
+ cur_vision_embeds = vision_embeds[cur_batch_idx]
115
+ cur_vision_attn = vision_atts[cur_batch_idx]
116
+ cur_vision_targets = vision_targets[cur_batch_idx]
117
+ cur_attn_masks = batch["attention_mask"][cur_batch_idx]
118
+
119
+ image_token_indices = torch.where(cur_input_ids == vision_placeholder_index)[0]
120
+ cur_image_num = len(image_token_indices)
121
+ image_token_indices = list(image_token_indices) + [cur_input_ids.shape[0]]
122
+
123
+ cur_input_embeds = []
124
+ cur_attention_mask = []
125
+ cur_target = []
126
+ cur_vision_mask = []
127
+
128
+ # convert text before 1st <image> to embedding
129
+ image_token_index = image_token_indices[0]
130
+
131
+ cur_input_embeds.append(
132
+ self.llm.get_input_embeddings()(cur_input_ids[:image_token_index]),
133
+ )
134
+ cur_attention_mask.append(
135
+ cur_attn_masks[:image_token_index],
136
+ )
137
+ cur_vision_mask.append(
138
+ torch.zeros_like(cur_attn_masks[:image_token_index]).to(cur_attn_masks.device),
139
+ )
140
+ if "labels" in batch:
141
+ cur_target.append(
142
+ batch["labels"][cur_batch_idx, :image_token_index],
143
+ )
144
+
145
+ if batch.get("vison_placeholder_mode", 0) == 1:
146
+ assert cur_image_num <= 1, "multiple video input is not supported"
147
+ cur_vision_embeds = cur_vision_embeds.unsqueeze(0)
148
+ cur_vision_attn = cur_vision_attn.unsqueeze(0)
149
+ cur_vision_targets = cur_vision_targets.unsqueeze(0)
150
+ assert cur_image_num == len(cur_vision_embeds), \
151
+ f"Size mismatch! cur_image_num: {cur_image_num}, len(cur_vision_embeds): {len(cur_vision_embeds)} {len(cur_vision_embeds)} \
152
+ in {batch['prompt'][cur_batch_idx]} & {batch['gt'][cur_batch_idx]} & {batch['input_ids'][cur_batch_idx]}"
153
+ # convert each <image> xxx group into embedding
154
+ text_embedding = self.llm.get_input_embeddings()(cur_input_ids.relu())
155
+ for i in range(0, cur_image_num):
156
+ image_token_index = image_token_indices[i]
157
+ cur_input_embeds.extend([
158
+ cur_vision_embeds[i],
159
+ text_embedding[image_token_index+1:image_token_indices[i+1]]
160
+ ])
161
+ cur_attention_mask.extend([
162
+ cur_vision_attn[i],
163
+ cur_attn_masks[image_token_index+1:image_token_indices[i+1]]
164
+ ])
165
+ cur_vision_mask.extend([
166
+ torch.ones_like(cur_vision_attn[i]).to(cur_vision_attn[i].device),
167
+ torch.zeros_like(cur_attn_masks[image_token_index+1:image_token_indices[i+1]]).to(cur_vision_attn[i].device),
168
+ ])
169
+ if "labels" in batch:
170
+ cur_target.extend([
171
+ cur_vision_targets[i],
172
+ batch["labels"][cur_batch_idx, image_token_index+1:image_token_indices[i+1]],
173
+ ])
174
+
175
+ input_embeds.append(torch.cat(cur_input_embeds))
176
+ attention_mask.append(torch.cat(cur_attention_mask))
177
+ vision_mask.append(torch.cat(cur_vision_mask))
178
+ if "labels" in batch:
179
+ targets.append(torch.cat(cur_target))
180
+
181
+ # padding
182
+ n_tokens = [embed.shape[0] for embed in input_embeds]
183
+
184
+ max_token = max(n_tokens)
185
+
186
+ for i in range(len(input_embeds)):
187
+ if max_token > n_tokens[i]:
188
+ self.pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
189
+ pad_token = torch.tensor([self.pad_id] * (max_token - n_tokens[i]))
190
+ pad_embedding = self.llm.get_input_embeddings()(pad_token.to(batch["attention_mask"][i].device))
191
+ pad_attention = torch.zeros(pad_embedding.shape[0], dtype=torch.long).to(batch["attention_mask"][i].device)
192
+ pad_targets = torch.ones(pad_attention.size(), dtype=torch.long).to(batch["attention_mask"][i].device).fill_(-100)
193
+
194
+ if left_padding:
195
+ input_embeds[i] = torch.cat([pad_embedding, input_embeds[i]])
196
+ attention_mask[i] = torch.cat([pad_attention, attention_mask[i]])
197
+ vision_mask[i] = torch.cat([pad_attention, vision_mask[i]])
198
+ if "labels" in batch:
199
+ targets[i] = torch.cat([pad_targets, targets[i]])
200
+ else:
201
+ input_embeds[i] = torch.cat([input_embeds[i], pad_embedding])
202
+ attention_mask[i] = torch.cat([attention_mask[i], pad_attention])
203
+ vision_mask[i] = torch.cat([vision_mask[i], pad_attention])
204
+ if "labels" in batch:
205
+ targets[i] = torch.cat([targets[i], pad_targets])
206
+
207
+ inputs_embeds = torch.stack(input_embeds, dim=0).type(self.llm.dtype)
208
+ attention_mask = torch.stack(attention_mask, dim=0)
209
+ vision_mask = torch.stack(vision_mask, dim=0).to(attention_mask.device)
210
+
211
+ if len(targets) > 0:
212
+ targets = torch.stack(targets, dim=0)
213
+
214
+ attention_mask = make_mask(
215
+ attention_mask,
216
+ mode=self.config.vision_attention_mask,
217
+ vision_mask=vision_mask,
218
+ dtype=inputs_embeds.dtype
219
+ )
220
+
221
+ return inputs_embeds, attention_mask, targets, vision_mask
222
+
223
+ def forward(self, **batch):
224
+ # -------------- Vision/Text Embedding ----------
225
+ vision_placeholder_index = batch.pop("vision_placeholder_index")
226
+ images, n_frames = batch["frames"], batch["n_frames"]
227
+ vision_encode_out = self._encode_vision(images, n_frames)
228
+ inputs_embeds, attention_mask, targets, vision_mask = self._concat_embedding(
229
+ vision_encode_out, batch, vision_placeholder_index)
230
+ # -----------------------------------------------
231
+
232
+ outputs = self.llm(
233
+ inputs_embeds=inputs_embeds,
234
+ attention_mask=attention_mask,
235
+ labels=targets,
236
+ return_dict=True,
237
+ output_hidden_states=True,
238
+ )
239
+
240
+ llm_loss = outputs.loss
241
+ device = llm_loss.device
242
+ dtype = llm_loss.dtype
243
+
244
+ metrics = {}
245
+
246
+ metrics["llm_loss"] = self.detach_and_gather_loss(llm_loss, dtype, device)
247
+ if self.enable_aux_vision:
248
+ if images.size(0) > 0:
249
+ aux_losses = self.aux_vision(images, outputs.hidden_states, vision_mask)
250
+ else:
251
+ # FIXME: hacking for deepspeed training
252
+ aux_losses = {key: torch.tensor(0., dtype=dtype).to(device) for key in self.aux_vision.loss_keys}
253
+
254
+ aux_loss = torch.tensor(0., dtype=dtype).to(device)
255
+ n_aux = 0
256
+ for _aux_key, _aux_loss in aux_losses.items():
257
+ aux_loss += _aux_loss
258
+ n_aux += 1
259
+ metrics[_aux_key] = self.detach_and_gather_loss(_aux_loss, dtype, device)
260
+ aux_loss /= n_aux
261
+
262
+ outputs.loss = aux_loss + llm_loss
263
+ metrics["total_loss"] = self.detach_and_gather_loss(outputs.loss, dtype, device)
264
+ self.report_metrics(**metrics)
265
+
266
+ return outputs
267
+
268
+ def generate(self, batch, **generate_params):
269
+
270
+ with torch.amp.autocast(
271
+ enabled=(self.device != torch.device("cpu")),
272
+ device_type=self.device.type,
273
+ ):
274
+ # get vision token
275
+ vision_placeholder_index = batch.pop("vision_placeholder_index")
276
+
277
+ # get vision features
278
+ images, n_frames = batch["frames"], batch["n_frames"]
279
+ vision_encode_out = self._encode_vision(images, n_frames)
280
+
281
+ inputs_embeds, attention_mask, _, _ = self._concat_embedding(
282
+ vision_encode_out, batch, vision_placeholder_index, left_padding=False)
283
+
284
+ outputs = self.llm.generate(
285
+ inputs_embeds=inputs_embeds,
286
+ attention_mask=attention_mask,
287
+ output_attentions=True,
288
+ **generate_params
289
+ )
290
+
291
+ return outputs
rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78d3f197f6c6558fa8056324f1563ab9e957255f5a1a959362aa4eed7a9545db
3
+ size 15984
rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c1a9c65c2869356282cad6b4a0f7dff7f4dd68ab3d9d216c72b7d6cb524f860
3
+ size 15984
rng_state_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:896febe768e17bae5022a95960c041f6425783774ec8859d99d3b149063b1bf9
3
+ size 15984
rng_state_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eac482d57e966585467c8ef44dae2869bf7e5d92886f69c11ed7bccc34c07efe
3
+ size 15984
rng_state_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1f27d227a20dc320ac283e0938fb2f6e5b475829a583f8c44d1a16a8c828307
3
+ size 15984
rng_state_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d05a7106aaeaec4b81704e3f4a998b5123cf9342a6733bd9fd2d578e99108c3b
3
+ size 15984
rng_state_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94120d8d88502ec8d8b623ec7550315caca003b44fcffbb5767ab0de91baefe
3
+ size 15984
rng_state_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:332e4d901be380f740b5d8578f7b80ef1865c7fba83bc288c8a35852205cc668
3
+ size 15984
rope_embeddings.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ from math import pi
7
+
8
+ import torch
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+
13
+ def broadcast(tensors, dim=-1):
14
+ num_tensors = len(tensors)
15
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
16
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
17
+ shape_len = list(shape_lens)[0]
18
+ dim = (dim + shape_len) if dim < 0 else dim
19
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
20
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
21
+ assert all(
22
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
23
+ ), 'invalid dimensions for broadcastable concatentation'
24
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
25
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
26
+ expanded_dims.insert(dim, (dim, dims[dim]))
27
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
28
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
29
+ return torch.cat(tensors, dim=dim)
30
+
31
+
32
+ def rotate_half(x):
33
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
34
+ x1, x2 = x.unbind(dim=-1)
35
+ x = torch.stack((-x2, x1), dim=-1)
36
+ return rearrange(x, '... d r -> ... (d r)')
37
+
38
+
39
+ class VisionRotaryEmbedding(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim,
43
+ pt_seq_len,
44
+ ft_seq_len=None,
45
+ custom_freqs=None,
46
+ freqs_for='lang',
47
+ theta=10000,
48
+ max_freq=10,
49
+ num_freqs=1,
50
+ ):
51
+ super().__init__()
52
+ if custom_freqs:
53
+ freqs = custom_freqs
54
+ elif freqs_for == 'lang':
55
+ freqs = 1.0 / (
56
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
57
+ )
58
+ elif freqs_for == 'pixel':
59
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
60
+ elif freqs_for == 'constant':
61
+ freqs = torch.ones(num_freqs).float()
62
+ else:
63
+ raise ValueError(f'unknown modality {freqs_for}')
64
+
65
+ if ft_seq_len is None:
66
+ ft_seq_len = pt_seq_len
67
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
68
+
69
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
70
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
71
+
72
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
73
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
74
+
75
+ freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
76
+
77
+ self.register_buffer('freqs_cos', freqs.cos(), persistent=False)
78
+ self.register_buffer('freqs_sin', freqs.sin(), persistent=False)
79
+
80
+ def forward(self, t, start_index=0):
81
+ rot_dim = self.freqs_cos.shape[-1]
82
+ end_index = start_index + rot_dim
83
+ assert rot_dim <= t.shape[-1], (
84
+ f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in '
85
+ f'all the positions {rot_dim}'
86
+ )
87
+ t_left, t, t_right = (
88
+ t[..., :start_index],
89
+ t[..., start_index:end_index],
90
+ t[..., end_index:],
91
+ )
92
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
93
+
94
+ return torch.cat((t_left, t, t_right), dim=-1)
95
+
96
+
97
+ class VisionRotaryEmbeddingFast(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dim,
101
+ pt_seq_len,
102
+ ft_seq_len=None,
103
+ custom_freqs=None,
104
+ freqs_for='lang',
105
+ theta=10000,
106
+ max_freq=10,
107
+ num_freqs=1,
108
+ patch_dropout=0.0,
109
+ ):
110
+ super().__init__()
111
+ if custom_freqs:
112
+ freqs = custom_freqs
113
+ elif freqs_for == 'lang':
114
+ freqs = 1.0 / (
115
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
116
+ )
117
+ elif freqs_for == 'pixel':
118
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
119
+ elif freqs_for == 'constant':
120
+ freqs = torch.ones(num_freqs).float()
121
+ else:
122
+ raise ValueError(f'unknown modality {freqs_for}')
123
+
124
+ if ft_seq_len is None:
125
+ ft_seq_len = pt_seq_len
126
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
127
+
128
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
129
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
130
+ freqs = broadcast((freqs[:, None, :], freqs[None, :, :]), dim=-1)
131
+
132
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
133
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
134
+
135
+ self.patch_dropout = patch_dropout
136
+
137
+ self.register_buffer('freqs_cos', freqs_cos, persistent=False)
138
+ self.register_buffer('freqs_sin', freqs_sin, persistent=False)
139
+
140
+ def forward(self, t, patch_indices_keep=None):
141
+ if patch_indices_keep is not None:
142
+ batch = t.size()[0]
143
+ batch_indices = torch.arange(batch)
144
+ batch_indices = batch_indices[..., None]
145
+
146
+ freqs_cos = repeat(
147
+ self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
148
+ )
149
+ freqs_sin = repeat(
150
+ self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
151
+ )
152
+
153
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
154
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
155
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
156
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
157
+
158
+ return t * freqs_cos + rotate_half(t) * freqs_sin
159
+
160
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f05512bf38916e185cca93d0ada0f63479a3d982044c9a30eec1c58ba2ff27e3
3
+ size 1064
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95a8c73db4abbe7f04d20fc397e5ca2a49e6027339b0470fc61067769542260c
3
+ size 7032
vision_embedding.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .configuration_vora import VoRAConfig
5
+
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dim: int, eps: float = 1e-6):
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(dim))
11
+ self.eps = eps
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ output = self._norm(x.float()).type_as(x)
15
+ return output * self.weight
16
+
17
+ def extra_repr(self) -> str:
18
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
19
+
20
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
21
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
22
+
23
+
24
+ class AIMv2PatchEmbed(nn.Module):
25
+ def __init__(self, config: VoRAConfig):
26
+ super().__init__()
27
+ self.proj = nn.Conv2d(
28
+ 3,
29
+ config.vision_embedding_intermediate_size,
30
+ kernel_size=(config.patch_size, config.patch_size),
31
+ stride=(config.patch_size, config.patch_size),
32
+ )
33
+ self.norm = RMSNorm(config.vision_embedding_intermediate_size, eps=config.rms_norm_eps)
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ x = self.proj(x).flatten(2).transpose(1, 2)
37
+ x = self.norm(x)
38
+ return x
39
+
40
+
41
+ class AIMv2Embedding(nn.Module):
42
+ def __init__(self,
43
+ config: VoRAConfig = None,
44
+ ):
45
+ super().__init__()
46
+ hidden_size = config.hidden_size
47
+ num_patches = (config.image_size // config.patch_size) ** 2
48
+ self.config = config
49
+
50
+ self.patchifier = AIMv2PatchEmbed(config)
51
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.vision_embedding_intermediate_size)))
52
+ self.out_proj = nn.Linear(config.vision_embedding_intermediate_size, hidden_size, bias=False)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ B, C, H, W = x.shape
56
+ h_token = H // self.config.patch_size
57
+ w_token = W // self.config.patch_size
58
+ tokens = self.patchifier(x)
59
+ _, N, _ = tokens.shape
60
+ pos_embed = self.pos_embed.to(tokens.device)
61
+
62
+ if N <= pos_embed.size(1):
63
+ tokens = tokens + pos_embed[:, :N]
64
+ else:
65
+ pos_embed = pos_embed.view(1, int(pos_embed.size(1)**0.5), int(pos_embed.size(1)**0.5), -1).permute(0, 3, 1, 2)
66
+ pos_embed = nn.functional.interpolate(pos_embed, size=(h_token, w_token), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
67
+ pos_embed = pos_embed.view(1, N, pos_embed.size(-1))
68
+ tokens = tokens + pos_embed
69
+
70
+ return self.out_proj(tokens)
vora_generation_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from transformers import GenerationMixin
5
+ from transformers.cache_utils import Cache
6
+ from transformers.utils import ModelOutput
7
+
8
+
9
+ class VoraGenerationMixin(GenerationMixin):
10
+
11
+ def prepare_inputs_for_generation(
12
+ self,
13
+ input_ids: torch.LongTensor,
14
+ past_key_values: Optional[Cache] = None,
15
+ attention_mask: Optional[torch.LongTensor] = None,
16
+ inputs_embeds: Optional[torch.FloatTensor] = None,
17
+ cache_position: Optional[torch.LongTensor] = None,
18
+ **kwargs,
19
+ ):
20
+ if attention_mask is not None and attention_mask.ndim == 4:
21
+ attention_mask_2d = (attention_mask[:, 0, :, :] == 0).any(dim=1).long().to(attention_mask.device)
22
+ model_input = super().prepare_inputs_for_generation(
23
+ input_ids,
24
+ past_key_values=past_key_values,
25
+ attention_mask=attention_mask_2d,
26
+ inputs_embeds=inputs_embeds,
27
+ cache_position=cache_position,
28
+ **kwargs,
29
+ )
30
+ model_input['attention_mask'] = attention_mask
31
+ return model_input
32
+ else:
33
+ return super().prepare_inputs_for_generation(
34
+ input_ids,
35
+ past_key_values=past_key_values,
36
+ attention_mask=attention_mask,
37
+ inputs_embeds=inputs_embeds,
38
+ cache_position=cache_position,
39
+ **kwargs,
40
+ )
41
+
42
+ def _update_model_kwargs_for_generation(
43
+ self,
44
+ outputs: ModelOutput,
45
+ model_kwargs: Dict[str, Any],
46
+ is_encoder_decoder: bool = False,
47
+ num_new_tokens: int = 1,
48
+ ) -> Dict[str, Any]:
49
+ if "attention_mask" in model_kwargs and model_kwargs["attention_mask"].ndim == 4:
50
+ attention_mask = model_kwargs.pop("attention_mask")
51
+ model_kwargs = super()._update_model_kwargs_for_generation(
52
+ outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
53
+ )
54
+ bs, _, seq_len, tgt_len = attention_mask.shape
55
+ dtype = attention_mask.dtype
56
+ min_dtype = torch.finfo(dtype).min
57
+ new_col = attention_mask.new_zeros((bs, 1, seq_len, 1)).fill_(min_dtype)
58
+ new_row = attention_mask.new_zeros((bs, 1, 1, tgt_len + 1))
59
+ model_kwargs["attention_mask"] = torch.cat([
60
+ torch.cat([attention_mask, new_col], dim=-1),
61
+ new_row
62
+ ], dim=2)
63
+ return model_kwargs
64
+ else:
65
+ return super()._update_model_kwargs_for_generation(
66
+ outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
67
+ )
68
+
69
+
70
+ def custom_prepare_4d_causal_attention_mask_with_cache_position(
71
+ attention_mask: torch.Tensor,
72
+ sequence_length: int,
73
+ target_length: int,
74
+ dtype: torch.dtype,
75
+ device: torch.device,
76
+ cache_position: torch.Tensor,
77
+ batch_size: int,
78
+ **kwargs,
79
+ ):
80
+ if attention_mask is not None and attention_mask.dim() == 4:
81
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
82
+ causal_mask = attention_mask[:, :, -sequence_length:, -target_length:]
83
+ else:
84
+ min_dtype = torch.finfo(dtype).min
85
+ causal_mask = torch.full(
86
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
87
+ )
88
+ if sequence_length != 1:
89
+ causal_mask = torch.triu(causal_mask, diagonal=1)
90
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
91
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
92
+ if attention_mask is not None:
93
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
94
+ mask_length = attention_mask.shape[-1]
95
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
96
+ padding_mask = padding_mask == 0
97
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
98
+ padding_mask, min_dtype
99
+ )
100
+
101
+ return causal_mask
zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)