File size: 4,932 Bytes
e6bc942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Protein-sequence data + MNTP masking collator.

We stream a small slice of a public protein-sequence set (UniRef-style). The collator does BERT-style 80/10/10 masking and produces
`labels` with -100 everywhere except masked positions, so MNTP loss is only
taken on masked tokens (predicted from the preceding position in the model).
"""
from __future__ import annotations

import torch
from torch.utils.data import Dataset


class ProteinSeqDataset(Dataset):
    def __init__(self, sequences: list[str], tokenizer, max_length: int = 256):
        self.sequences = sequences
        self.tok = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        enc = self.tok(
            self.sequences[idx],
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}


class MNTPCollator:
    """Pad a batch and apply 80/10/10 masking; labels=-100 except masked."""

    def __init__(self, tokenizer, mlm_probability: float = 0.15):
        self.tok = tokenizer
        self.p = mlm_probability
        self.mask_id = (
            tokenizer.mask_token_id
            if tokenizer.mask_token_id is not None
            else tokenizer.pad_token_id
        )
        self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        self.vocab_size = len(tokenizer)

    def __call__(self, examples):
        maxlen = max(len(e["input_ids"]) for e in examples)
        input_ids, attn = [], []
        for e in examples:
            ids = e["input_ids"]
            pad = maxlen - len(ids)
            input_ids.append(ids + [self.pad_id] * pad)
            attn.append(e["attention_mask"] + [0] * pad)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attn = torch.tensor(attn, dtype=torch.long)

        labels = input_ids.clone()
        prob = torch.full(labels.shape, self.p)
        prob.masked_fill_(attn == 0, 0.0)
        masked = torch.bernoulli(prob).bool()
        labels[~masked] = -100

        # 80% -> [MASK]
        repl = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked
        input_ids[repl] = self.mask_id
        # 10% -> random token (remaining of the masked set, half of them)
        rand = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked & ~repl
        input_ids[rand] = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)[rand]
        # remaining 10% kept unchanged

        return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}


class CleanCollator:
    """Pad a batch with NO masking — for the SimCSE stage.

    LLM2Vec's unsupervised SimCSE runs on clean (unmasked) sentences; the two
    positive views come purely from dropout, not from corruption. So the
    contrastive stage must NOT see MNTP-masked input (that was the bug in our
    joint objective). No `labels` are produced — SimCSE doesn't use them."""

    def __init__(self, tokenizer):
        self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

    def __call__(self, examples):
        maxlen = max(len(e["input_ids"]) for e in examples)
        input_ids, attn = [], []
        for e in examples:
            ids = e["input_ids"]
            pad = maxlen - len(ids)
            input_ids.append(ids + [self.pad_id] * pad)
            attn.append(e["attention_mask"] + [0] * pad)
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
        }


def load_sequences(n: int, hf_dataset: str | None, hf_config: str | None,
                   text_column: str, seed: int = 0) -> list[str]:
    """Stream `n` protein sequences from an HF dataset, or fall back to a tiny
    synthetic set (offline/compute-node safety)."""
    if hf_dataset is None:
        return _synthetic(n, seed)
    try:
        from datasets import load_dataset

        ds = load_dataset(hf_dataset, hf_config, split="train", streaming=True)
        seqs = []
        for ex in ds:
            s = ex.get(text_column)
            if s:
                seqs.append(s)
            if len(seqs) >= n:
                break
        if seqs:
            return seqs
    except Exception as exc:  # pragma: no cover - network/offline guard
        print(f"[data] streaming failed ({exc!r}); using synthetic sequences")
    return _synthetic(n, seed)


_AA = "ACDEFGHIKLMNPQRSTVWY"


def _synthetic(n: int, seed: int) -> list[str]:
    g = torch.Generator().manual_seed(seed)
    out = []
    for _ in range(n):
        length = int(torch.randint(40, 220, (1,), generator=g).item())
        idx = torch.randint(len(_AA), (length,), generator=g)
        out.append("".join(_AA[i] for i in idx))
    return out