XAFT commited on
Commit
b5fdd99
·
verified ·
1 Parent(s): 190ac94

Update implementation

Browse files
Files changed (1) hide show
  1. selective_vit.py +107 -41
selective_vit.py CHANGED
@@ -63,7 +63,7 @@ class SoftMaskedMultiheadAttention(nn.Module):
63
  nn.init.constant_(self.out_proj.bias, 0.)
64
 
65
 
66
- def naive_forward(self, query, key, value, key_padding_mask=None,
67
  attn_mask=None, average_attn_weights=True):
68
  batch_size, tgt_len, embed_dim = query.size()
69
  batch_size, src_len, _ = key.size()
@@ -120,7 +120,6 @@ class SoftMaskedMultiheadAttention(nn.Module):
120
  cu_seq_q, cu_seq_k,
121
  max_q, max_k,
122
  attn_mask=None,
123
- is_causal=False,
124
  ):
125
  """
126
  FlashAttention-compatible soft-masked attention using varlen_attn
@@ -182,7 +181,6 @@ class SoftMaskedMultiheadAttention(nn.Module):
182
  cu_seq_k=cu_seq_k,
183
  max_q=max_q,
184
  max_k=max_k,
185
- is_causal=is_causal,
186
  scale=scale_attn,
187
  )
188
 
@@ -193,9 +191,9 @@ class SoftMaskedMultiheadAttention(nn.Module):
193
 
194
  return out
195
 
196
- def forward(self, query, key, value, method="naive", **kwargs):
197
- if method == 'naive':
198
- out = self.naive_forward(query, key, value, **kwargs)
199
  elif method == "fa":
200
  out = self.flash_forward(query, key, value, **kwargs)
201
  else:
@@ -262,7 +260,7 @@ class EncoderBlock(nn.Module):
262
  x = self.embed(x)
263
  x = self.norm1(x)
264
  # Apply attention mechanism
265
- attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="naive")
266
  # Add & Norm
267
  x = x + self.path_drop(attn_output)
268
  x = self.norm2(x)
@@ -277,42 +275,111 @@ class EncoderBlock(nn.Module):
277
  return x
278
 
279
  def flash_forward(self, x, mask, skip_masks=False):
 
 
 
 
 
 
 
280
  binary_mask = mask >= self.mask_threshold
281
- sel_mask = mask[binary_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- seq_lengths = binary_mask.sum(1)
284
- cum_lengths = torch.zeros(binary_mask.shape[0]+1, dtype=torch.int, device=binary_mask.device)
285
- cum_lengths[1:] = seq_lengths.cumsum(-1)
286
- max_len = seq_lengths.amax()
 
 
 
287
 
288
- x1 = x
289
- x = x[binary_mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- x = self.embed(x)
292
- x = self.norm1(x)
293
- # Apply flash attention mechanism
294
  attn_output = self.self_attn(
295
- x, x, x,
296
- cu_seq_q=cum_lengths,
297
- cu_seq_k=cum_lengths,
298
- max_q=max_len,
299
- max_k=max_len,
300
- attn_mask=sel_mask if not skip_masks else None,
301
- method="fa"
302
- )
303
- # Add & Norm
304
- x = x + self.path_drop(attn_output)
305
- x = self.norm2(x)
306
- # Feed-forward network
307
- mlp_output = self.mlp(x)
308
- # Add & Norm
309
- x = self.path_drop(self.project(x + mlp_output))
310
- x = self.norm3(x)
311
- if mask is not None:
312
- x = x * sel_mask.unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- x_out = x1.clone()
315
- x_out[binary_mask] = x_out[binary_mask] + x
316
  return x_out
317
 
318
  def get_groups(self, mask, full=False):
@@ -328,7 +395,7 @@ class EncoderBlock(nn.Module):
328
  groups[-1][0].append(ii)
329
  return groups
330
 
331
- def naive_forward(self, x, mask, full=False, skip_masks=False):
332
  # Step 1: Threshold the mask without in-place ops
333
  mask_thresholded = mask * (mask >= self.mask_threshold)
334
  # Step 2: Prepare output tensor (copy of x)
@@ -369,11 +436,11 @@ class EncoderBlock(nn.Module):
369
  x = self.flash_forward(x, attn_mask, skip_masks)
370
  else:
371
  warnings.warn(
372
- "Flash Attention requirements not met, falling back to naive attention.",
373
  category=UserWarning,
374
  stacklevel=2,
375
  )
376
- x = self.naive_forward(x, attn_mask, full, skip_masks)
377
  else:
378
  x = self.forward_common(x, attn_mask, skip_masks)
379
  return x, attn_mask
@@ -522,7 +589,6 @@ class VisionTransformer(nn.Module):
522
  dis_cls_token = hidden_states[:, 1]
523
  dis_logits = self.dis_head(dis_cls_token)
524
 
525
- # Inference-time averaging (same as original)
526
  if not self.training:
527
  logits = (logits + dis_logits) / 2
528
 
 
63
  nn.init.constant_(self.out_proj.bias, 0.)
64
 
65
 
66
+ def eager_forward(self, query, key, value, key_padding_mask=None,
67
  attn_mask=None, average_attn_weights=True):
68
  batch_size, tgt_len, embed_dim = query.size()
69
  batch_size, src_len, _ = key.size()
 
120
  cu_seq_q, cu_seq_k,
121
  max_q, max_k,
122
  attn_mask=None,
 
123
  ):
124
  """
