File size: 11,067 Bytes
f7b715f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""Self-contained model class for binomial-marks-1.

Distributed alongside the weights on HuggingFace Hub so anyone can do:

    from transformers import AutoTokenizer, AutoModel
    tok   = AutoTokenizer.from_pretrained("BinomialTechnologies/binomial-marks-1")
    model = AutoModel.from_pretrained("BinomialTechnologies/binomial-marks-1",
                                       trust_remote_code=True)

This file imports only from `transformers` + `torch` β€” no project-internal
dependencies.

Architecture:
    ModernBERT-large encoder (with optional YaRN RoPE extension to 16k)
        ↓ (CLS + masked mean pool concatenated)
        ↓ (3 Γ— MLP heads)
    23 outputs:
        10 Γ— topic_mentioned (binary classification, sigmoid β†’ BCE loss)
        10 Γ— topic_score     (regression in [-2, +2] after clamp at inference)
         3 Γ— tone_score      (regression in [1, 5] after clamp at inference)
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput

# Relative import β€” HF's `trust_remote_code` loader bundles sibling .py
# files together and resolves these without the symbol being "installed".
from .configuration_marks import MarksConfig, TOPICS, TONES


# ---------------------------------------------------------------------------
# YaRN RoPE extension β€” per-dim ramp; applied after model load
# ---------------------------------------------------------------------------

def _yarn_inv_freq(
    head_dim: int,
    base: float,
    scale: float,
    original_max_position: int,
    beta_fast: float = 32.0,
    beta_slow: float = 1.0,
    device=None,
    dtype=torch.float32,
) -> torch.Tensor:
    if scale <= 1.0:
        return 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))
    inv_freq_extrap = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))
    inv_freq_interp = inv_freq_extrap / scale
    wavelengths = 2.0 * math.pi / inv_freq_extrap
    L = original_max_position
    ramp = (L / wavelengths - beta_slow) / (beta_fast - beta_slow)
    ramp = ramp.clamp(0.0, 1.0)
    return inv_freq_interp * (1.0 - ramp) + inv_freq_extrap * ramp


def _apply_yarn_to_modernbert(encoder, new_max_position: int,
                               original_max_position: int = 8192,
                               beta_fast: float = 32.0, beta_slow: float = 1.0):
    if new_max_position == original_max_position:
        return
    scale = new_max_position / original_max_position
    cfg = encoder.config
    head_dim = cfg.hidden_size // cfg.num_attention_heads
    global_base = float(getattr(cfg, "global_rope_theta", getattr(cfg, "rope_theta", 10000.0)))

    rotary_modules = [
        m for _, m in encoder.named_modules()
        if m.__class__.__name__ == "ModernBertRotaryEmbedding"
    ]
    for mod in rotary_modules:
        full_buf = getattr(mod, "full_attention_inv_freq", None)
        if full_buf is None or full_buf.numel() != head_dim // 2:
            continue
        new_inv = _yarn_inv_freq(
            head_dim=head_dim, base=global_base, scale=scale,
            original_max_position=original_max_position,
            beta_fast=beta_fast, beta_slow=beta_slow,
            device=full_buf.device, dtype=full_buf.dtype,
        )
        full_buf.data.copy_(new_inv)


# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------

@dataclass
class MarksOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    loss_components: Optional[dict] = None
    topic_mentioned_logits: Optional[torch.Tensor] = None   # (B, 10)
    topic_score: Optional[torch.Tensor] = None              # (B, 10)
    tone_score: Optional[torch.Tensor] = None               # (B,  3)


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------

