Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
File size: 4,932 Bytes
e6bc942 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Protein-sequence data + MNTP masking collator.
We stream a small slice of a public protein-sequence set (UniRef-style). The collator does BERT-style 80/10/10 masking and produces
`labels` with -100 everywhere except masked positions, so MNTP loss is only
taken on masked tokens (predicted from the preceding position in the model).
"""
from __future__ import annotations
import torch
from torch.utils.data import Dataset
class ProteinSeqDataset(Dataset):
def __init__(self, sequences: list[str], tokenizer, max_length: int = 256):
self.sequences = sequences
self.tok = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
enc = self.tok(
self.sequences[idx],
truncation=True,
max_length=self.max_length,
return_tensors=None,
)
return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
class MNTPCollator:
"""Pad a batch and apply 80/10/10 masking; labels=-100 except masked."""
def __init__(self, tokenizer, mlm_probability: float = 0.15):
self.tok = tokenizer
self.p = mlm_probability
self.mask_id = (
tokenizer.mask_token_id
if tokenizer.mask_token_id is not None
else tokenizer.pad_token_id
)
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.vocab_size = len(tokenizer)
def __call__(self, examples):
maxlen = max(len(e["input_ids"]) for e in examples)
input_ids, attn = [], []
for e in examples:
ids = e["input_ids"]
pad = maxlen - len(ids)
input_ids.append(ids + [self.pad_id] * pad)
attn.append(e["attention_mask"] + [0] * pad)
input_ids = torch.tensor(input_ids, dtype=torch.long)
attn = torch.tensor(attn, dtype=torch.long)
labels = input_ids.clone()
prob = torch.full(labels.shape, self.p)
prob.masked_fill_(attn == 0, 0.0)
masked = torch.bernoulli(prob).bool()
labels[~masked] = -100
# 80% -> [MASK]
repl = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked
input_ids[repl] = self.mask_id
# 10% -> random token (remaining of the masked set, half of them)
rand = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked & ~repl
input_ids[rand] = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)[rand]
# remaining 10% kept unchanged
return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
class CleanCollator:
"""Pad a batch with NO masking — for the SimCSE stage.
LLM2Vec's unsupervised SimCSE runs on clean (unmasked) sentences; the two
positive views come purely from dropout, not from corruption. So the
contrastive stage must NOT see MNTP-masked input (that was the bug in our
joint objective). No `labels` are produced — SimCSE doesn't use them."""
def __init__(self, tokenizer):
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
def __call__(self, examples):
maxlen = max(len(e["input_ids"]) for e in examples)
input_ids, attn = [], []
for e in examples:
ids = e["input_ids"]
pad = maxlen - len(ids)
input_ids.append(ids + [self.pad_id] * pad)
attn.append(e["attention_mask"] + [0] * pad)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn, dtype=torch.long),
}
def load_sequences(n: int, hf_dataset: str | None, hf_config: str | None,
text_column: str, seed: int = 0) -> list[str]:
"""Stream `n` protein sequences from an HF dataset, or fall back to a tiny
synthetic set (offline/compute-node safety)."""
if hf_dataset is None:
return _synthetic(n, seed)
try:
from datasets import load_dataset
ds = load_dataset(hf_dataset, hf_config, split="train", streaming=True)
seqs = []
for ex in ds:
s = ex.get(text_column)
if s:
seqs.append(s)
if len(seqs) >= n:
break
if seqs:
return seqs
except Exception as exc: # pragma: no cover - network/offline guard
print(f"[data] streaming failed ({exc!r}); using synthetic sequences")
return _synthetic(n, seed)
_AA = "ACDEFGHIKLMNPQRSTVWY"
def _synthetic(n: int, seed: int) -> list[str]:
g = torch.Generator().manual_seed(seed)
out = []
for _ in range(n):
length = int(torch.randint(40, 220, (1,), generator=g).item())
idx = torch.randint(len(_AA), (length,), generator=g)
out.append("".join(_AA[i] for i in idx))
return out
|