125
  FlashAttention-compatible soft-masked attention using varlen_attn
 
181
  cu_seq_k=cu_seq_k,
182
  max_q=max_q,
183
  max_k=max_k,
 
184
  scale=scale_attn,
185
  )
186
 
 
191
 
192
  return out
193
 
194
+ def forward(self, query, key, value, method="eager", **kwargs):
195
+ if method == 'eager':
196
+ out = self.eager_forward(query, key, value, **kwargs)
197
  elif method == "fa":
198
  out = self.flash_forward(query, key, value, **kwargs)
199
  else:
 
260
  x = self.embed(x)
261
  x = self.norm1(x)
262
  # Apply attention mechanism
263
+ attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="eager")
264
  # Add & Norm
265
  x = x + self.path_drop(attn_output)
266
  x = self.norm2(x)
 
275
  return x
276
 
277
  def flash_forward(self, x, mask, skip_masks=False):
278
+ # x: [B, N, C]
279
+ # mask: [B, N]
280
+
281
+ B, N, C = x.shape
282
+
283
+ x_res = x # residual
284
+
285
  binary_mask = mask >= self.mask_threshold
286
+ seq_lengths = binary_mask.sum(dim=1, dtype=torch.int32)
287
+ mean_len = seq_lengths.float().square().mean().sqrt().item()
288
+ max_len = seq_lengths.amax().item()
289
+ min_len = seq_lengths.amin().item()
290
+
291
+ # Early exit if nothing selected
292
+ if not binary_mask.any():
293
+ return x
294
+
295
+ # Check if nonselective or topk would be easier
296
+ if ((mean_len / x.shape[1]) > 0.90):
297
+ x_sel = x.flatten(0, 1)
298
+ flat_idx = None
299
+
300
+ if not skip_masks:
301
+ sel_mask = mask.flatten(0, 1)
302
+ else:
303
+ sel_mask = None
304
 
305
+ cu_seqlens = torch.arange(0, (B + 1) * N, step=N, dtype=torch.int32, device=x.device)
306
+ elif max_len > 32:
307
+ # Regular selective model
308
+ idx = binary_mask.nonzero(as_tuple=False)
309
+ b_idx = idx[:, 0]
310
+ t_idx = idx[:, 1]
311
+ flat_idx = b_idx * N + t_idx
312
 
