Wi11Chan commited on
Commit
76a1ce2
·
1 Parent(s): eeac6f1

Upload ViTForSemanticSegmentation

Browse files
Files changed (2) hide show
  1. modeling_vit.py +12 -4
  2. pytorch_model.bin +1 -1
modeling_vit.py CHANGED
@@ -206,7 +206,8 @@ class ViTSelfAttention(nn.Module):
206
  def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
207
  new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
208
  x = x.view(new_x_shape)
209
- return x.permute(0, 2, 1, 3)
 
210
 
211
  def forward(
212
  self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
@@ -245,10 +246,17 @@ class ViTSelfAttention(nn.Module):
245
  # query_layer, key_layer, value_layer, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
246
  # )
247
 
248
- from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
249
- context_layer = flash_attn_func(query_layer, key_layer, value_layer, dropout_p=self.dropout_prob)
 
 
 
 
 
 
250
 
251
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
 
252
  new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
253
  context_layer = context_layer.view(new_context_layer_shape)
254
  outputs = (context_layer,)
 
206
  def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
207
  new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
208
  x = x.view(new_x_shape)
209
+ # return x.permute(0, 2, 1, 3)
210
+ return x
211
 
212
  def forward(
213
  self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
 
246
  # query_layer, key_layer, value_layer, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
247
  # )
248
 
249
+ # from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
250
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
251
+ myseq = torch.tensor([0, query_layer.shape[1]], dtype=torch.int32, device=query_layer.device)
252
+ # myseq = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
253
+ # device=qkv.device)
254
+ context_layer = flash_attn_unpadded_func(query_layer.squeeze(), key_layer.squeeze(), value_layer.squeeze(),
255
+ cu_seqlens_q=myseq, cu_seqlens_k=myseq, max_seqlen_q=query_layer.shape[1], max_seqlen_k=query_layer.shape[1],
256
+ dropout_p=self.dropout_prob)
257
 
258
+ # context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
259
+ context_layer = context_layer.unsqueeze(0).contiguous()
260
  new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
261
  context_layer = context_layer.view(new_context_layer_shape)
262
  outputs = (context_layer,)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b20e2f83b85cbd7bc4f8ee1a756bd503550dea7b76a5b0a16584263f69bbf3b
3
  size 345082557
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:312d9a0cddd14deb1a040a7006433acdf9a934e8ff4b9f84b85b14e0e34f610b
3
  size 345082557