Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| """Protein-sequence data + MNTP masking collator. | |
| We stream a small slice of a public protein-sequence set (UniRef-style). The collator does BERT-style 80/10/10 masking and produces | |
| `labels` with -100 everywhere except masked positions, so MNTP loss is only | |
| taken on masked tokens (predicted from the preceding position in the model). | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| from torch.utils.data import Dataset | |
| class ProteinSeqDataset(Dataset): | |
| def __init__(self, sequences: list[str], tokenizer, max_length: int = 256): | |
| self.sequences = sequences | |
| self.tok = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.sequences) | |
| def __getitem__(self, idx): | |
| enc = self.tok( | |
| self.sequences[idx], | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors=None, | |
| ) | |
| return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} | |
| class MNTPCollator: | |
| """Pad a batch and apply 80/10/10 masking; labels=-100 except masked.""" | |
| def __init__(self, tokenizer, mlm_probability: float = 0.15): | |
| self.tok = tokenizer | |
| self.p = mlm_probability | |
| self.mask_id = ( | |
| tokenizer.mask_token_id | |
| if tokenizer.mask_token_id is not None | |
| else tokenizer.pad_token_id | |
| ) | |
| self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| self.vocab_size = len(tokenizer) | |
| def __call__(self, examples): | |
| maxlen = max(len(e["input_ids"]) for e in examples) | |
| input_ids, attn = [], [] | |
| for e in examples: | |
| ids = e["input_ids"] | |
| pad = maxlen - len(ids) | |
| input_ids.append(ids + [self.pad_id] * pad) | |
| attn.append(e["attention_mask"] + [0] * pad) | |
| input_ids = torch.tensor(input_ids, dtype=torch.long) | |
| attn = torch.tensor(attn, dtype=torch.long) | |
| labels = input_ids.clone() | |
| prob = torch.full(labels.shape, self.p) | |
| prob.masked_fill_(attn == 0, 0.0) | |
| masked = torch.bernoulli(prob).bool() | |
| labels[~masked] = -100 | |
| # 80% -> [MASK] | |
| repl = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked | |
| input_ids[repl] = self.mask_id | |
| # 10% -> random token (remaining of the masked set, half of them) | |
| rand = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked & ~repl | |
| input_ids[rand] = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)[rand] | |
| # remaining 10% kept unchanged | |
| return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} | |
| class CleanCollator: | |
| """Pad a batch with NO masking — for the SimCSE stage. | |
| LLM2Vec's unsupervised SimCSE runs on clean (unmasked) sentences; the two | |
| positive views come purely from dropout, not from corruption. So the | |
| contrastive stage must NOT see MNTP-masked input (that was the bug in our | |
| joint objective). No `labels` are produced — SimCSE doesn't use them.""" | |
| def __init__(self, tokenizer): | |
| self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| def __call__(self, examples): | |
| maxlen = max(len(e["input_ids"]) for e in examples) | |
| input_ids, attn = [], [] | |
| for e in examples: | |
| ids = e["input_ids"] | |
| pad = maxlen - len(ids) | |
| input_ids.append(ids + [self.pad_id] * pad) | |
| attn.append(e["attention_mask"] + [0] * pad) | |
| return { | |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), | |
| "attention_mask": torch.tensor(attn, dtype=torch.long), | |
| } | |
| def load_sequences(n: int, hf_dataset: str | None, hf_config: str | None, | |
| text_column: str, seed: int = 0) -> list[str]: | |
| """Stream `n` protein sequences from an HF dataset, or fall back to a tiny | |
| synthetic set (offline/compute-node safety).""" | |
| if hf_dataset is None: | |
| return _synthetic(n, seed) | |
| try: | |
| from datasets import load_dataset | |
| ds = load_dataset(hf_dataset, hf_config, split="train", streaming=True) | |
| seqs = [] | |
| for ex in ds: | |
| s = ex.get(text_column) | |
| if s: | |
| seqs.append(s) | |
| if len(seqs) >= n: | |
| break | |
| if seqs: | |
| return seqs | |
| except Exception as exc: # pragma: no cover - network/offline guard | |
| print(f"[data] streaming failed ({exc!r}); using synthetic sequences") | |
| return _synthetic(n, seed) | |
| _AA = "ACDEFGHIKLMNPQRSTVWY" | |
| def _synthetic(n: int, seed: int) -> list[str]: | |
| g = torch.Generator().manual_seed(seed) | |
| out = [] | |
| for _ in range(n): | |
| length = int(torch.randint(40, 220, (1,), generator=g).item()) | |
| idx = torch.randint(len(_AA), (length,), generator=g) | |
| out.append("".join(_AA[i] for i in idx)) | |
| return out | |