lhallee commited on
Commit
0cd732a
·
verified ·
1 Parent(s): ecb6f10

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +11 -3
modeling_esm_plusplus.py CHANGED
@@ -29,7 +29,12 @@ from torch.nn.attention.flex_attention import flex_attention
29
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
  from transformers.modeling_outputs import ModelOutput
31
 
32
- from .embedding_mixin import EmbeddingMixin, Pooler
 
 
 
 
 
33
 
34
 
35
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
@@ -350,10 +355,10 @@ class MultiHeadAttention(nn.Module):
350
  )
351
  query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
352
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
 
353
 
354
  if output_attentions: # Manual attention computation
355
- b, h, l, d = query_BHLD.shape
356
- scale = 1 / math.sqrt(d)
357
  attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
358
  if attention_mask is not None:
359
  attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
@@ -377,6 +382,7 @@ class MultiHeadAttention(nn.Module):
377
  key_BHLD,
378
  value_BHLD,
379
  block_mask=flex_block_mask,
 
380
  )
381
  except Exception as exc:
382
  if not self._warned_flex_fallback:
@@ -390,6 +396,7 @@ class MultiHeadAttention(nn.Module):
390
  key_BHLD,
391
  value_BHLD,
392
  attn_mask=sdpa_mask,
 
393
  )
394
  else:
395
  context_BHLD = F.scaled_dot_product_attention(
@@ -397,6 +404,7 @@ class MultiHeadAttention(nn.Module):
397
  key_BHLD,
398
  value_BHLD,
399
  attn_mask=sdpa_mask,
 
400
  )
401
 
402
  context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
 
29
  from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
30
  from transformers.modeling_outputs import ModelOutput
31
 
32
+ try:
33
+ # when used from AutoModel, these are in the same directory
34
+ from .embedding_mixin import EmbeddingMixin, Pooler
35
+ except:
36
+ # when running from our repo, these are in the base directory
37
+ from embedding_mixin import EmbeddingMixin, Pooler
38
 
39
 
40
  def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
355
  )
356
  query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
357
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
358
+ scale = 1 / math.sqrt(self.d_head)
359
 
360
  if output_attentions: # Manual attention computation
361
+ b, h, l, _ = query_BHLD.shape
 
362
  attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
363
  if attention_mask is not None:
364
  attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
 
382
  key_BHLD,
383
  value_BHLD,
384
  block_mask=flex_block_mask,
385
+ scale=scale,
386
  )
387
  except Exception as exc:
388
  if not self._warned_flex_fallback:
 
396
  key_BHLD,
397
  value_BHLD,
398
  attn_mask=sdpa_mask,
399
+ scale=scale,
400
  )
401
  else:
402
  context_BHLD = F.scaled_dot_product_attention(
 
404
  key_BHLD,
405
  value_BHLD,
406
  attn_mask=sdpa_mask,
407
+ scale=scale,
408
  )
409
 
410
  context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")