File size: 8,905 Bytes
e6bc942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Bidirectional adaptation of ProGen2 (LLM2Vec-style recipe).

ProGen2 (`hugohrban/progen2-base`, ProGenForCausalLM, GPT-J-style) applies
causality inside `ProGenAttention._attn` via a triangular `bias` buffer:

    causal_mask = self.bias[..., kq, :kl]
    attn_weights = torch.where(causal_mask, attn_weights, masked_bias)
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask   # padding mask (kept)

Setting every layer's `bias` buffer to all-True neutralises the causal
`torch.where` while preserving the padding-mask path -> full bidirectional
attention, with no rewrite of the module's forward.

Objectives (configurable, following the proposal):
  * MNTP   - masked next-token prediction: a masked position i is predicted
             from the hidden state of the preceding token i-1.
  * SimCSE - dropout-based unsupervised contrastive (InfoNCE, in-batch negs).
  * joint  - MNTP + lambda * SimCSE in one step.
"""
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers.pytorch_utils import Conv1D


def make_bidirectional(model: nn.Module) -> int:
    """Replace every triangular causal `bias` buffer with all-True.

    Returns the number of attention modules patched (sanity check).
    """
    patched = 0
    for module in model.modules():
        bias = getattr(module, "bias", None)
        # ProGenAttention.bias is a bool buffer shaped (1, 1, n_pos, n_pos).
        if isinstance(bias, torch.Tensor) and bias.dtype == torch.bool and bias.dim() == 4:
            module.bias = torch.ones_like(bias)
            patched += 1
    if patched == 0:
        raise RuntimeError(
            "No causal bias buffers found - ProGen internals may have changed; "
            "inspect ProGenAttention for the causal mask."
        )
    return patched


def find_lora_targets(model: nn.Module) -> list[str]:
    """Auto-detect attention/MLP projection leaf names for LoRA (robust to renames).

    Catches both `nn.Linear` (ProGen2, Llama, ...) and `Conv1D` (GPT-2's
    c_attn/c_proj/c_fc) so the same recipe works on text decoders. `lm_head` is
    excluded (it's the output head, not a representation projection)."""
    names = set()
    for name, mod in model.named_modules():
        if isinstance(mod, (nn.Linear, Conv1D)) and not name.endswith("lm_head"):
            names.add(name.split(".")[-1])
    return sorted(names)


def mean_pool(hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).to(hidden.dtype)
    summed = (hidden * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts


class BidirProGen(nn.Module):
    """Wraps a (bidirectional, LoRA-wrapped) ProGen2 for MNTP / SimCSE / joint."""

    def __init__(self, base_model, objective: str = "joint", simcse_weight: float = 0.1,
                 temperature: float = 0.05):
        super().__init__()
        self.model = base_model
        assert objective in {"mntp", "simcse", "joint"}
        self.objective = objective
        self.simcse_weight = simcse_weight
        self.temperature = temperature

    def _backbone(self, input_ids, attention_mask):
        """Return last hidden states (B, T, H) from the inner transformer."""
        out = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        return out.logits, out.hidden_states[-1]

    def mntp_loss(self, logits, labels):
        # Predict masked token i from position i-1: shift logits left vs labels.
        shift_logits = logits[:, :-1, :]
        shift_labels = labels[:, 1:]
        return F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            ignore_index=-100,
        )

    def simcse_loss(self, input_ids, attention_mask):
        # Two independent dropout passes -> positive pair; in-batch negatives.
        _, h1 = self._backbone(input_ids, attention_mask)
        _, h2 = self._backbone(input_ids, attention_mask)
        z1 = F.normalize(mean_pool(h1, attention_mask), dim=-1)
        z2 = F.normalize(mean_pool(h2, attention_mask), dim=-1)
        sim = z1 @ z2.t() / self.temperature
        targets = torch.arange(sim.size(0), device=sim.device)
        return 0.5 * (F.cross_entropy(sim, targets) + F.cross_entropy(sim.t(), targets))

    def forward(self, input_ids, attention_mask, labels=None, **_):
        loss = input_ids.new_zeros((), dtype=torch.float32)
        logs = {}
        if self.objective in {"mntp", "joint"}:
            logits, _ = self._backbone(input_ids, attention_mask)
            l_mntp = self.mntp_loss(logits, labels)
            loss = loss + l_mntp
            logs["mntp"] = l_mntp.detach()
        if self.objective in {"simcse", "joint"}:
            l_sim = self.simcse_loss(input_ids, attention_mask)
            # simcse_weight only balances the *joint* loss; a standalone SimCSE
            # stage trains the contrastive loss at full weight.
            w = self.simcse_weight if self.objective == "joint" else 1.0
            loss = loss + w * l_sim
            logs["simcse"] = l_sim.detach()
        return {"loss": loss, "logs": logs}


def set_dropout(model: nn.Module, p: float) -> int:
    """Force every nn.Dropout to probability `p` (architecture-agnostic).

    LLM2Vec's SimCSE stage explicitly enables dropout (default 0.1) so the two
    forward passes of the same sequence differ — that dropout *is* the only
    augmentation forming the positive pair. Some decoders (check ProGen2!) ship
    with dropout=0, which would make the two views identical and the contrastive
    signal degenerate. Call this BEFORE LoRA-wrapping so LoRA's own dropout is
    left untouched."""
    n = 0
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.p = p
            n += 1
    return n


def load_bidir_progen(model_name: str, objective: str, lora_r: int, lora_alpha: int,
                      lora_dropout: float, simcse_weight: float, temperature: float,
                      dtype: torch.dtype = torch.bfloat16,
                      init_adapter: str | None = None,
                      attn_dropout: float | None = None):
    """Load a decoder LM, make it bidirectional, attach (or resume) LoRA, wrap.

    `init_adapter`: path to an existing LoRA adapter to CONTINUE training (this is
    how the SimCSE stage starts from the MNTP checkpoint, per LLM2Vec). When None,
    a fresh LoRA is initialised.
    `attn_dropout`: if set, force all dropout to this prob (SimCSE stage needs it)."""
    from peft import LoraConfig, get_peft_model, PeftModel

    # ProGen2's remote modeling code predates transformers>=5, so it never sets
    # the `all_tied_weights_keys` attribute that the weight loader now accesses
    # unconditionally (modeling_utils._move_missing_keys_from_meta_to_device).
    # ProGen (GPT-J-style) does not tie its lm_head, so an empty mapping is correct.
    import transformers.modeling_utils as _mu
    if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
        _mu.PreTrainedModel.all_tied_weights_keys = {}

    # transformers 4.44.2 names this kwarg `torch_dtype` (renamed to `dtype` in 5.x);
    # passing `dtype` here would fall through to ProGenForCausalLM.__init__ and raise.
    # attn_implementation="eager" forces the attention path that reads the triangular
    # `bias` buffer make_bidirectional() flips. ProGen2 is eager-only anyway, but
    # text models (GPT-2) default to SDPA, whose causal masking ignores that buffer
    # — so without eager the bidirectional patch would silently do nothing.
    base = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True, torch_dtype=dtype,
        attn_implementation="eager",
    )
    n_patched = make_bidirectional(base)
    n_drop = set_dropout(base, attn_dropout) if attn_dropout is not None else 0

    if init_adapter is not None:
        # Resume the existing adapter (e.g. SimCSE stage continuing from MNTP).
        base = PeftModel.from_pretrained(base, init_adapter, is_trainable=True)
        targets = sorted(base.peft_config["default"].target_modules)
    else:
        targets = find_lora_targets(base)
        lora_cfg = LoraConfig(
            r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
            target_modules=targets, bias="none", task_type="FEATURE_EXTRACTION",
        )
        base = get_peft_model(base, lora_cfg)
    wrapped = BidirProGen(base, objective=objective, simcse_weight=simcse_weight,
                          temperature=temperature)
    return wrapped, {"patched_layers": n_patched, "lora_targets": targets,
                     "dropout_set": n_drop, "resumed_adapter": init_adapter}