XAFT commited on
Commit
1d0ad69
·
verified ·
1 Parent(s): 30e5071

Update implementation

Browse files
Files changed (1) hide show
  1. selective_vit.py +8 -9
selective_vit.py CHANGED
@@ -63,7 +63,7 @@ class SoftMaskedMultiheadAttention(nn.Module):
63
  nn.init.constant_(self.out_proj.bias, 0.)
64
 
65
 
66
- def naive_forward(self, query, key, value, key_padding_mask=None,
67
  attn_mask=None, average_attn_weights=True):
68
  batch_size, tgt_len, embed_dim = query.size()
69
  batch_size, src_len, _ = key.size()
@@ -191,9 +191,9 @@ class SoftMaskedMultiheadAttention(nn.Module):
191
 
192
  return out
193
 
194
- def forward(self, query, key, value, method="naive", **kwargs):
195
- if method == 'naive':
196
- out = self.naive_forward(query, key, value, **kwargs)
197
  elif method == "fa":
198
  out = self.flash_forward(query, key, value, **kwargs)
199
  else:
@@ -260,7 +260,7 @@ class EncoderBlock(nn.Module):
260
  x = self.embed(x)
261
  x = self.norm1(x)
262
  # Apply attention mechanism
263
- attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="naive")
264
  # Add & Norm
265
  x = x + self.path_drop(attn_output)
266
  x = self.norm2(x)
@@ -395,7 +395,7 @@ class EncoderBlock(nn.Module):
395
  groups[-1][0].append(ii)
396
  return groups
397
 
398
- def naive_forward(self, x, mask, full=False, skip_masks=False):
399
  # Step 1: Threshold the mask without in-place ops
400
  mask_thresholded = mask * (mask >= self.mask_threshold)
401
  # Step 2: Prepare output tensor (copy of x)
@@ -436,11 +436,11 @@ class EncoderBlock(nn.Module):
436
  x = self.flash_forward(x, attn_mask, skip_masks)
437
  else:
438
  warnings.warn(
439
- "Flash Attention requirements not met, falling back to naive attention.",
440
  category=UserWarning,
441
  stacklevel=2,
442
  )
443
- x = self.naive_forward(x, attn_mask, full, skip_masks)
444
  else:
445
  x = self.forward_common(x, attn_mask, skip_masks)
446
  return x, attn_mask
@@ -589,7 +589,6 @@ class VisionTransformer(nn.Module):
589
  dis_cls_token = hidden_states[:, 1]
590
  dis_logits = self.dis_head(dis_cls_token)
591
 
592
- # Inference-time averaging (same as original)
593
  if not self.training:
594
  logits = (logits + dis_logits) / 2
595
 
 
63
  nn.init.constant_(self.out_proj.bias, 0.)
64
 
65
 
66
+ def eager_forward(self, query, key, value, key_padding_mask=None,
67
  attn_mask=None, average_attn_weights=True):
68
  batch_size, tgt_len, embed_dim = query.size()
69
  batch_size, src_len, _ = key.size()
 
191
 
192
  return out
193
 
194
+ def forward(self, query, key, value, method="eager", **kwargs):
195
+ if method == 'eager':
196
+ out = self.eager_forward(query, key, value, **kwargs)
197
  elif method == "fa":
198
  out = self.flash_forward(query, key, value, **kwargs)
199
  else:
 
260
  x = self.embed(x)
261
  x = self.norm1(x)
262
  # Apply attention mechanism
263
+ attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="eager")
264
  # Add & Norm
265
  x = x + self.path_drop(attn_output)
266
  x = self.norm2(x)
 
395
  groups[-1][0].append(ii)
396
  return groups
397
 
398
+ def eager_forward(self, x, mask, full=False, skip_masks=False):
399
  # Step 1: Threshold the mask without in-place ops
400
  mask_thresholded = mask * (mask >= self.mask_threshold)
401
  # Step 2: Prepare output tensor (copy of x)
 
436
  x = self.flash_forward(x, attn_mask, skip_masks)
437
  else:
438
  warnings.warn(
439
+ "Flash Attention requirements not met, falling back to eager attention.",
440
  category=UserWarning,
441
  stacklevel=2,
442
  )
443
+ x = self.eager_forward(x, attn_mask, full, skip_masks)
444
  else:
445
  x = self.forward_common(x, attn_mask, skip_masks)
446
  return x, attn_mask
 
589
  dis_cls_token = hidden_states[:, 1]
590
  dis_logits = self.dis_head(dis_cls_token)
591
 
 
592
  if not self.training:
593
  logits = (logits + dis_logits) / 2
594