Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| """Bidirectional adaptation of ProGen2 (LLM2Vec-style recipe). | |
| ProGen2 (`hugohrban/progen2-base`, ProGenForCausalLM, GPT-J-style) applies | |
| causality inside `ProGenAttention._attn` via a triangular `bias` buffer: | |
| causal_mask = self.bias[..., kq, :kl] | |
| attn_weights = torch.where(causal_mask, attn_weights, masked_bias) | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask # padding mask (kept) | |
| Setting every layer's `bias` buffer to all-True neutralises the causal | |
| `torch.where` while preserving the padding-mask path -> full bidirectional | |
| attention, with no rewrite of the module's forward. | |
| Objectives (configurable, following the proposal): | |
| * MNTP - masked next-token prediction: a masked position i is predicted | |
| from the hidden state of the preceding token i-1. | |
| * SimCSE - dropout-based unsupervised contrastive (InfoNCE, in-batch negs). | |
| * joint - MNTP + lambda * SimCSE in one step. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM | |
| from transformers.pytorch_utils import Conv1D | |
| def make_bidirectional(model: nn.Module) -> int: | |
| """Replace every triangular causal `bias` buffer with all-True. | |
| Returns the number of attention modules patched (sanity check). | |
| """ | |
| patched = 0 | |
| for module in model.modules(): | |
| bias = getattr(module, "bias", None) | |
| # ProGenAttention.bias is a bool buffer shaped (1, 1, n_pos, n_pos). | |
| if isinstance(bias, torch.Tensor) and bias.dtype == torch.bool and bias.dim() == 4: | |
| module.bias = torch.ones_like(bias) | |
| patched += 1 | |
| if patched == 0: | |
| raise RuntimeError( | |
| "No causal bias buffers found - ProGen internals may have changed; " | |
| "inspect ProGenAttention for the causal mask." | |
| ) | |
| return patched | |
| def find_lora_targets(model: nn.Module) -> list[str]: | |
| """Auto-detect attention/MLP projection leaf names for LoRA (robust to renames). | |
| Catches both `nn.Linear` (ProGen2, Llama, ...) and `Conv1D` (GPT-2's | |
| c_attn/c_proj/c_fc) so the same recipe works on text decoders. `lm_head` is | |
| excluded (it's the output head, not a representation projection).""" | |
| names = set() | |
| for name, mod in model.named_modules(): | |
| if isinstance(mod, (nn.Linear, Conv1D)) and not name.endswith("lm_head"): | |
| names.add(name.split(".")[-1]) | |
| return sorted(names) | |
| def mean_pool(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.unsqueeze(-1).to(hidden.dtype) | |
| summed = (hidden * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-9) | |
| return summed / counts | |
| class BidirProGen(nn.Module): | |
| """Wraps a (bidirectional, LoRA-wrapped) ProGen2 for MNTP / SimCSE / joint.""" | |
| def __init__(self, base_model, objective: str = "joint", simcse_weight: float = 0.1, | |
| temperature: float = 0.05): | |
| super().__init__() | |
| self.model = base_model | |
| assert objective in {"mntp", "simcse", "joint"} | |
| self.objective = objective | |
| self.simcse_weight = simcse_weight | |
| self.temperature = temperature | |
| def _backbone(self, input_ids, attention_mask): | |
| """Return last hidden states (B, T, H) from the inner transformer.""" | |
| out = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| ) | |
| return out.logits, out.hidden_states[-1] | |
| def mntp_loss(self, logits, labels): | |
| # Predict masked token i from position i-1: shift logits left vs labels. | |
| shift_logits = logits[:, :-1, :] | |
| shift_labels = labels[:, 1:] | |
| return F.cross_entropy( | |
| shift_logits.reshape(-1, shift_logits.size(-1)), | |
| shift_labels.reshape(-1), | |
| ignore_index=-100, | |
| ) | |
| def simcse_loss(self, input_ids, attention_mask): | |
| # Two independent dropout passes -> positive pair; in-batch negatives. | |
| _, h1 = self._backbone(input_ids, attention_mask) | |
| _, h2 = self._backbone(input_ids, attention_mask) | |
| z1 = F.normalize(mean_pool(h1, attention_mask), dim=-1) | |
| z2 = F.normalize(mean_pool(h2, attention_mask), dim=-1) | |
| sim = z1 @ z2.t() / self.temperature | |
| targets = torch.arange(sim.size(0), device=sim.device) | |
| return 0.5 * (F.cross_entropy(sim, targets) + F.cross_entropy(sim.t(), targets)) | |
| def forward(self, input_ids, attention_mask, labels=None, **_): | |
| loss = input_ids.new_zeros((), dtype=torch.float32) | |
| logs = {} | |
| if self.objective in {"mntp", "joint"}: | |
| logits, _ = self._backbone(input_ids, attention_mask) | |
| l_mntp = self.mntp_loss(logits, labels) | |
| loss = loss + l_mntp | |
| logs["mntp"] = l_mntp.detach() | |
| if self.objective in {"simcse", "joint"}: | |
| l_sim = self.simcse_loss(input_ids, attention_mask) | |
| # simcse_weight only balances the *joint* loss; a standalone SimCSE | |
| # stage trains the contrastive loss at full weight. | |
| w = self.simcse_weight if self.objective == "joint" else 1.0 | |
| loss = loss + w * l_sim | |
| logs["simcse"] = l_sim.detach() | |
| return {"loss": loss, "logs": logs} | |
| def set_dropout(model: nn.Module, p: float) -> int: | |
| """Force every nn.Dropout to probability `p` (architecture-agnostic). | |
| LLM2Vec's SimCSE stage explicitly enables dropout (default 0.1) so the two | |
| forward passes of the same sequence differ — that dropout *is* the only | |
| augmentation forming the positive pair. Some decoders (check ProGen2!) ship | |
| with dropout=0, which would make the two views identical and the contrastive | |
| signal degenerate. Call this BEFORE LoRA-wrapping so LoRA's own dropout is | |
| left untouched.""" | |
| n = 0 | |
| for m in model.modules(): | |
| if isinstance(m, nn.Dropout): | |
| m.p = p | |
| n += 1 | |
| return n | |
| def load_bidir_progen(model_name: str, objective: str, lora_r: int, lora_alpha: int, | |
| lora_dropout: float, simcse_weight: float, temperature: float, | |
| dtype: torch.dtype = torch.bfloat16, | |
| init_adapter: str | None = None, | |
| attn_dropout: float | None = None): | |
| """Load a decoder LM, make it bidirectional, attach (or resume) LoRA, wrap. | |
| `init_adapter`: path to an existing LoRA adapter to CONTINUE training (this is | |
| how the SimCSE stage starts from the MNTP checkpoint, per LLM2Vec). When None, | |
| a fresh LoRA is initialised. | |
| `attn_dropout`: if set, force all dropout to this prob (SimCSE stage needs it).""" | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| # ProGen2's remote modeling code predates transformers>=5, so it never sets | |
| # the `all_tied_weights_keys` attribute that the weight loader now accesses | |
| # unconditionally (modeling_utils._move_missing_keys_from_meta_to_device). | |
| # ProGen (GPT-J-style) does not tie its lm_head, so an empty mapping is correct. | |
| import transformers.modeling_utils as _mu | |
| if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel): | |
| _mu.PreTrainedModel.all_tied_weights_keys = {} | |
| # transformers 4.44.2 names this kwarg `torch_dtype` (renamed to `dtype` in 5.x); | |
| # passing `dtype` here would fall through to ProGenForCausalLM.__init__ and raise. | |
| # attn_implementation="eager" forces the attention path that reads the triangular | |
| # `bias` buffer make_bidirectional() flips. ProGen2 is eager-only anyway, but | |
| # text models (GPT-2) default to SDPA, whose causal masking ignores that buffer | |
| # — so without eager the bidirectional patch would silently do nothing. | |
| base = AutoModelForCausalLM.from_pretrained( | |
| model_name, trust_remote_code=True, torch_dtype=dtype, | |
| attn_implementation="eager", | |
| ) | |
| n_patched = make_bidirectional(base) | |
| n_drop = set_dropout(base, attn_dropout) if attn_dropout is not None else 0 | |
| if init_adapter is not None: | |
| # Resume the existing adapter (e.g. SimCSE stage continuing from MNTP). | |
| base = PeftModel.from_pretrained(base, init_adapter, is_trainable=True) | |
| targets = sorted(base.peft_config["default"].target_modules) | |
| else: | |
| targets = find_lora_targets(base) | |
| lora_cfg = LoraConfig( | |
| r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
| target_modules=targets, bias="none", task_type="FEATURE_EXTRACTION", | |
| ) | |
| base = get_peft_model(base, lora_cfg) | |
| wrapped = BidirProGen(base, objective=objective, simcse_weight=simcse_weight, | |
| temperature=temperature) | |
| return wrapped, {"patched_layers": n_patched, "lora_targets": targets, | |
| "dropout_set": n_drop, "resumed_adapter": init_adapter} | |