"""Bidirectional adaptation of ProGen2 (LLM2Vec-style recipe). ProGen2 (`hugohrban/progen2-base`, ProGenForCausalLM, GPT-J-style) applies causality inside `ProGenAttention._attn` via a triangular `bias` buffer: causal_mask = self.bias[..., kq, :kl] attn_weights = torch.where(causal_mask, attn_weights, masked_bias) if attention_mask is not None: attn_weights = attn_weights + attention_mask # padding mask (kept) Setting every layer's `bias` buffer to all-True neutralises the causal `torch.where` while preserving the padding-mask path -> full bidirectional attention, with no rewrite of the module's forward. Objectives (configurable, following the proposal): * MNTP - masked next-token prediction: a masked position i is predicted from the hidden state of the preceding token i-1. * SimCSE - dropout-based unsupervised contrastive (InfoNCE, in-batch negs). * joint - MNTP + lambda * SimCSE in one step. """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM from transformers.pytorch_utils import Conv1D def make_bidirectional(model: nn.Module) -> int: """Replace every triangular causal `bias` buffer with all-True. Returns the number of attention modules patched (sanity check). """ patched = 0 for module in model.modules(): bias = getattr(module, "bias", None) # ProGenAttention.bias is a bool buffer shaped (1, 1, n_pos, n_pos). if isinstance(bias, torch.Tensor) and bias.dtype == torch.bool and bias.dim() == 4: module.bias = torch.ones_like(bias) patched += 1 if patched == 0: raise RuntimeError( "No causal bias buffers found - ProGen internals may have changed; " "inspect ProGenAttention for the causal mask." ) return patched def find_lora_targets(model: nn.Module) -> list[str]: """Auto-detect attention/MLP projection leaf names for LoRA (robust to renames). Catches both `nn.Linear` (ProGen2, Llama, ...) and `Conv1D` (GPT-2's c_attn/c_proj/c_fc) so the same recipe works on text decoders. `lm_head` is excluded (it's the output head, not a representation projection).""" names = set() for name, mod in model.named_modules(): if isinstance(mod, (nn.Linear, Conv1D)) and not name.endswith("lm_head"): names.add(name.split(".")[-1]) return sorted(names) def mean_pool(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1).to(hidden.dtype) summed = (hidden * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1e-9) return summed / counts class BidirProGen(nn.Module): """Wraps a (bidirectional, LoRA-wrapped) ProGen2 for MNTP / SimCSE / joint.""" def __init__(self, base_model, objective: str = "joint", simcse_weight: float = 0.1, temperature: float = 0.05): super().__init__() self.model = base_model assert objective in {"mntp", "simcse", "joint"} self.objective = objective self.simcse_weight = simcse_weight self.temperature = temperature def _backbone(self, input_ids, attention_mask): """Return last hidden states (B, T, H) from the inner transformer.""" out = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) return out.logits, out.hidden_states[-1] def mntp_loss(self, logits, labels): # Predict masked token i from position i-1: shift logits left vs labels. shift_logits = logits[:, :-1, :] shift_labels = labels[:, 1:] return F.cross_entropy( shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1), ignore_index=-100, ) def simcse_loss(self, input_ids, attention_mask): # Two independent dropout passes -> positive pair; in-batch negatives. _, h1 = self._backbone(input_ids, attention_mask) _, h2 = self._backbone(input_ids, attention_mask) z1 = F.normalize(mean_pool(h1, attention_mask), dim=-1) z2 = F.normalize(mean_pool(h2, attention_mask), dim=-1) sim = z1 @ z2.t() / self.temperature targets = torch.arange(sim.size(0), device=sim.device) return 0.5 * (F.cross_entropy(sim, targets) + F.cross_entropy(sim.t(), targets)) def forward(self, input_ids, attention_mask, labels=None, **_): loss = input_ids.new_zeros((), dtype=torch.float32) logs = {} if self.objective in {"mntp", "joint"}: logits, _ = self._backbone(input_ids, attention_mask) l_mntp = self.mntp_loss(logits, labels) loss = loss + l_mntp logs["mntp"] = l_mntp.detach() if self.objective in {"simcse", "joint"}: l_sim = self.simcse_loss(input_ids, attention_mask) # simcse_weight only balances the *joint* loss; a standalone SimCSE # stage trains the contrastive loss at full weight. w = self.simcse_weight if self.objective == "joint" else 1.0 loss = loss + w * l_sim logs["simcse"] = l_sim.detach() return {"loss": loss, "logs": logs} def set_dropout(model: nn.Module, p: float) -> int: """Force every nn.Dropout to probability `p` (architecture-agnostic). LLM2Vec's SimCSE stage explicitly enables dropout (default 0.1) so the two forward passes of the same sequence differ — that dropout *is* the only augmentation forming the positive pair. Some decoders (check ProGen2!) ship with dropout=0, which would make the two views identical and the contrastive signal degenerate. Call this BEFORE LoRA-wrapping so LoRA's own dropout is left untouched.""" n = 0 for m in model.modules(): if isinstance(m, nn.Dropout): m.p = p n += 1 return n def load_bidir_progen(model_name: str, objective: str, lora_r: int, lora_alpha: int, lora_dropout: float, simcse_weight: float, temperature: float, dtype: torch.dtype = torch.bfloat16, init_adapter: str | None = None, attn_dropout: float | None = None): """Load a decoder LM, make it bidirectional, attach (or resume) LoRA, wrap. `init_adapter`: path to an existing LoRA adapter to CONTINUE training (this is how the SimCSE stage starts from the MNTP checkpoint, per LLM2Vec). When None, a fresh LoRA is initialised. `attn_dropout`: if set, force all dropout to this prob (SimCSE stage needs it).""" from peft import LoraConfig, get_peft_model, PeftModel # ProGen2's remote modeling code predates transformers>=5, so it never sets # the `all_tied_weights_keys` attribute that the weight loader now accesses # unconditionally (modeling_utils._move_missing_keys_from_meta_to_device). # ProGen (GPT-J-style) does not tie its lm_head, so an empty mapping is correct. import transformers.modeling_utils as _mu if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel): _mu.PreTrainedModel.all_tied_weights_keys = {} # transformers 4.44.2 names this kwarg `torch_dtype` (renamed to `dtype` in 5.x); # passing `dtype` here would fall through to ProGenForCausalLM.__init__ and raise. # attn_implementation="eager" forces the attention path that reads the triangular # `bias` buffer make_bidirectional() flips. ProGen2 is eager-only anyway, but # text models (GPT-2) default to SDPA, whose causal masking ignores that buffer # — so without eager the bidirectional patch would silently do nothing. base = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=dtype, attn_implementation="eager", ) n_patched = make_bidirectional(base) n_drop = set_dropout(base, attn_dropout) if attn_dropout is not None else 0 if init_adapter is not None: # Resume the existing adapter (e.g. SimCSE stage continuing from MNTP). base = PeftModel.from_pretrained(base, init_adapter, is_trainable=True) targets = sorted(base.peft_config["default"].target_modules) else: targets = find_lora_targets(base) lora_cfg = LoraConfig( r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=targets, bias="none", task_type="FEATURE_EXTRACTION", ) base = get_peft_model(base, lora_cfg) wrapped = BidirProGen(base, objective=objective, simcse_weight=simcse_weight, temperature=temperature) return wrapped, {"patched_layers": n_patched, "lora_targets": targets, "dropout_set": n_drop, "resumed_adapter": init_adapter}