File size: 4,225 Bytes
c1f6b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
bert_ordinal.py
---------------
BERT-based ordinal regression model, fully integrated with the HuggingFace
Transformers API:

    model.save_pretrained("my-checkpoint/")
    model = BertOrdinal.from_pretrained("my-checkpoint/")

Architecture
------------
1. A (optionally frozen) BERT backbone.
2. A projection head on the [CLS] token:
       Linear(hidden_size β†’ hidden_dim) β†’ ReLU β†’ Dropout(p) β†’ Linear(hidden_dim β†’ 1)
   producing a single latent score s ∈ ℝ.
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
   cumsum(softplus(Β·)).
4. Cumulative-link probabilities:
       P(Y ≀ j | x) = Οƒ(ΞΈ_j βˆ’ s)

Usage
-----
    from bert_ordinal import BertOrdinalConfig, BertOrdinal

    # ── Create from scratch ──────────────────────────────────────────────────
    cfg = BertOrdinalConfig(
        bert_model_name="bert-base-uncased",
        num_classes=3,
        hidden_dim=128,
        dropout=0.1,
        freeze_bert=True,
    )
    model = BertOrdinal(cfg)

    # ── Save ────────────────────────────────────────────────────────────────
    model.save_pretrained("my-checkpoint/")
    tokenizer.save_pretrained("my-checkpoint/")   # keep tokenizer alongside

    # ── Reload ──────────────────────────────────────────────────────────────
    model     = BertOrdinal.from_pretrained("my-checkpoint/")
    tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
"""

from __future__ import annotations
from typing import Optional
from transformers import PretrainedConfig


# ---------------------------------------------------------------------------
# 1. Config  β€”  subclass PretrainedConfig for full HF serialisation
# ---------------------------------------------------------------------------

class BertOrdinalConfig(PretrainedConfig):
    """
    Configuration for :class:`BertOrdinal`.

    Because this inherits from :class:`~transformers.PretrainedConfig`,
    ``save_pretrained`` writes a ``config.json`` that ``from_pretrained``
    can read back without any extra bookkeeping.

    Parameters
    ----------
    bert_model_name : str
        HuggingFace model name or local path for the BERT backbone.
    num_classes : int
        Number of ordinal classes K.  Creates K-1 learnable thresholds.
    hidden_dim : int
        Inner dimension of the projection head.
    dropout : float
        Dropout probability inside the projection head.
    freeze_bert : bool
        Freeze backbone weights at construction time.
    loss_reduction : str
        ``'mean'`` or ``'sum'``.
    """

    # Tells HF which class owns this config (written into config.json).
    model_type = "bert_ordinal"
    problem_type = "single_label_classification"

    def __init__(
        self,
        bert_model_name: str  = "allenai/scibert_scivocab_uncased",
        num_classes:     int  = 3,
        hidden_dim:      int  = 256,
        dropout:         float = 0.1,
        freeze_bert:     bool = True,
        loss_reduction:  str  = "mean",
        # hidden_size is set automatically by the model after loading BERT;
        # it is stored here so from_pretrained can rebuild the head offline.
        hidden_size:     Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.bert_model_name = bert_model_name
        self.num_classes     = num_classes
        self.hidden_dim      = hidden_dim
        self.dropout         = dropout
        self.freeze_bert     = freeze_bert
        self.loss_reduction  = loss_reduction
        self.hidden_size     = hidden_size   # filled in by BertOrdinal.__init__

        self.auto_map = {
            "AutoConfig": "configuration_bert_ordinal.BertOrdinalConfig",
            "AutoModel": "modeling_bert_ordinal.BertOrdinal",
            "AutoModelForSequenceClassification": "modeling_bert_ordinal.BertOrdinal",
        }