class MarksMultiHead(PreTrainedModel):
    """Multi-head ModernBERT-large fine-tuned for earnings-call NLP scoring.

    23 outputs per call:
      * topic_mentioned (binary, 10 dims)
      * topic_score     (regression in [-2, +2], 10 dims)
      * tone_score      (regression in [1, 5], 3 dims)
    """

    config_class = MarksConfig
    base_model_prefix = "encoder"
    supports_gradient_checkpointing = True

    def __init__(self, config: MarksConfig):
        super().__init__(config)
        self.n_topics = len(config.topics)
        self.n_tones  = len(config.tones)

        # Encoder β€” built from config (so we don't redownload base weights;
        # weights come from this repo's safetensors).
        if config.encoder_config:
            enc_cfg = AutoConfig.from_dict(config.encoder_config) if hasattr(AutoConfig, "from_dict") else AutoConfig.for_model(**config.encoder_config)
        else:
            enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path)

        # Override the encoder ctx to the trained value (16384 for our v1).
        enc_cfg.max_position_embeddings = config.max_position_embeddings

        # Initialize encoder with config-only constructor (random init); the
        # PreTrainedModel.from_pretrained caller will restore real weights
        # from this repo's safetensors.
        self.encoder = AutoModel.from_config(enc_cfg)
        H = enc_cfg.hidden_size

        # Head input is CLS + mean pool concatenated β†’ 2H.
        head_in = 2 * H
        head_hidden = H // config.head_dim_ratio

        def _mlp(out_dim: int) -> nn.Sequential:
            return nn.Sequential(
                nn.Linear(head_in, head_hidden),
                nn.GELU(),
                nn.Dropout(config.dropout),
                nn.Linear(head_hidden, out_dim),
            )

        self.dropout = nn.Dropout(config.dropout)
        self.head_topic_mentioned = _mlp(self.n_topics)
        self.head_topic_score     = _mlp(self.n_topics)
        self.head_tone_score      = _mlp(self.n_tones)

        # Loss weights (used only if labels are passed for fine-tuning).
        self._loss_weights = config.loss_weights

        # Apply YaRN to encoder (idempotent if max_position == native).
        if config.marks_rope_strategy == "yarn":
            _apply_yarn_to_modernbert(
                self.encoder,
                new_max_position=config.max_position_embeddings,
                original_max_position=config.original_max_position,
            )
        # NTK is applied inside encoder config; nothing to do here.

        self.post_init()

    # -------------------------------------------------------------------------
    # Forward
    # -------------------------------------------------------------------------
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        topic_mentioned: Optional[torch.Tensor] = None,
        topic_score:     Optional[torch.Tensor] = None,
        tone_score:      Optional[torch.Tensor] = None,
        **kwargs,
    ) -> MarksOutput:

        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = out.last_hidden_state                      # (B, T, H)

        cls = last_hidden[:, 0]                                  # (B, H)
        m = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
        mean_pool = (last_hidden * m).sum(1) / m.sum(1).clamp(min=1.0)  # (B, H)
        pooled = self.dropout(torch.cat([cls, mean_pool], dim=-1))      # (B, 2H)

        tm_logits = self.head_topic_mentioned(pooled)
        ts_pred   = self.head_topic_score(pooled)
        tn_pred   = self.head_tone_score(pooled)

        loss, components = None, {}
        if topic_mentioned is not None:
            tm_logits_fp = tm_logits.float()
            ts_pred_fp = ts_pred.float()
            tn_pred_fp = tn_pred.float()
            tm_t = topic_mentioned.float()
            ts_t = topic_score.float()
            tn_t = tone_score.float()

            l_tm = F.binary_cross_entropy_with_logits(tm_logits_fp, tm_t)
            l_ts = F.mse_loss(ts_pred_fp, ts_t)
            l_tn = F.mse_loss(tn_pred_fp, tn_t)
            components = {
                "topic_mentioned": l_tm.detach(),
                "topic_score":     l_ts.detach(),
                "tone_scores":     l_tn.detach(),
            }
            w = self._loss_weights
            loss = (
                w["topic_mentioned"] * l_tm
                + w["topic_score"]   * l_ts
                + w["tone_scores"]   * l_tn
            )

        return MarksOutput(
            loss=loss,
            loss_components=components or None,
            topic_mentioned_logits=tm_logits,
            topic_score=ts_pred,
            tone_score=tn_pred,
        )

    # -------------------------------------------------------------------------
    # Convenience predict
    # -------------------------------------------------------------------------
    @torch.no_grad()
    def predict(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        mention_threshold: float = 0.5,
    ) -> dict:
        """Run a forward pass and return clamped + masked predictions.

        Returns a dict with:
          topic_mentioned       (B, 10) hard 0/1
          topic_mentioned_prob  (B, 10) sigmoid confidence
          topic_score           (B, 10) clamped to [-2, +2], zeroed where mentioned=0
          tone_score            (B,  3) clamped to [1, 5]
        """
        out = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        prob = torch.sigmoid(out.topic_mentioned_logits)
        mentioned = (prob >= mention_threshold).float()
        ts_lo, ts_hi = self.config.topic_score_range
        tn_lo, tn_hi = self.config.tone_score_range
        ts = out.topic_score.clamp(ts_lo, ts_hi) * mentioned
        tn = out.tone_score.clamp(tn_lo, tn_hi)
        return {
            "topic_mentioned":      mentioned,
            "topic_mentioned_prob": prob,
            "topic_score":          ts,
            "tone_score":           tn,
        }

    # -------------------------------------------------------------------------
    # Gradient checkpointing β€” delegate to encoder
    # -------------------------------------------------------------------------
    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.encoder, "gradient_checkpointing_enable"):
            self.encoder.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
            )

    def gradient_checkpointing_disable(self):
        if hasattr(self.encoder, "gradient_checkpointing_disable"):
            self.encoder.gradient_checkpointing_disable()