binomial-shannon-2 / modeling_shannon2.py
ilayibrahimzadeh's picture
Publish binomial-shannon v2
c24924a verified
"""Self-contained model class for binomial-shannon-2.
Distributed alongside the weights on HuggingFace Hub so anyone can do:
from transformers import AutoTokenizer, AutoModel
tok = AutoTokenizer.from_pretrained("BinomialTechnologies/binomial-shannon-2")
model = AutoModel.from_pretrained("BinomialTechnologies/binomial-shannon-2",
trust_remote_code=True)
Imports only from `transformers` + `torch` — no project-internal dependencies.
Module names match the training checkpoint so weights load verbatim.
Architecture:
shared encoder
↓ (CLS + masked mean pool concatenated)
↓ 2-way router (ticker | macro)
↓ ticker head bank (19 outputs) + macro head bank (35 outputs)
Router mode_prob over {ticker, macro}
Ticker bank event(10) + tone + implied_direction + novelty
+ claim(4) + specificity + materiality (19, = shannon-1)
Macro bank topic(18) + directional_read + severity(5)
+ novelty_macro(3) + claim_macro(4) + hawkish_dovish(5) (35)
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_shannon2 import (
Shannon2Config, EVENTS, CLAIM_TYPES, TOPICS, SEVERITY_BUCKETS,
NOVELTY_BUCKETS_MACRO, CLAIM_TYPES_MACRO, HAWKISH_DOVISH_BUCKETS,
)
N_EVENTS = len(EVENTS)
N_CLAIMS_TICKER = len(CLAIM_TYPES)
N_TOPICS = len(TOPICS)
N_SEVERITY = len(SEVERITY_BUCKETS)
N_NOVELTY_MACRO = len(NOVELTY_BUCKETS_MACRO)
N_CLAIMS_MACRO = len(CLAIM_TYPES_MACRO)
N_HD = len(HAWKISH_DOVISH_BUCKETS)
MODE_TICKER = 0
MODE_MACRO = 1
@dataclass
class Shannon2Output(ModelOutput):
mode_logits: Optional[torch.Tensor] = None
# ticker bank
event_logits: Optional[torch.Tensor] = None
tone: Optional[torch.Tensor] = None
implied_direction: Optional[torch.Tensor] = None
novelty: Optional[torch.Tensor] = None
claim_logits: Optional[torch.Tensor] = None
specificity: Optional[torch.Tensor] = None
materiality_if_true: Optional[torch.Tensor] = None
# macro bank
topic_logits: Optional[torch.Tensor] = None
directional_read: Optional[torch.Tensor] = None
severity_logits: Optional[torch.Tensor] = None
novelty_macro_logits: Optional[torch.Tensor] = None
claim_macro_logits: Optional[torch.Tensor] = None
hawkish_dovish_logits: Optional[torch.Tensor] = None
class Shannon2MultiHead(PreTrainedModel):
config_class = Shannon2Config
base_model_prefix = "shannon2"
_tied_weights_keys: list = []
all_tied_weights_keys: dict = {}
def __init__(self, config: Shannon2Config) -> None:
super().__init__(config)
self.config = config
# Rebuild the encoder from the bundled config so loading works offline.
if hasattr(config, "encoder_config") and config.encoder_config:
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
mtype = config.encoder_config.get("model_type")
if mtype and mtype in CONFIG_MAPPING:
enc_cfg = CONFIG_MAPPING[mtype].from_dict(config.encoder_config)
else:
enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path)
else:
enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path)
if config.max_position_embeddings > getattr(enc_cfg, "max_position_embeddings", 8192):
enc_cfg.max_position_embeddings = config.max_position_embeddings
from transformers import AutoModel as _AutoModel
self.encoder = _AutoModel.from_config(enc_cfg, attn_implementation="sdpa")
H = enc_cfg.hidden_size
head_in = 2 * H
h1, h2 = config.head_h1, config.head_h2
d = config.dropout
self.dropout = nn.Dropout(d)
def _mlp(out_dim: int) -> nn.Sequential:
return nn.Sequential(
nn.Linear(head_in, h1), nn.GELU(), nn.Dropout(d),
nn.Linear(h1, h2), nn.GELU(), nn.Dropout(d),
nn.Linear(h2, out_dim),
)
# Router
self.head_router = _mlp(2)
# Ticker bank
self.head_event = _mlp(N_EVENTS)
self.head_tone = _mlp(1)
self.head_implied_direction = _mlp(1)
self.head_novelty = _mlp(1)
self.head_claim = _mlp(N_CLAIMS_TICKER)
self.head_specificity = _mlp(1)
self.head_materiality = _mlp(1)
# Macro bank
self.head_topic = _mlp(N_TOPICS)
self.head_directional_read = _mlp(1)
self.head_severity = _mlp(N_SEVERITY)
self.head_novelty_macro = _mlp(N_NOVELTY_MACRO)
self.head_claim_macro = _mlp(N_CLAIMS_MACRO)
self.head_hawkish_dovish = _mlp(N_HD)
def _pool(self, last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
cls = last_hidden[:, 0]
m = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
mean_pool = (last_hidden * m).sum(1) / m.sum(1).clamp(min=1.0)
return self.dropout(torch.cat([cls, mean_pool], dim=-1))
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
**kwargs) -> Shannon2Output:
enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled = self._pool(enc.last_hidden_state, attention_mask)
return Shannon2Output(
mode_logits=self.head_router(pooled),
event_logits=self.head_event(pooled),
tone=self.head_tone(pooled).squeeze(-1),
implied_direction=self.head_implied_direction(pooled).squeeze(-1),
novelty=self.head_novelty(pooled).squeeze(-1),
claim_logits=self.head_claim(pooled),
specificity=self.head_specificity(pooled).squeeze(-1),
materiality_if_true=self.head_materiality(pooled).squeeze(-1),
topic_logits=self.head_topic(pooled),
directional_read=self.head_directional_read(pooled).squeeze(-1),
severity_logits=self.head_severity(pooled),
novelty_macro_logits=self.head_novelty_macro(pooled),
claim_macro_logits=self.head_claim_macro(pooled),
hawkish_dovish_logits=self.head_hawkish_dovish(pooled),
)
@torch.no_grad()
def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
mention_threshold: float = 0.5) -> dict:
out = self.forward(input_ids=input_ids, attention_mask=attention_mask)
ev_prob = torch.sigmoid(out.event_logits)
return {
"mode_prob": F.softmax(out.mode_logits, dim=-1),
# ticker
"event_prob": ev_prob,
"event_mentioned": (ev_prob >= mention_threshold).float(),
"tone": out.tone.clamp(-1.0, 1.0),
"implied_direction": out.implied_direction.clamp(-1.0, 1.0),
"novelty": out.novelty.clamp(0.0, 1.0),
"claim_prob": F.softmax(out.claim_logits, dim=-1),
"specificity": out.specificity.clamp(0.0, 1.0),
"materiality_if_true": out.materiality_if_true.clamp(0.0, 1.0),
# macro
"topic_prob": F.softmax(out.topic_logits, dim=-1),
"directional_read": out.directional_read.clamp(-1.0, 1.0),
"severity_prob": F.softmax(out.severity_logits, dim=-1),
"novelty_macro_prob": F.softmax(out.novelty_macro_logits, dim=-1),
"claim_macro_prob": F.softmax(out.claim_macro_logits, dim=-1),
"hawkish_dovish_prob": F.softmax(out.hawkish_dovish_logits, dim=-1),
}
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if hasattr(self.encoder, "gradient_checkpointing_enable"):
self.encoder.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def gradient_checkpointing_disable(self):
if hasattr(self.encoder, "gradient_checkpointing_disable"):
self.encoder.gradient_checkpointing_disable()