Update implementation
Browse files- selective_vit.py +8 -9
selective_vit.py
CHANGED
|
@@ -63,7 +63,7 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 63 |
nn.init.constant_(self.out_proj.bias, 0.)
|
| 64 |
|
| 65 |
|
| 66 |
-
def
|
| 67 |
attn_mask=None, average_attn_weights=True):
|
| 68 |
batch_size, tgt_len, embed_dim = query.size()
|
| 69 |
batch_size, src_len, _ = key.size()
|
|
@@ -191,9 +191,9 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 191 |
|
| 192 |
return out
|
| 193 |
|
| 194 |
-
def forward(self, query, key, value, method="
|
| 195 |
-
if method == '
|
| 196 |
-
out = self.
|
| 197 |
elif method == "fa":
|
| 198 |
out = self.flash_forward(query, key, value, **kwargs)
|
| 199 |
else:
|
|
@@ -260,7 +260,7 @@ class EncoderBlock(nn.Module):
|
|
| 260 |
x = self.embed(x)
|
| 261 |
x = self.norm1(x)
|
| 262 |
# Apply attention mechanism
|
| 263 |
-
attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="
|
| 264 |
# Add & Norm
|
| 265 |
x = x + self.path_drop(attn_output)
|
| 266 |
x = self.norm2(x)
|
|
@@ -395,7 +395,7 @@ class EncoderBlock(nn.Module):
|
|
| 395 |
groups[-1][0].append(ii)
|
| 396 |
return groups
|
| 397 |
|
| 398 |
-
def
|
| 399 |
# Step 1: Threshold the mask without in-place ops
|
| 400 |
mask_thresholded = mask * (mask >= self.mask_threshold)
|
| 401 |
# Step 2: Prepare output tensor (copy of x)
|
|
@@ -436,11 +436,11 @@ class EncoderBlock(nn.Module):
|
|
| 436 |
x = self.flash_forward(x, attn_mask, skip_masks)
|
| 437 |
else:
|
| 438 |
warnings.warn(
|
| 439 |
-
"Flash Attention requirements not met, falling back to
|
| 440 |
category=UserWarning,
|
| 441 |
stacklevel=2,
|
| 442 |
)
|
| 443 |
-
x = self.
|
| 444 |
else:
|
| 445 |
x = self.forward_common(x, attn_mask, skip_masks)
|
| 446 |
return x, attn_mask
|
|
@@ -589,7 +589,6 @@ class VisionTransformer(nn.Module):
|
|
| 589 |
dis_cls_token = hidden_states[:, 1]
|
| 590 |
dis_logits = self.dis_head(dis_cls_token)
|
| 591 |
|
| 592 |
-
# Inference-time averaging (same as original)
|
| 593 |
if not self.training:
|
| 594 |
logits = (logits + dis_logits) / 2
|
| 595 |
|
|
|
|
| 63 |
nn.init.constant_(self.out_proj.bias, 0.)
|
| 64 |
|
| 65 |
|
| 66 |
+
def eager_forward(self, query, key, value, key_padding_mask=None,
|
| 67 |
attn_mask=None, average_attn_weights=True):
|
| 68 |
batch_size, tgt_len, embed_dim = query.size()
|
| 69 |
batch_size, src_len, _ = key.size()
|
|
|
|
| 191 |
|
| 192 |
return out
|
| 193 |
|
| 194 |
+
def forward(self, query, key, value, method="eager", **kwargs):
|
| 195 |
+
if method == 'eager':
|
| 196 |
+
out = self.eager_forward(query, key, value, **kwargs)
|
| 197 |
elif method == "fa":
|
| 198 |
out = self.flash_forward(query, key, value, **kwargs)
|
| 199 |
else:
|
|
|
|
| 260 |
x = self.embed(x)
|
| 261 |
x = self.norm1(x)
|
| 262 |
# Apply attention mechanism
|
| 263 |
+
attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="eager")
|
| 264 |
# Add & Norm
|
| 265 |
x = x + self.path_drop(attn_output)
|
| 266 |
x = self.norm2(x)
|
|
|
|
| 395 |
groups[-1][0].append(ii)
|
| 396 |
return groups
|
| 397 |
|
| 398 |
+
def eager_forward(self, x, mask, full=False, skip_masks=False):
|
| 399 |
# Step 1: Threshold the mask without in-place ops
|
| 400 |
mask_thresholded = mask * (mask >= self.mask_threshold)
|
| 401 |
# Step 2: Prepare output tensor (copy of x)
|
|
|
|
| 436 |
x = self.flash_forward(x, attn_mask, skip_masks)
|
| 437 |
else:
|
| 438 |
warnings.warn(
|
| 439 |
+
"Flash Attention requirements not met, falling back to eager attention.",
|
| 440 |
category=UserWarning,
|
| 441 |
stacklevel=2,
|
| 442 |
)
|
| 443 |
+
x = self.eager_forward(x, attn_mask, full, skip_masks)
|
| 444 |
else:
|
| 445 |
x = self.forward_common(x, attn_mask, skip_masks)
|
| 446 |
return x, attn_mask
|
|
|
|
| 589 |
dis_cls_token = hidden_states[:, 1]
|
| 590 |
dis_logits = self.dis_head(dis_cls_token)
|
| 591 |
|
|
|
|
| 592 |
if not self.training:
|
| 593 |
logits = (logits + dis_logits) / 2
|
| 594 |
|