Rihong commited on
Commit
c4f85ed
·
verified ·
1 Parent(s): 56a2e58

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. vit.py +12 -5
README.md CHANGED
@@ -63,4 +63,4 @@ hf upload Rihong/VideoChat2_Infinity_Mistral_7B_hf ./lmms_eval/baselines/infty_v
63
  ## References
64
 
65
  - [VideoChat2 (Ask-Anything)](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
66
- - [Infinite-Video](https://github.com/deep-spin/Infinite-Video)
 
63
  ## References
64
 
65
  - [VideoChat2 (Ask-Anything)](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat2)
66
+ - [Infinite-Video](https://github.com/deep-spin/Infinite-Video)
vit.py CHANGED
@@ -88,13 +88,20 @@ class Attention(nn.Module):
88
  qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
89
  q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
90
 
91
- q = q * self.scale
92
- attn = (q @ k.transpose(-2, -1))
 
93
 
94
- attn = attn.softmax(dim=-1)
95
- attn = self.attn_drop(attn)
 
 
 
 
 
 
 
96
 
97
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
98
  x = self.proj(x)
99
  x = self.proj_drop(x)
100
  return x
 
88
  qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
89
  q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
90
 
91
+ # memory inefficient attention implementation that has been replaced by F.scaled_dot_product_attention
92
+ # q = q * self.scale
93
+ # attn = (q @ k.transpose(-2, -1))
94
 
95
+ # attn = attn.softmax(dim=-1)
96
+ # attn = self.attn_drop(attn)
97
+
98
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
99
+
100
+ # Use F.scaled_dot_product_attention for memory-efficient attention (flash attention)
101
+ dropout_p = self.attn_drop.p if self.training else 0.0
102
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, scale=self.scale)
103
+ x = x.transpose(1, 2).reshape(B, N, -1)
104
 
 
105
  x = self.proj(x)
106
  x = self.proj_drop(x)
107
  return x