Upload embedding_mixin.py with huggingface_hub
Browse files- embedding_mixin.py +30 -1
embedding_mixin.py
CHANGED
|
@@ -155,6 +155,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
|
|
| 155 |
return _collate_fn
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
class EmbeddingMixin:
|
| 159 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 160 |
raise NotImplementedError
|
|
@@ -242,7 +262,7 @@ class EmbeddingMixin:
|
|
| 242 |
|
| 243 |
def embed_dataset(
|
| 244 |
self,
|
| 245 |
-
sequences: List[str],
|
| 246 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 247 |
batch_size: int = 2,
|
| 248 |
max_len: int = 512,
|
|
@@ -255,6 +275,7 @@ class EmbeddingMixin:
|
|
| 255 |
save: bool = True,
|
| 256 |
sql_db_path: str = 'embeddings.db',
|
| 257 |
save_path: str = 'embeddings.pth',
|
|
|
|
| 258 |
**kwargs,
|
| 259 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 260 |
"""
|
|
@@ -263,7 +284,15 @@ class EmbeddingMixin:
|
|
| 263 |
Supports two modes:
|
| 264 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 265 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
|
|
|
|
|
|
|
|
|
| 266 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 268 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 269 |
hidden_size = self.config.hidden_size
|
|
|
|
| 155 |
return _collate_fn
|
| 156 |
|
| 157 |
|
| 158 |
+
def parse_fasta(fasta_path: str) -> List[str]:
|
| 159 |
+
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
|
| 160 |
+
sequences = []
|
| 161 |
+
current_seq = []
|
| 162 |
+
with open(fasta_path, 'r') as f:
|
| 163 |
+
for line in f:
|
| 164 |
+
line = line.strip()
|
| 165 |
+
if not line:
|
| 166 |
+
continue
|
| 167 |
+
if line.startswith('>'):
|
| 168 |
+
if current_seq:
|
| 169 |
+
sequences.append(''.join(current_seq))
|
| 170 |
+
current_seq = []
|
| 171 |
+
else:
|
| 172 |
+
current_seq.append(line)
|
| 173 |
+
if current_seq:
|
| 174 |
+
sequences.append(''.join(current_seq))
|
| 175 |
+
return sequences
|
| 176 |
+
|
| 177 |
+
|
| 178 |
class EmbeddingMixin:
|
| 179 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 180 |
raise NotImplementedError
|
|
|
|
| 262 |
|
| 263 |
def embed_dataset(
|
| 264 |
self,
|
| 265 |
+
sequences: Optional[List[str]] = None,
|
| 266 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 267 |
batch_size: int = 2,
|
| 268 |
max_len: int = 512,
|
|
|
|
| 275 |
save: bool = True,
|
| 276 |
sql_db_path: str = 'embeddings.db',
|
| 277 |
save_path: str = 'embeddings.pth',
|
| 278 |
+
fasta_path: Optional[str] = None,
|
| 279 |
**kwargs,
|
| 280 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 281 |
"""
|
|
|
|
| 284 |
Supports two modes:
|
| 285 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 286 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
| 287 |
+
|
| 288 |
+
Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
|
| 289 |
+
`fasta_path`, or both (the two sources are combined). At least one must be provided.
|
| 290 |
"""
|
| 291 |
+
if fasta_path is not None:
|
| 292 |
+
fasta_sequences = parse_fasta(fasta_path)
|
| 293 |
+
sequences = list(sequences or []) + fasta_sequences
|
| 294 |
+
assert sequences is not None and len(sequences) > 0, \
|
| 295 |
+
"Must provide at least one sequence via `sequences` or `fasta_path`."
|
| 296 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 297 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 298 |
hidden_size = self.config.hidden_size
|