lhallee commited on
Commit
ace7bb4
·
verified ·
1 Parent(s): a1aaef3

Upload embedding_mixin.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. embedding_mixin.py +30 -1
embedding_mixin.py CHANGED
@@ -155,6 +155,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
155
  return _collate_fn
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  class EmbeddingMixin:
159
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
160
  raise NotImplementedError
@@ -242,7 +262,7 @@ class EmbeddingMixin:
242
 
243
  def embed_dataset(
244
  self,
245
- sequences: List[str],
246
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
247
  batch_size: int = 2,
248
  max_len: int = 512,
@@ -255,6 +275,7 @@ class EmbeddingMixin:
255
  save: bool = True,
256
  sql_db_path: str = 'embeddings.db',
257
  save_path: str = 'embeddings.pth',
 
258
  **kwargs,
259
  ) -> Optional[dict[str, torch.Tensor]]:
260
  """
@@ -263,7 +284,15 @@ class EmbeddingMixin:
263
  Supports two modes:
264
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
265
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
 
 
 
266
  """
 
 
 
 
 
267
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
268
  sequences = sorted(sequences, key=len, reverse=True)
269
  hidden_size = self.config.hidden_size
 
155
  return _collate_fn
156
 
157
 
158
+ def parse_fasta(fasta_path: str) -> List[str]:
159
+ assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
160
+ sequences = []
161
+ current_seq = []
162
+ with open(fasta_path, 'r') as f:
163
+ for line in f:
164
+ line = line.strip()
165
+ if not line:
166
+ continue
167
+ if line.startswith('>'):
168
+ if current_seq:
169
+ sequences.append(''.join(current_seq))
170
+ current_seq = []
171
+ else:
172
+ current_seq.append(line)
173
+ if current_seq:
174
+ sequences.append(''.join(current_seq))
175
+ return sequences
176
+
177
+
178
  class EmbeddingMixin:
179
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
180
  raise NotImplementedError
 
262
 
263
  def embed_dataset(
264
  self,
265
+ sequences: Optional[List[str]] = None,
266
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
267
  batch_size: int = 2,
268
  max_len: int = 512,
 
275
  save: bool = True,
276
  sql_db_path: str = 'embeddings.db',
277
  save_path: str = 'embeddings.pth',
278
+ fasta_path: Optional[str] = None,
279
  **kwargs,
280
  ) -> Optional[dict[str, torch.Tensor]]:
281
  """
 
284
  Supports two modes:
285
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
286
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
287
+
288
+ Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
289
+ `fasta_path`, or both (the two sources are combined). At least one must be provided.
290
  """
291
+ if fasta_path is not None:
292
+ fasta_sequences = parse_fasta(fasta_path)
293
+ sequences = list(sequences or []) + fasta_sequences
294
+ assert sequences is not None and len(sequences) > 0, \
295
+ "Must provide at least one sequence via `sequences` or `fasta_path`."
296
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
297
  sequences = sorted(sequences, key=len, reverse=True)
298
  hidden_size = self.config.hidden_size