"""Protein-sequence data + MNTP masking collator. We stream a small slice of a public protein-sequence set (UniRef-style). The collator does BERT-style 80/10/10 masking and produces `labels` with -100 everywhere except masked positions, so MNTP loss is only taken on masked tokens (predicted from the preceding position in the model). """ from __future__ import annotations import torch from torch.utils.data import Dataset class ProteinSeqDataset(Dataset): def __init__(self, sequences: list[str], tokenizer, max_length: int = 256): self.sequences = sequences self.tok = tokenizer self.max_length = max_length def __len__(self): return len(self.sequences) def __getitem__(self, idx): enc = self.tok( self.sequences[idx], truncation=True, max_length=self.max_length, return_tensors=None, ) return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} class MNTPCollator: """Pad a batch and apply 80/10/10 masking; labels=-100 except masked.""" def __init__(self, tokenizer, mlm_probability: float = 0.15): self.tok = tokenizer self.p = mlm_probability self.mask_id = ( tokenizer.mask_token_id if tokenizer.mask_token_id is not None else tokenizer.pad_token_id ) self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.vocab_size = len(tokenizer) def __call__(self, examples): maxlen = max(len(e["input_ids"]) for e in examples) input_ids, attn = [], [] for e in examples: ids = e["input_ids"] pad = maxlen - len(ids) input_ids.append(ids + [self.pad_id] * pad) attn.append(e["attention_mask"] + [0] * pad) input_ids = torch.tensor(input_ids, dtype=torch.long) attn = torch.tensor(attn, dtype=torch.long) labels = input_ids.clone() prob = torch.full(labels.shape, self.p) prob.masked_fill_(attn == 0, 0.0) masked = torch.bernoulli(prob).bool() labels[~masked] = -100 # 80% -> [MASK] repl = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked input_ids[repl] = self.mask_id # 10% -> random token (remaining of the masked set, half of them) rand = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked & ~repl input_ids[rand] = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)[rand] # remaining 10% kept unchanged return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} class CleanCollator: """Pad a batch with NO masking — for the SimCSE stage. LLM2Vec's unsupervised SimCSE runs on clean (unmasked) sentences; the two positive views come purely from dropout, not from corruption. So the contrastive stage must NOT see MNTP-masked input (that was the bug in our joint objective). No `labels` are produced — SimCSE doesn't use them.""" def __init__(self, tokenizer): self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 def __call__(self, examples): maxlen = max(len(e["input_ids"]) for e in examples) input_ids, attn = [], [] for e in examples: ids = e["input_ids"] pad = maxlen - len(ids) input_ids.append(ids + [self.pad_id] * pad) attn.append(e["attention_mask"] + [0] * pad) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn, dtype=torch.long), } def load_sequences(n: int, hf_dataset: str | None, hf_config: str | None, text_column: str, seed: int = 0) -> list[str]: """Stream `n` protein sequences from an HF dataset, or fall back to a tiny synthetic set (offline/compute-node safety).""" if hf_dataset is None: return _synthetic(n, seed) try: from datasets import load_dataset ds = load_dataset(hf_dataset, hf_config, split="train", streaming=True) seqs = [] for ex in ds: s = ex.get(text_column) if s: seqs.append(s) if len(seqs) >= n: break if seqs: return seqs except Exception as exc: # pragma: no cover - network/offline guard print(f"[data] streaming failed ({exc!r}); using synthetic sequences") return _synthetic(n, seed) _AA = "ACDEFGHIKLMNPQRSTVWY" def _synthetic(n: int, seed: int) -> list[str]: g = torch.Generator().manual_seed(seed) out = [] for _ in range(n): length = int(torch.randint(40, 220, (1,), generator=g).item()) idx = torch.randint(len(_AA), (length,), generator=g) out.append("".join(_AA[i] for i in idx)) return out