Upload ViTForSemanticSegmentation
Browse files- modeling_vit.py +12 -4
- pytorch_model.bin +1 -1
modeling_vit.py
CHANGED
|
@@ -206,7 +206,8 @@ class ViTSelfAttention(nn.Module):
|
|
| 206 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 208 |
x = x.view(new_x_shape)
|
| 209 |
-
return x.permute(0, 2, 1, 3)
|
|
|
|
| 210 |
|
| 211 |
def forward(
|
| 212 |
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
|
@@ -245,10 +246,17 @@ class ViTSelfAttention(nn.Module):
|
|
| 245 |
# query_layer, key_layer, value_layer, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
|
| 246 |
# )
|
| 247 |
|
| 248 |
-
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
|
|
| 252 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 253 |
context_layer = context_layer.view(new_context_layer_shape)
|
| 254 |
outputs = (context_layer,)
|
|
|
|
| 206 |
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 208 |
x = x.view(new_x_shape)
|
| 209 |
+
# return x.permute(0, 2, 1, 3)
|
| 210 |
+
return x
|
| 211 |
|
| 212 |
def forward(
|
| 213 |
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
|
|
|
| 246 |
# query_layer, key_layer, value_layer, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
|
| 247 |
# )
|
| 248 |
|
| 249 |
+
# from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 250 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
| 251 |
+
myseq = torch.tensor([0, query_layer.shape[1]], dtype=torch.int32, device=query_layer.device)
|
| 252 |
+
# myseq = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 253 |
+
# device=qkv.device)
|
| 254 |
+
context_layer = flash_attn_unpadded_func(query_layer.squeeze(), key_layer.squeeze(), value_layer.squeeze(),
|
| 255 |
+
cu_seqlens_q=myseq, cu_seqlens_k=myseq, max_seqlen_q=query_layer.shape[1], max_seqlen_k=query_layer.shape[1],
|
| 256 |
+
dropout_p=self.dropout_prob)
|
| 257 |
|
| 258 |
+
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 259 |
+
context_layer = context_layer.unsqueeze(0).contiguous()
|
| 260 |
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 261 |
context_layer = context_layer.view(new_context_layer_shape)
|
| 262 |
outputs = (context_layer,)
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 345082557
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:312d9a0cddd14deb1a040a7006433acdf9a934e8ff4b9f84b85b14e0e34f610b
|
| 3 |
size 345082557
|