Update implementation
Browse files- selective_vit.py +107 -41
selective_vit.py
CHANGED
|
@@ -63,7 +63,7 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 63 |
nn.init.constant_(self.out_proj.bias, 0.)
|
| 64 |
|
| 65 |
|
| 66 |
-
def
|
| 67 |
attn_mask=None, average_attn_weights=True):
|
| 68 |
batch_size, tgt_len, embed_dim = query.size()
|
| 69 |
batch_size, src_len, _ = key.size()
|
|
@@ -120,7 +120,6 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 120 |
cu_seq_q, cu_seq_k,
|
| 121 |
max_q, max_k,
|
| 122 |
attn_mask=None,
|
| 123 |
-
is_causal=False,
|
| 124 |
):
|
| 125 |
"""
|
| 126 |
FlashAttention-compatible soft-masked attention using varlen_attn
|
|
@@ -182,7 +181,6 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 182 |
cu_seq_k=cu_seq_k,
|
| 183 |
max_q=max_q,
|
| 184 |
max_k=max_k,
|
| 185 |
-
is_causal=is_causal,
|
| 186 |
scale=scale_attn,
|
| 187 |
)
|
| 188 |
|
|
@@ -193,9 +191,9 @@ class SoftMaskedMultiheadAttention(nn.Module):
|
|
| 193 |
|
| 194 |
return out
|
| 195 |
|
| 196 |
-
def forward(self, query, key, value, method="
|
| 197 |
-
if method == '
|
| 198 |
-
out = self.
|
| 199 |
elif method == "fa":
|
| 200 |
out = self.flash_forward(query, key, value, **kwargs)
|
| 201 |
else:
|
|
@@ -262,7 +260,7 @@ class EncoderBlock(nn.Module):
|
|
| 262 |
x = self.embed(x)
|
| 263 |
x = self.norm1(x)
|
| 264 |
# Apply attention mechanism
|
| 265 |
-
attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="
|
| 266 |
# Add & Norm
|
| 267 |
x = x + self.path_drop(attn_output)
|
| 268 |
x = self.norm2(x)
|
|
@@ -277,42 +275,111 @@ class EncoderBlock(nn.Module):
|
|
| 277 |
return x
|
| 278 |
|
| 279 |
def flash_forward(self, x, mask, skip_masks=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
binary_mask = mask >= self.mask_threshold
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
-
x = self.embed(x)
|
| 292 |
-
x = self.norm1(x)
|
| 293 |
-
# Apply flash attention mechanism
|
| 294 |
attn_output = self.self_attn(
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
mlp_output = self.mlp(
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
if
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
x_out = x1.clone()
|
| 315 |
-
x_out[binary_mask] = x_out[binary_mask] + x
|
| 316 |
return x_out
|
| 317 |
|
| 318 |
def get_groups(self, mask, full=False):
|
|
@@ -328,7 +395,7 @@ class EncoderBlock(nn.Module):
|
|
| 328 |
groups[-1][0].append(ii)
|
| 329 |
return groups
|
| 330 |
|
| 331 |
-
def
|
| 332 |
# Step 1: Threshold the mask without in-place ops
|
| 333 |
mask_thresholded = mask * (mask >= self.mask_threshold)
|
| 334 |
# Step 2: Prepare output tensor (copy of x)
|
|
@@ -369,11 +436,11 @@ class EncoderBlock(nn.Module):
|
|
| 369 |
x = self.flash_forward(x, attn_mask, skip_masks)
|
| 370 |
else:
|
| 371 |
warnings.warn(
|
| 372 |
-
"Flash Attention requirements not met, falling back to
|
| 373 |
category=UserWarning,
|
| 374 |
stacklevel=2,
|
| 375 |
)
|
| 376 |
-
x = self.
|
| 377 |
else:
|
| 378 |
x = self.forward_common(x, attn_mask, skip_masks)
|
| 379 |
return x, attn_mask
|
|
@@ -522,7 +589,6 @@ class VisionTransformer(nn.Module):
|
|
| 522 |
dis_cls_token = hidden_states[:, 1]
|
| 523 |
dis_logits = self.dis_head(dis_cls_token)
|
| 524 |
|
| 525 |
-
# Inference-time averaging (same as original)
|
| 526 |
if not self.training:
|
| 527 |
logits = (logits + dis_logits) / 2
|
| 528 |
|
|
|
|
| 63 |
nn.init.constant_(self.out_proj.bias, 0.)
|
| 64 |
|
| 65 |
|
| 66 |
+
def eager_forward(self, query, key, value, key_padding_mask=None,
|
| 67 |
attn_mask=None, average_attn_weights=True):
|
| 68 |
batch_size, tgt_len, embed_dim = query.size()
|
| 69 |
batch_size, src_len, _ = key.size()
|
|
|
|
| 120 |
cu_seq_q, cu_seq_k,
|
| 121 |
max_q, max_k,
|
| 122 |
attn_mask=None,
|
|
|
|
| 123 |
):
|
| 124 |
"""
|
| 125 |
FlashAttention-compatible soft-masked attention using varlen_attn
|
|
|
|
| 181 |
cu_seq_k=cu_seq_k,
|
| 182 |
max_q=max_q,
|
| 183 |
max_k=max_k,
|
|
|
|
| 184 |
scale=scale_attn,
|
| 185 |
)
|
| 186 |
|
|
|
|
| 191 |
|
| 192 |
return out
|
| 193 |
|
| 194 |
+
def forward(self, query, key, value, method="eager", **kwargs):
|
| 195 |
+
if method == 'eager':
|
| 196 |
+
out = self.eager_forward(query, key, value, **kwargs)
|
| 197 |
elif method == "fa":
|
| 198 |
out = self.flash_forward(query, key, value, **kwargs)
|
| 199 |
else:
|
|
|
|
| 260 |
x = self.embed(x)
|
| 261 |
x = self.norm1(x)
|
| 262 |
# Apply attention mechanism
|
| 263 |
+
attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="eager")
|
| 264 |
# Add & Norm
|
| 265 |
x = x + self.path_drop(attn_output)
|
| 266 |
x = self.norm2(x)
|
|
|
|
| 275 |
return x
|
| 276 |
|
| 277 |
def flash_forward(self, x, mask, skip_masks=False):
|
| 278 |
+
# x: [B, N, C]
|
| 279 |
+
# mask: [B, N]
|
| 280 |
+
|
| 281 |
+
B, N, C = x.shape
|
| 282 |
+
|
| 283 |
+
x_res = x # residual
|
| 284 |
+
|
| 285 |
binary_mask = mask >= self.mask_threshold
|
| 286 |
+
seq_lengths = binary_mask.sum(dim=1, dtype=torch.int32)
|
| 287 |
+
mean_len = seq_lengths.float().square().mean().sqrt().item()
|
| 288 |
+
max_len = seq_lengths.amax().item()
|
| 289 |
+
min_len = seq_lengths.amin().item()
|
| 290 |
+
|
| 291 |
+
# Early exit if nothing selected
|
| 292 |
+
if not binary_mask.any():
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
# Check if nonselective or topk would be easier
|
| 296 |
+
if ((mean_len / x.shape[1]) > 0.90):
|
| 297 |
+
x_sel = x.flatten(0, 1)
|
| 298 |
+
flat_idx = None
|
| 299 |
+
|
| 300 |
+
if not skip_masks:
|
| 301 |
+
sel_mask = mask.flatten(0, 1)
|
| 302 |
+
else:
|
| 303 |
+
sel_mask = None
|
| 304 |
|
| 305 |
+
cu_seqlens = torch.arange(0, (B + 1) * N, step=N, dtype=torch.int32, device=x.device)
|
| 306 |
+
elif max_len > 32:
|
| 307 |
+
# Regular selective model
|
| 308 |
+
idx = binary_mask.nonzero(as_tuple=False)
|
| 309 |
+
b_idx = idx[:, 0]
|
| 310 |
+
t_idx = idx[:, 1]
|
| 311 |
+
flat_idx = b_idx * N + t_idx
|
| 312 |
|
| 313 |
+
# Pack selected tokens
|
| 314 |
+
x_sel = x[b_idx, t_idx]
|
| 315 |
+
|
| 316 |
+
if not skip_masks:
|
| 317 |
+
sel_mask = mask[b_idx, t_idx]
|
| 318 |
+
else:
|
| 319 |
+
sel_mask = None
|
| 320 |
+
|
| 321 |
+
# cu_seqlens for varlen FA
|
| 322 |
+
cu_seqlens = torch.zeros(binary_mask.shape[0]+1, dtype=torch.int, device=binary_mask.device)
|
| 323 |
+
cu_seqlens[1:] = seq_lengths.cumsum(-1)
|
| 324 |
+
else:
|
| 325 |
+
# Small kept lengths: use top-k packing, but keep varlen FA interface
|
| 326 |
+
k = max_len
|
| 327 |
+
|
| 328 |
+
# topk over score/mask values
|
| 329 |
+
top_vals, top_idx = mask.topk(k, dim=1, largest=True, sorted=False) # [B, k]
|
| 330 |
+
b_idx = torch.arange(B, device=mask.device)[:, None].expand_as(top_idx)
|
| 331 |
+
flat_idx = (b_idx * N + top_idx).reshape(-1)
|
| 332 |
+
|
| 333 |
+
gather_idx = top_idx.unsqueeze(-1).expand(-1, -1, C) # [B, k, C]
|
| 334 |
+
x_top = x.gather(1, gather_idx) # [B, k, C]
|
| 335 |
+
|
| 336 |
+
# Flatten, then keep only valid entries so packed layout matches varlen FA
|
| 337 |
+
x_sel = x_top.flatten(0, 1)
|
| 338 |
+
|
| 339 |
+
if not skip_masks:
|
| 340 |
+
sel_mask = top_vals.flatten(0, 1)
|
| 341 |
+
else:
|
| 342 |
+
sel_mask = None
|
| 343 |
+
|
| 344 |
+
cu_seqlens = torch.arange(0, (B + 1) * max_len, step=max_len, dtype=torch.int32, device=x.device)
|
| 345 |
+
|
| 346 |
+
cu_seqlens = cu_seqlens.to(torch.int32)
|
| 347 |
+
|
| 348 |
+
# Block
|
| 349 |
+
x_sel = self.embed(x_sel)
|
| 350 |
+
x_sel = self.norm1(x_sel)
|
| 351 |
|
|
|
|
|
|
|
|
|
|
| 352 |
attn_output = self.self_attn(
|
| 353 |
+
x_sel, x_sel, x_sel,
|
| 354 |
+
cu_seq_q=cu_seqlens,
|
| 355 |
+
cu_seq_k=cu_seqlens,
|
| 356 |
+
max_q=max_len,
|
| 357 |
+
max_k=max_len,
|
| 358 |
+
attn_mask=None if skip_masks else sel_mask,
|
| 359 |
+
method="fa",
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
x_sel = x_sel + self.path_drop(attn_output)
|
| 363 |
+
x_sel = self.norm2(x_sel)
|
| 364 |
+
|
| 365 |
+
mlp_output = self.mlp(x_sel)
|
| 366 |
+
x_sel = self.path_drop(self.project(x_sel + mlp_output))
|
| 367 |
+
x_sel = self.norm3(x_sel)
|
| 368 |
+
|
| 369 |
+
if sel_mask is not None:
|
| 370 |
+
x_sel.mul_(sel_mask.unsqueeze(-1))
|
| 371 |
+
|
| 372 |
+
# Scatter back directly into residual output
|
| 373 |
+
if flat_idx is None:
|
| 374 |
+
x_out = x_res + x_sel.view(*x_res.shape)
|
| 375 |
+
else:
|
| 376 |
+
B, N, C = x_res.shape
|
| 377 |
+
flat_out = x_res.reshape(B * N, C)
|
| 378 |
+
if torch.is_grad_enabled():
|
| 379 |
+
flat_out = flat_out.clone()
|
| 380 |
+
flat_out.index_add_(0, flat_idx, x_sel)
|
| 381 |
+
x_out = flat_out.view(B, N, C)
|
| 382 |
|
|
|
|
|
|
|
| 383 |
return x_out
|
| 384 |
|
| 385 |
def get_groups(self, mask, full=False):
|
|
|
|
| 395 |
groups[-1][0].append(ii)
|
| 396 |
return groups
|
| 397 |
|
| 398 |
+
def eager_forward(self, x, mask, full=False, skip_masks=False):
|
| 399 |
# Step 1: Threshold the mask without in-place ops
|
| 400 |
mask_thresholded = mask * (mask >= self.mask_threshold)
|
| 401 |
# Step 2: Prepare output tensor (copy of x)
|
|
|
|
| 436 |
x = self.flash_forward(x, attn_mask, skip_masks)
|
| 437 |
else:
|
| 438 |
warnings.warn(
|
| 439 |
+
"Flash Attention requirements not met, falling back to eager attention.",
|
| 440 |
category=UserWarning,
|
| 441 |
stacklevel=2,
|
| 442 |
)
|
| 443 |
+
x = self.eager_forward(x, attn_mask, full, skip_masks)
|
| 444 |
else:
|
| 445 |
x = self.forward_common(x, attn_mask, skip_masks)
|
| 446 |
return x, attn_mask
|
|
|
|
| 589 |
dis_cls_token = hidden_states[:, 1]
|
| 590 |
dis_logits = self.dis_head(dis_cls_token)
|
| 591 |
|
|
|
|
| 592 |
if not self.training:
|
| 593 |
logits = (logits + dis_logits) / 2
|
| 594 |
|