File size: 11,227 Bytes
c1f6b2a cc52b39 c1f6b2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 | """
bert_ordinal.py
---------------
BERT-based ordinal regression model, fully integrated with the HuggingFace
Transformers API:
model.save_pretrained("my-checkpoint/")
model = BertOrdinal.from_pretrained("my-checkpoint/")
Architecture
------------
1. A (optionally frozen) BERT backbone.
2. A projection head on the [CLS] token:
Linear(hidden_size β hidden_dim) β ReLU β Dropout(p) β Linear(hidden_dim β 1)
producing a single latent score s β β.
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
cumsum(softplus(Β·)).
4. Cumulative-link probabilities:
P(Y β€ j | x) = Ο(ΞΈ_j β s)
Usage
-----
from bert_ordinal import BertOrdinalConfig, BertOrdinal
# ββ Create from scratch ββββββββββββββββββββββββββββββββββββββββββββββββββ
cfg = BertOrdinalConfig(
bert_model_name="bert-base-uncased",
num_classes=3,
hidden_dim=128,
dropout=0.1,
freeze_bert=True,
)
model = BertOrdinal(cfg)
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model.save_pretrained("my-checkpoint/")
tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
# ββ Reload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
model = BertOrdinal.from_pretrained("my-checkpoint/")
tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_bert_ordinal import BertOrdinalConfig
# ---------------------------------------------------------------------------
# 1. Output dataclass
# ---------------------------------------------------------------------------
@dataclass
class BertOrdinalOutput(ModelOutput):
"""
Return type of :class:`BertOrdinal`.
Attributes
----------
loss : torch.Tensor or None
Ordinal cross-entropy loss (scalar). Present only when ``labels``
are supplied.
logits : torch.Tensor (B,)
Raw latent score from the projection head.
predictions : torch.Tensor (B,)
Predicted class index β argmax of ``class_probs``.
cum_probs : torch.Tensor (B, K-1)
Cumulative probabilities P(Y β€ j | x).
class_probs : torch.Tensor (B, K)
Per-class probabilities P(Y = j | x).
"""
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
predictions: Optional[torch.Tensor] = None
cum_probs: Optional[torch.Tensor] = None
class_probs: Optional[torch.Tensor] = None
# ---------------------------------------------------------------------------
# 3. Model β subclass PreTrainedModel for save / from_pretrained
# ---------------------------------------------------------------------------
class BertOrdinal(PreTrainedModel):
"""
BERT encoder with an ordinal-regression head.
Fully compatible with the HuggingFace checkpoint API::
model.save_pretrained("my-checkpoint/")
model = BertOrdinal.from_pretrained("my-checkpoint/")
What gets saved
~~~~~~~~~~~~~~~
``save_pretrained`` writes two files:
* ``config.json`` β the full :class:`BertOrdinalConfig` (including
``bert_model_name``, ``hidden_size``, thresholds shape, β¦).
* ``model.safetensors`` (or ``pytorch_model.bin``) β a **single flat
state_dict** containing both the BERT backbone weights and the
head/threshold parameters.
``from_pretrained`` reconstructs the model from the config (which
already has ``hidden_size`` cached), loads the state_dict, and
re-applies the ``freeze_bert`` setting β no internet access needed
after the first save.
"""
config_class = BertOrdinalConfig
def __init__(self, config: BertOrdinalConfig) -> None:
super().__init__(config)
K = config.num_classes
# ββ 1. BERT backbone ββββββββββββββββββββββββββββββββββββββββββββββββ
# If hidden_size is already in the config (i.e. we are being called
# from from_pretrained after a save), build the backbone from the
# cached backbone config instead of re-downloading weights β
# from_pretrained will overwrite with the saved state_dict anyway.
self.bert = AutoModel.from_pretrained(config.bert_model_name)
hidden_size: int = self.bert.config.hidden_size
# Cache so the head can be rebuilt offline after save_pretrained.
config.hidden_size = hidden_size
if config.freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
# ββ 2. Projection head ββββββββββββββββββββββββββββββββββββββββββββββ
self.head = nn.Sequential(
nn.Linear(hidden_size, config.hidden_dim),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim, 1),
)
self._init_head()
# ββ 3. Ordinal thresholds βββββββββββββββββββββββββββββββββββββββββββ
# K-1 raw values; monotonicity enforced via cumsum(softplus(Β·)).
self.raw_thresholds = nn.Parameter(torch.zeros(K - 1))
with torch.no_grad():
targets = torch.linspace(-1.0, 1.0, K - 1)
diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]])
self.raw_thresholds.copy_(
torch.log(torch.expm1(diffs.clamp(min=1e-3)))
)
# Finalises weight init bookkeeping required by PreTrainedModel.
self.post_init()
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
def _init_head(self) -> None:
for m in self.head.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
@property
def thresholds(self) -> torch.Tensor:
"""Monotone thresholds ΞΈβ β€ β¦ β€ ΞΈ_{K-1} (shape: K-1)."""
return torch.cumsum(F.softplus(self.raw_thresholds), dim=0)
# -----------------------------------------------------------------------
# Forward
# -----------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> BertOrdinalOutput:
"""
Parameters
----------
input_ids : (B, L)
attention_mask : (B, L)
token_type_ids : (B, L) optional
labels : (B,) long β class indices in {0, β¦, K-1}
Returns
-------
BertOrdinalOutput
"""
# ββ Encode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
if token_type_ids is not None:
bert_kwargs["token_type_ids"] = token_type_ids
cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] # (B, H)
# ββ Latent score ββββββββββββββββββββββββββββββββββββββββββββββββββββ
score = self.head(cls_repr).squeeze(-1) # (B,)
# ββ Cumulative probs P(Y β€ j) = Ο(ΞΈ_j β score) ββββββββββββββββββββ
cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) # (B, K-1)
cum_probs = torch.sigmoid(cum_logits) # (B, K-1)
# ββ Class probs P(Y = j) = P(Y β€ j) β P(Y β€ j-1) βββββββββββββββββ
B, dev = cum_probs.size(0), cum_probs.device
F_ = torch.cat(
[torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)],
dim=1,
) # (B, K+1)
class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) # (B, K)
# ββ Predictions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
predictions = class_probs.argmax(dim=-1) # (B,)
# ββ Loss βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
loss: Optional[torch.Tensor] = None
if labels is not None:
loss = ordinal_cross_entropy(
class_probs, labels, reduction=self.config.loss_reduction
)
return BertOrdinalOutput(
loss=loss,
logits=score,
predictions=predictions,
cum_probs=cum_probs,
class_probs=class_probs,
)
# -----------------------------------------------------------------------
# Convenience
# -----------------------------------------------------------------------
@torch.no_grad()
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Return predicted class indices (no loss computed)."""
return self.forward(input_ids, attention_mask, token_type_ids).predictions
# ---------------------------------------------------------------------------
# Loss function
# ---------------------------------------------------------------------------
def ordinal_cross_entropy(
class_probs: torch.Tensor,
labels: torch.Tensor,
reduction: str = "mean",
) -> torch.Tensor:
"""
Ordinal cross-entropy.
Parameters
----------
class_probs : (B, K) β P(Y=j|x), clamped > 0
labels : (B,) β ground-truth indices in {0, β¦, K-1}
reduction : 'mean' | 'sum' | 'none'
"""
return F.nll_loss(torch.log(class_probs), labels, reduction=reduction)
|