ratishsp's picture
Bidirectional ProGen2 LoRA adapter + 9-task benchmark + code
e6bc942
Raw
History Blame Contribute Delete
4.93 kB
"""Protein-sequence data + MNTP masking collator.
We stream a small slice of a public protein-sequence set (UniRef-style). The collator does BERT-style 80/10/10 masking and produces
`labels` with -100 everywhere except masked positions, so MNTP loss is only
taken on masked tokens (predicted from the preceding position in the model).
"""
from __future__ import annotations
import torch
from torch.utils.data import Dataset
class ProteinSeqDataset(Dataset):
def __init__(self, sequences: list[str], tokenizer, max_length: int = 256):
self.sequences = sequences
self.tok = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
enc = self.tok(
self.sequences[idx],
truncation=True,
max_length=self.max_length,
return_tensors=None,
)
return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
class MNTPCollator:
"""Pad a batch and apply 80/10/10 masking; labels=-100 except masked."""
def __init__(self, tokenizer, mlm_probability: float = 0.15):
self.tok = tokenizer
self.p = mlm_probability
self.mask_id = (
tokenizer.mask_token_id
if tokenizer.mask_token_id is not None
else tokenizer.pad_token_id
)
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.vocab_size = len(tokenizer)
def __call__(self, examples):
maxlen = max(len(e["input_ids"]) for e in examples)
input_ids, attn = [], []
for e in examples:
ids = e["input_ids"]
pad = maxlen - len(ids)
input_ids.append(ids + [self.pad_id] * pad)
attn.append(e["attention_mask"] + [0] * pad)
input_ids = torch.tensor(input_ids, dtype=torch.long)
attn = torch.tensor(attn, dtype=torch.long)
labels = input_ids.clone()
prob = torch.full(labels.shape, self.p)
prob.masked_fill_(attn == 0, 0.0)
masked = torch.bernoulli(prob).bool()
labels[~masked] = -100
# 80% -> [MASK]
repl = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked
input_ids[repl] = self.mask_id
# 10% -> random token (remaining of the masked set, half of them)
rand = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked & ~repl
input_ids[rand] = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)[rand]
# remaining 10% kept unchanged
return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
class CleanCollator:
"""Pad a batch with NO masking — for the SimCSE stage.
LLM2Vec's unsupervised SimCSE runs on clean (unmasked) sentences; the two
positive views come purely from dropout, not from corruption. So the
contrastive stage must NOT see MNTP-masked input (that was the bug in our
joint objective). No `labels` are produced — SimCSE doesn't use them."""
def __init__(self, tokenizer):
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
def __call__(self, examples):
maxlen = max(len(e["input_ids"]) for e in examples)
input_ids, attn = [], []
for e in examples:
ids = e["input_ids"]
pad = maxlen - len(ids)
input_ids.append(ids + [self.pad_id] * pad)
attn.append(e["attention_mask"] + [0] * pad)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attn, dtype=torch.long),
}
def load_sequences(n: int, hf_dataset: str | None, hf_config: str | None,
text_column: str, seed: int = 0) -> list[str]:
"""Stream `n` protein sequences from an HF dataset, or fall back to a tiny
synthetic set (offline/compute-node safety)."""
if hf_dataset is None:
return _synthetic(n, seed)
try:
from datasets import load_dataset
ds = load_dataset(hf_dataset, hf_config, split="train", streaming=True)
seqs = []
for ex in ds:
s = ex.get(text_column)
if s:
seqs.append(s)
if len(seqs) >= n:
break
if seqs:
return seqs
except Exception as exc: # pragma: no cover - network/offline guard
print(f"[data] streaming failed ({exc!r}); using synthetic sequences")
return _synthetic(n, seed)
_AA = "ACDEFGHIKLMNPQRSTVWY"
def _synthetic(n: int, seed: int) -> list[str]:
g = torch.Generator().manual_seed(seed)
out = []
for _ in range(n):
length = int(torch.randint(40, 220, (1,), generator=g).item())
idx = torch.randint(len(_AA), (length,), generator=g)
out.append("".join(_AA[i] for i in idx))
return out