Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
File size: 8,905 Bytes
e6bc942 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """Bidirectional adaptation of ProGen2 (LLM2Vec-style recipe).
ProGen2 (`hugohrban/progen2-base`, ProGenForCausalLM, GPT-J-style) applies
causality inside `ProGenAttention._attn` via a triangular `bias` buffer:
causal_mask = self.bias[..., kq, :kl]
attn_weights = torch.where(causal_mask, attn_weights, masked_bias)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask # padding mask (kept)
Setting every layer's `bias` buffer to all-True neutralises the causal
`torch.where` while preserving the padding-mask path -> full bidirectional
attention, with no rewrite of the module's forward.
Objectives (configurable, following the proposal):
* MNTP - masked next-token prediction: a masked position i is predicted
from the hidden state of the preceding token i-1.
* SimCSE - dropout-based unsupervised contrastive (InfoNCE, in-batch negs).
* joint - MNTP + lambda * SimCSE in one step.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.pytorch_utils import Conv1D
def make_bidirectional(model: nn.Module) -> int:
"""Replace every triangular causal `bias` buffer with all-True.
Returns the number of attention modules patched (sanity check).
"""
patched = 0
for module in model.modules():
bias = getattr(module, "bias", None)
# ProGenAttention.bias is a bool buffer shaped (1, 1, n_pos, n_pos).
if isinstance(bias, torch.Tensor) and bias.dtype == torch.bool and bias.dim() == 4:
module.bias = torch.ones_like(bias)
patched += 1
if patched == 0:
raise RuntimeError(
"No causal bias buffers found - ProGen internals may have changed; "
"inspect ProGenAttention for the causal mask."
)
return patched
def find_lora_targets(model: nn.Module) -> list[str]:
"""Auto-detect attention/MLP projection leaf names for LoRA (robust to renames).
Catches both `nn.Linear` (ProGen2, Llama, ...) and `Conv1D` (GPT-2's
c_attn/c_proj/c_fc) so the same recipe works on text decoders. `lm_head` is
excluded (it's the output head, not a representation projection)."""
names = set()
for name, mod in model.named_modules():
if isinstance(mod, (nn.Linear, Conv1D)) and not name.endswith("lm_head"):
names.add(name.split(".")[-1])
return sorted(names)
def mean_pool(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).to(hidden.dtype)
summed = (hidden * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
class BidirProGen(nn.Module):
"""Wraps a (bidirectional, LoRA-wrapped) ProGen2 for MNTP / SimCSE / joint."""
def __init__(self, base_model, objective: str = "joint", simcse_weight: float = 0.1,
temperature: float = 0.05):
super().__init__()
self.model = base_model
assert objective in {"mntp", "simcse", "joint"}
self.objective = objective
self.simcse_weight = simcse_weight
self.temperature = temperature
def _backbone(self, input_ids, attention_mask):
"""Return last hidden states (B, T, H) from the inner transformer."""
out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
return out.logits, out.hidden_states[-1]
def mntp_loss(self, logits, labels):
# Predict masked token i from position i-1: shift logits left vs labels.
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
return F.cross_entropy(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
ignore_index=-100,
)
def simcse_loss(self, input_ids, attention_mask):
# Two independent dropout passes -> positive pair; in-batch negatives.
_, h1 = self._backbone(input_ids, attention_mask)
_, h2 = self._backbone(input_ids, attention_mask)
z1 = F.normalize(mean_pool(h1, attention_mask), dim=-1)
z2 = F.normalize(mean_pool(h2, attention_mask), dim=-1)
sim = z1 @ z2.t() / self.temperature
targets = torch.arange(sim.size(0), device=sim.device)
return 0.5 * (F.cross_entropy(sim, targets) + F.cross_entropy(sim.t(), targets))
def forward(self, input_ids, attention_mask, labels=None, **_):
loss = input_ids.new_zeros((), dtype=torch.float32)
logs = {}
if self.objective in {"mntp", "joint"}:
logits, _ = self._backbone(input_ids, attention_mask)
l_mntp = self.mntp_loss(logits, labels)
loss = loss + l_mntp
logs["mntp"] = l_mntp.detach()
if self.objective in {"simcse", "joint"}:
l_sim = self.simcse_loss(input_ids, attention_mask)
# simcse_weight only balances the *joint* loss; a standalone SimCSE
# stage trains the contrastive loss at full weight.
w = self.simcse_weight if self.objective == "joint" else 1.0
loss = loss + w * l_sim
logs["simcse"] = l_sim.detach()
return {"loss": loss, "logs": logs}
def set_dropout(model: nn.Module, p: float) -> int:
"""Force every nn.Dropout to probability `p` (architecture-agnostic).
LLM2Vec's SimCSE stage explicitly enables dropout (default 0.1) so the two
forward passes of the same sequence differ — that dropout *is* the only
augmentation forming the positive pair. Some decoders (check ProGen2!) ship
with dropout=0, which would make the two views identical and the contrastive
signal degenerate. Call this BEFORE LoRA-wrapping so LoRA's own dropout is
left untouched."""
n = 0
for m in model.modules():
if isinstance(m, nn.Dropout):
m.p = p
n += 1
return n
def load_bidir_progen(model_name: str, objective: str, lora_r: int, lora_alpha: int,
lora_dropout: float, simcse_weight: float, temperature: float,
dtype: torch.dtype = torch.bfloat16,
init_adapter: str | None = None,
attn_dropout: float | None = None):
"""Load a decoder LM, make it bidirectional, attach (or resume) LoRA, wrap.
`init_adapter`: path to an existing LoRA adapter to CONTINUE training (this is
how the SimCSE stage starts from the MNTP checkpoint, per LLM2Vec). When None,
a fresh LoRA is initialised.
`attn_dropout`: if set, force all dropout to this prob (SimCSE stage needs it)."""
from peft import LoraConfig, get_peft_model, PeftModel
# ProGen2's remote modeling code predates transformers>=5, so it never sets
# the `all_tied_weights_keys` attribute that the weight loader now accesses
# unconditionally (modeling_utils._move_missing_keys_from_meta_to_device).
# ProGen (GPT-J-style) does not tie its lm_head, so an empty mapping is correct.
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
# transformers 4.44.2 names this kwarg `torch_dtype` (renamed to `dtype` in 5.x);
# passing `dtype` here would fall through to ProGenForCausalLM.__init__ and raise.
# attn_implementation="eager" forces the attention path that reads the triangular
# `bias` buffer make_bidirectional() flips. ProGen2 is eager-only anyway, but
# text models (GPT-2) default to SDPA, whose causal masking ignores that buffer
# — so without eager the bidirectional patch would silently do nothing.
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=dtype,
attn_implementation="eager",
)
n_patched = make_bidirectional(base)
n_drop = set_dropout(base, attn_dropout) if attn_dropout is not None else 0
if init_adapter is not None:
# Resume the existing adapter (e.g. SimCSE stage continuing from MNTP).
base = PeftModel.from_pretrained(base, init_adapter, is_trainable=True)
targets = sorted(base.peft_config["default"].target_modules)
else:
targets = find_lora_targets(base)
lora_cfg = LoraConfig(
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
target_modules=targets, bias="none", task_type="FEATURE_EXTRACTION",
)
base = get_peft_model(base, lora_cfg)
wrapped = BidirProGen(base, objective=objective, simcse_weight=simcse_weight,
temperature=temperature)
return wrapped, {"patched_layers": n_patched, "lora_targets": targets,
"dropout_set": n_drop, "resumed_adapter": init_adapter}
|