zR commited on
Commit ·
8ba56be
1
Parent(s): 7031540
fix #350
Browse files
visual.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers.activations import ACT2FN
|
|
| 6 |
import math
|
| 7 |
from torch.nn import LayerNorm
|
| 8 |
|
|
|
|
| 9 |
def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
| 10 |
if scaling_attention_score:
|
| 11 |
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
|
@@ -16,11 +17,12 @@ def standard_attention(query_layer, key_layer, value_layer, scaling_attention_sc
|
|
| 16 |
context_layer = torch.matmul(attention_probs, value_layer)
|
| 17 |
return context_layer
|
| 18 |
|
|
|
|
| 19 |
def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
| 20 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
|
| 21 |
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
| 22 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 23 |
-
query_layer, key_layer, value_layer,
|
| 24 |
attn_mask=None,
|
| 25 |
dropout_p=0.,
|
| 26 |
is_causal=False
|
|
@@ -31,10 +33,12 @@ def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_
|
|
| 31 |
query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
|
| 32 |
)
|
| 33 |
|
|
|
|
| 34 |
class PatchEmbedding(nn.Module):
|
| 35 |
def __init__(self, config):
|
| 36 |
super().__init__()
|
| 37 |
-
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
|
|
|
|
| 38 |
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
| 39 |
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
| 40 |
|
|
@@ -62,11 +66,11 @@ class Attention(nn.Module):
|
|
| 62 |
qkv = self.query_key_value(x)
|
| 63 |
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
|
| 64 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 65 |
-
|
| 66 |
out = attention_fn_default(
|
| 67 |
q, k, v
|
| 68 |
)
|
| 69 |
-
output = self.dense(out.transpose(1, 2).
|
| 70 |
output = self.output_dropout(output)
|
| 71 |
return output
|
| 72 |
|
|
@@ -105,7 +109,9 @@ class TransformerLayer(nn.Module):
|
|
| 105 |
attention_output = self.input_layernorm(self.attention(attention_input))
|
| 106 |
hidden_states = attention_input + attention_output
|
| 107 |
mlp_input = hidden_states
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
output = mlp_input + mlp_output
|
| 110 |
return output
|
| 111 |
|
|
@@ -147,7 +153,8 @@ class EVA2CLIPModel(nn.Module):
|
|
| 147 |
self.patch_embedding = PatchEmbedding(vision_config)
|
| 148 |
self.transformer = Transformer(vision_config)
|
| 149 |
self.linear_proj = GLU(config, in_features=config.hidden_size)
|
| 150 |
-
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
|
|
|
|
| 151 |
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 152 |
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 153 |
self.scaling_factor = vision_config.scaling_factor
|
|
@@ -158,14 +165,16 @@ class EVA2CLIPModel(nn.Module):
|
|
| 158 |
x = x[:, 1:]
|
| 159 |
|
| 160 |
b, s, h = x.shape
|
| 161 |
-
grid_size = int(s**0.5)
|
| 162 |
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
| 163 |
x = self.conv(x)
|
| 164 |
|
| 165 |
x = x.flatten(2).transpose(1, 2)
|
| 166 |
x = self.linear_proj(x)
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
x = torch.cat((boi, x, eoi), dim=1)
|
| 170 |
x = x / self.scaling_factor
|
| 171 |
return x
|
|
|
|
| 6 |
import math
|
| 7 |
from torch.nn import LayerNorm
|
| 8 |
|
| 9 |
+
|
| 10 |
def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
| 11 |
if scaling_attention_score:
|
| 12 |
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
|
|
|
|
| 17 |
context_layer = torch.matmul(attention_probs, value_layer)
|
| 18 |
return context_layer
|
| 19 |
|
| 20 |
+
|
| 21 |
def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
|
| 22 |
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
|
| 23 |
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
|
| 24 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 25 |
+
query_layer, key_layer, value_layer,
|
| 26 |
attn_mask=None,
|
| 27 |
dropout_p=0.,
|
| 28 |
is_causal=False
|
|
|
|
| 33 |
query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
|
| 34 |
)
|
| 35 |
|
| 36 |
+
|
| 37 |
class PatchEmbedding(nn.Module):
|
| 38 |
def __init__(self, config):
|
| 39 |
super().__init__()
|
| 40 |
+
self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
|
| 41 |
+
stride=config.patch_size)
|
| 42 |
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
| 43 |
self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
|
| 44 |
|
|
|
|
| 66 |
qkv = self.query_key_value(x)
|
| 67 |
qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
|
| 68 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 69 |
+
|
| 70 |
out = attention_fn_default(
|
| 71 |
q, k, v
|
| 72 |
)
|
| 73 |
+
output = self.dense(out.transpose(1, 2).view(B, L, -1))
|
| 74 |
output = self.output_dropout(output)
|
| 75 |
return output
|
| 76 |
|
|
|
|
| 109 |
attention_output = self.input_layernorm(self.attention(attention_input))
|
| 110 |
hidden_states = attention_input + attention_output
|
| 111 |
mlp_input = hidden_states
|
| 112 |
+
|
| 113 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
| 114 |
+
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
|
| 115 |
output = mlp_input + mlp_output
|
| 116 |
return output
|
| 117 |
|
|
|
|
| 153 |
self.patch_embedding = PatchEmbedding(vision_config)
|
| 154 |
self.transformer = Transformer(vision_config)
|
| 155 |
self.linear_proj = GLU(config, in_features=config.hidden_size)
|
| 156 |
+
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
|
| 157 |
+
stride=2)
|
| 158 |
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 159 |
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 160 |
self.scaling_factor = vision_config.scaling_factor
|
|
|
|
| 165 |
x = x[:, 1:]
|
| 166 |
|
| 167 |
b, s, h = x.shape
|
| 168 |
+
grid_size = int(s ** 0.5)
|
| 169 |
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
| 170 |
x = self.conv(x)
|
| 171 |
|
| 172 |
x = x.flatten(2).transpose(1, 2)
|
| 173 |
x = self.linear_proj(x)
|
| 174 |
+
|
| 175 |
+
# https://github.com/THUDM/GLM-4/issues/350
|
| 176 |
+
boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
|
| 177 |
+
eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
|
| 178 |
x = torch.cat((boi, x, eoi), dim=1)
|
| 179 |
x = x / self.scaling_factor
|
| 180 |
return x
|