lhallee commited on
Commit
327bd59
·
verified ·
1 Parent(s): 9790c44

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +33 -3
modeling_fastesm.py CHANGED
@@ -156,6 +156,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
156
  return _collate_fn
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  class EmbeddingMixin:
160
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
  raise NotImplementedError
@@ -243,7 +263,7 @@ class EmbeddingMixin:
243
 
244
  def embed_dataset(
245
  self,
246
- sequences: List[str],
247
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
248
  batch_size: int = 2,
249
  max_len: int = 512,
@@ -256,6 +276,7 @@ class EmbeddingMixin:
256
  save: bool = True,
257
  sql_db_path: str = 'embeddings.db',
258
  save_path: str = 'embeddings.pth',
 
259
  **kwargs,
260
  ) -> Optional[dict[str, torch.Tensor]]:
261
  """
@@ -264,7 +285,15 @@ class EmbeddingMixin:
264
  Supports two modes:
265
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
266
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
 
 
 
267
  """
 
 
 
 
 
268
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
269
  sequences = sorted(sequences, key=len, reverse=True)
270
  hidden_size = self.config.hidden_size
@@ -645,8 +674,9 @@ def get_attention_mask(
645
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
646
  return attention_mask_2d, None, flex_block_mask
647
 
648
- # SDPA / manual
649
- attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
 
650
  return attention_mask_2d, attention_mask_4d, None
651
 
652
 
 
156
  return _collate_fn
157
 
158
 
159
+ def parse_fasta(fasta_path: str) -> List[str]:
160
+ assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
161
+ sequences = []
162
+ current_seq = []
163
+ with open(fasta_path, 'r') as f:
164
+ for line in f:
165
+ line = line.strip()
166
+ if not line:
167
+ continue
168
+ if line.startswith('>'):
169
+ if current_seq:
170
+ sequences.append(''.join(current_seq))
171
+ current_seq = []
172
+ else:
173
+ current_seq.append(line)
174
+ if current_seq:
175
+ sequences.append(''.join(current_seq))
176
+ return sequences
177
+
178
+
179
  class EmbeddingMixin:
180
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
181
  raise NotImplementedError
 
263
 
264
  def embed_dataset(
265
  self,
266
+ sequences: Optional[List[str]] = None,
267
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
268
  batch_size: int = 2,
269
  max_len: int = 512,
 
276
  save: bool = True,
277
  sql_db_path: str = 'embeddings.db',
278
  save_path: str = 'embeddings.pth',
279
+ fasta_path: Optional[str] = None,
280
  **kwargs,
281
  ) -> Optional[dict[str, torch.Tensor]]:
282
  """
 
285
  Supports two modes:
286
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
287
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
288
+
289
+ Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
290
+ `fasta_path`, or both (the two sources are combined). At least one must be provided.
291
  """
292
+ if fasta_path is not None:
293
+ fasta_sequences = parse_fasta(fasta_path)
294
+ sequences = list(sequences or []) + fasta_sequences
295
+ assert sequences is not None and len(sequences) > 0, \
296
+ "Must provide at least one sequence via `sequences` or `fasta_path`."
297
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
298
  sequences = sorted(sequences, key=len, reverse=True)
299
  hidden_size = self.config.hidden_size
 
674
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
675
  return attention_mask_2d, None, flex_block_mask
676
 
677
+ # SDPA / manual — only mask the key dimension so padding query positions attend to
678
+ # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
679
+ attention_mask_4d = attention_mask_2d[:, None, None, :]
680
  return attention_mask_2d, attention_mask_4d, None
681
 
682