ratishsp's picture
Bidirectional ProGen2 LoRA adapter + 9-task benchmark + code
e6bc942
Raw
History Blame Contribute Delete
8.91 kB
"""Bidirectional adaptation of ProGen2 (LLM2Vec-style recipe).
ProGen2 (`hugohrban/progen2-base`, ProGenForCausalLM, GPT-J-style) applies
causality inside `ProGenAttention._attn` via a triangular `bias` buffer:
causal_mask = self.bias[..., kq, :kl]
attn_weights = torch.where(causal_mask, attn_weights, masked_bias)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask # padding mask (kept)
Setting every layer's `bias` buffer to all-True neutralises the causal
`torch.where` while preserving the padding-mask path -> full bidirectional
attention, with no rewrite of the module's forward.
Objectives (configurable, following the proposal):
* MNTP - masked next-token prediction: a masked position i is predicted
from the hidden state of the preceding token i-1.
* SimCSE - dropout-based unsupervised contrastive (InfoNCE, in-batch negs).
* joint - MNTP + lambda * SimCSE in one step.
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.pytorch_utils import Conv1D
def make_bidirectional(model: nn.Module) -> int:
"""Replace every triangular causal `bias` buffer with all-True.
Returns the number of attention modules patched (sanity check).
"""
patched = 0
for module in model.modules():
bias = getattr(module, "bias", None)
# ProGenAttention.bias is a bool buffer shaped (1, 1, n_pos, n_pos).
if isinstance(bias, torch.Tensor) and bias.dtype == torch.bool and bias.dim() == 4:
module.bias = torch.ones_like(bias)
patched += 1
if patched == 0:
raise RuntimeError(
"No causal bias buffers found - ProGen internals may have changed; "
"inspect ProGenAttention for the causal mask."
)
return patched
def find_lora_targets(model: nn.Module) -> list[str]:
"""Auto-detect attention/MLP projection leaf names for LoRA (robust to renames).
Catches both `nn.Linear` (ProGen2, Llama, ...) and `Conv1D` (GPT-2's
c_attn/c_proj/c_fc) so the same recipe works on text decoders. `lm_head` is
excluded (it's the output head, not a representation projection)."""
names = set()
for name, mod in model.named_modules():
if isinstance(mod, (nn.Linear, Conv1D)) and not name.endswith("lm_head"):
names.add(name.split(".")[-1])
return sorted(names)
def mean_pool(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).to(hidden.dtype)
summed = (hidden * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
class BidirProGen(nn.Module):
"""Wraps a (bidirectional, LoRA-wrapped) ProGen2 for MNTP / SimCSE / joint."""
def __init__(self, base_model, objective: str = "joint", simcse_weight: float = 0.1,
temperature: float = 0.05):
super().__init__()
self.model = base_model
assert objective in {"mntp", "simcse", "joint"}
self.objective = objective
self.simcse_weight = simcse_weight
self.temperature = temperature
def _backbone(self, input_ids, attention_mask):
"""Return last hidden states (B, T, H) from the inner transformer."""
out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
return out.logits, out.hidden_states[-1]
def mntp_loss(self, logits, labels):
# Predict masked token i from position i-1: shift logits left vs labels.
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
return F.cross_entropy(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
ignore_index=-100,
)
def simcse_loss(self, input_ids, attention_mask):
# Two independent dropout passes -> positive pair; in-batch negatives.
_, h1 = self._backbone(input_ids, attention_mask)
_, h2 = self._backbone(input_ids, attention_mask)
z1 = F.normalize(mean_pool(h1, attention_mask), dim=-1)
z2 = F.normalize(mean_pool(h2, attention_mask), dim=-1)
sim = z1 @ z2.t() / self.temperature
targets = torch.arange(sim.size(0), device=sim.device)
return 0.5 * (F.cross_entropy(sim, targets) + F.cross_entropy(sim.t(), targets))
def forward(self, input_ids, attention_mask, labels=None, **_):
loss = input_ids.new_zeros((), dtype=torch.float32)
logs = {}
if self.objective in {"mntp", "joint"}:
logits, _ = self._backbone(input_ids, attention_mask)
l_mntp = self.mntp_loss(logits, labels)
loss = loss + l_mntp
logs["mntp"] = l_mntp.detach()
if self.objective in {"simcse", "joint"}:
l_sim = self.simcse_loss(input_ids, attention_mask)
# simcse_weight only balances the *joint* loss; a standalone SimCSE
# stage trains the contrastive loss at full weight.
w = self.simcse_weight if self.objective == "joint" else 1.0
loss = loss + w * l_sim
logs["simcse"] = l_sim.detach()
return {"loss": loss, "logs": logs}
def set_dropout(model: nn.Module, p: float) -> int:
"""Force every nn.Dropout to probability `p` (architecture-agnostic).
LLM2Vec's SimCSE stage explicitly enables dropout (default 0.1) so the two
forward passes of the same sequence differ — that dropout *is* the only
augmentation forming the positive pair. Some decoders (check ProGen2!) ship
with dropout=0, which would make the two views identical and the contrastive
signal degenerate. Call this BEFORE LoRA-wrapping so LoRA's own dropout is
left untouched."""
n = 0
for m in model.modules():
if isinstance(m, nn.Dropout):
m.p = p
n += 1
return n
def load_bidir_progen(model_name: str, objective: str, lora_r: int, lora_alpha: int,
lora_dropout: float, simcse_weight: float, temperature: float,
dtype: torch.dtype = torch.bfloat16,
init_adapter: str | None = None,
attn_dropout: float | None = None):
"""Load a decoder LM, make it bidirectional, attach (or resume) LoRA, wrap.
`init_adapter`: path to an existing LoRA adapter to CONTINUE training (this is
how the SimCSE stage starts from the MNTP checkpoint, per LLM2Vec). When None,
a fresh LoRA is initialised.
`attn_dropout`: if set, force all dropout to this prob (SimCSE stage needs it)."""
from peft import LoraConfig, get_peft_model, PeftModel
# ProGen2's remote modeling code predates transformers>=5, so it never sets
# the `all_tied_weights_keys` attribute that the weight loader now accesses
# unconditionally (modeling_utils._move_missing_keys_from_meta_to_device).
# ProGen (GPT-J-style) does not tie its lm_head, so an empty mapping is correct.
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
# transformers 4.44.2 names this kwarg `torch_dtype` (renamed to `dtype` in 5.x);
# passing `dtype` here would fall through to ProGenForCausalLM.__init__ and raise.
# attn_implementation="eager" forces the attention path that reads the triangular
# `bias` buffer make_bidirectional() flips. ProGen2 is eager-only anyway, but
# text models (GPT-2) default to SDPA, whose causal masking ignores that buffer
# — so without eager the bidirectional patch would silently do nothing.
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=dtype,
attn_implementation="eager",
)
n_patched = make_bidirectional(base)
n_drop = set_dropout(base, attn_dropout) if attn_dropout is not None else 0
if init_adapter is not None:
# Resume the existing adapter (e.g. SimCSE stage continuing from MNTP).
base = PeftModel.from_pretrained(base, init_adapter, is_trainable=True)
targets = sorted(base.peft_config["default"].target_modules)
else:
targets = find_lora_targets(base)
lora_cfg = LoraConfig(
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
target_modules=targets, bias="none", task_type="FEATURE_EXTRACTION",
)
base = get_peft_model(base, lora_cfg)
wrapped = BidirProGen(base, objective=objective, simcse_weight=simcse_weight,
temperature=temperature)
return wrapped, {"patched_layers": n_patched, "lora_targets": targets,
"dropout_set": n_drop, "resumed_adapter": init_adapter}