313
+ # Pack selected tokens
314
+ x_sel = x[b_idx, t_idx]
315
+
316
+ if not skip_masks:
317
+ sel_mask = mask[b_idx, t_idx]
318
+ else:
319
+ sel_mask = None
320
+
321
+ # cu_seqlens for varlen FA
322
+ cu_seqlens = torch.zeros(binary_mask.shape[0]+1, dtype=torch.int, device=binary_mask.device)
323
+ cu_seqlens[1:] = seq_lengths.cumsum(-1)
324
+ else:
325
+ # Small kept lengths: use top-k packing, but keep varlen FA interface
326
+ k = max_len
327
+
328
+ # topk over score/mask values
329
+ top_vals, top_idx = mask.topk(k, dim=1, largest=True, sorted=False) # [B, k]
330
+ b_idx = torch.arange(B, device=mask.device)[:, None].expand_as(top_idx)
331
+ flat_idx = (b_idx * N + top_idx).reshape(-1)
332
+
333
+ gather_idx = top_idx.unsqueeze(-1).expand(-1, -1, C) # [B, k, C]
334
+ x_top = x.gather(1, gather_idx) # [B, k, C]
335
+
336
+ # Flatten, then keep only valid entries so packed layout matches varlen FA
337
+ x_sel = x_top.flatten(0, 1)
338
+
339
+ if not skip_masks:
340
+ sel_mask = top_vals.flatten(0, 1)
341
+ else:
342
+ sel_mask = None
343
+
344
+ cu_seqlens = torch.arange(0, (B + 1) * max_len, step=max_len, dtype=torch.int32, device=x.device)
345
+
346
+ cu_seqlens = cu_seqlens.to(torch.int32)
347
+
348
+ # Block
349
+ x_sel = self.embed(x_sel)
350
+ x_sel = self.norm1(x_sel)
351
 
 
 
 
352
  attn_output = self.self_attn(
353
+ x_sel, x_sel, x_sel,
354
+ cu_seq_q=cu_seqlens,
355
+ cu_seq_k=cu_seqlens,
356
+ max_q=max_len,
357
+ max_k=max_len,
358
+ attn_mask=None if skip_masks else sel_mask,
359
+ method="fa",
360
+ )
361
+
362
+ x_sel = x_sel + self.path_drop(attn_output)
363
+ x_sel = self.norm2(x_sel)
364
+
365
+ mlp_output = self.mlp(x_sel)
366
+ x_sel = self.path_drop(self.project(x_sel + mlp_output))
367
+ x_sel = self.norm3(x_sel)
368
+
369
+ if sel_mask is not None:
370
+ x_sel.mul_(sel_mask.unsqueeze(-1))
371
+
372
+ # Scatter back directly into residual output
373
+ if flat_idx is None:
374
+ x_out = x_res + x_sel.view(*x_res.shape)
375
+ else:
376
+ B, N, C = x_res.shape
377
+ flat_out = x_res.reshape(B * N, C)
378
+ if torch.is_grad_enabled():
379
+ flat_out = flat_out.clone()
380
+ flat_out.index_add_(0, flat_idx, x_sel)
381
+ x_out = flat_out.view(B, N, C)
382
 
 
 
383
  return x_out
384
 
385
  def get_groups(self, mask, full=False):
 
395
  groups[-1][0].append(ii)
396
  return groups
397
 
398
+ def eager_forward(self, x, mask, full=False, skip_masks=False):
399
  # Step 1: Threshold the mask without in-place ops
400
  mask_thresholded = mask * (mask >= self.mask_threshold)
401
  # Step 2: Prepare output tensor (copy of x)
 
436
  x = self.flash_forward(x, attn_mask, skip_masks)
437
  else:
438
  warnings.warn(
439
+ "Flash Attention requirements not met, falling back to eager attention.",
440
  category=UserWarning,
441
  stacklevel=2,
442
  )
443
+ x = self.eager_forward(x, attn_mask, full, skip_masks)
444
  else:
445
  x = self.forward_common(x, attn_mask, skip_masks)
446
  return x, attn_mask
 
589
  dis_cls_token = hidden_states[:, 1]
590
  dis_logits = self.dis_head(dis_cls_token)
591
 
 
592
  if not self.training:
593
  logits = (logits + dis_logits) / 2
594