"""Self-contained model class for binomial-shannon-2. Distributed alongside the weights on HuggingFace Hub so anyone can do: from transformers import AutoTokenizer, AutoModel tok = AutoTokenizer.from_pretrained("BinomialTechnologies/binomial-shannon-2") model = AutoModel.from_pretrained("BinomialTechnologies/binomial-shannon-2", trust_remote_code=True) Imports only from `transformers` + `torch` — no project-internal dependencies. Module names match the training checkpoint so weights load verbatim. Architecture: shared encoder ↓ (CLS + masked mean pool concatenated) ↓ 2-way router (ticker | macro) ↓ ticker head bank (19 outputs) + macro head bank (35 outputs) Router mode_prob over {ticker, macro} Ticker bank event(10) + tone + implied_direction + novelty + claim(4) + specificity + materiality (19, = shannon-1) Macro bank topic(18) + directional_read + severity(5) + novelty_macro(3) + claim_macro(4) + hawkish_dovish(5) (35) """ from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoConfig from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_shannon2 import ( Shannon2Config, EVENTS, CLAIM_TYPES, TOPICS, SEVERITY_BUCKETS, NOVELTY_BUCKETS_MACRO, CLAIM_TYPES_MACRO, HAWKISH_DOVISH_BUCKETS, ) N_EVENTS = len(EVENTS) N_CLAIMS_TICKER = len(CLAIM_TYPES) N_TOPICS = len(TOPICS) N_SEVERITY = len(SEVERITY_BUCKETS) N_NOVELTY_MACRO = len(NOVELTY_BUCKETS_MACRO) N_CLAIMS_MACRO = len(CLAIM_TYPES_MACRO) N_HD = len(HAWKISH_DOVISH_BUCKETS) MODE_TICKER = 0 MODE_MACRO = 1 @dataclass class Shannon2Output(ModelOutput): mode_logits: Optional[torch.Tensor] = None # ticker bank event_logits: Optional[torch.Tensor] = None tone: Optional[torch.Tensor] = None implied_direction: Optional[torch.Tensor] = None novelty: Optional[torch.Tensor] = None claim_logits: Optional[torch.Tensor] = None specificity: Optional[torch.Tensor] = None materiality_if_true: Optional[torch.Tensor] = None # macro bank topic_logits: Optional[torch.Tensor] = None directional_read: Optional[torch.Tensor] = None severity_logits: Optional[torch.Tensor] = None novelty_macro_logits: Optional[torch.Tensor] = None claim_macro_logits: Optional[torch.Tensor] = None hawkish_dovish_logits: Optional[torch.Tensor] = None class Shannon2MultiHead(PreTrainedModel): config_class = Shannon2Config base_model_prefix = "shannon2" _tied_weights_keys: list = [] all_tied_weights_keys: dict = {} def __init__(self, config: Shannon2Config) -> None: super().__init__(config) self.config = config # Rebuild the encoder from the bundled config so loading works offline. if hasattr(config, "encoder_config") and config.encoder_config: from transformers.models.auto.configuration_auto import CONFIG_MAPPING mtype = config.encoder_config.get("model_type") if mtype and mtype in CONFIG_MAPPING: enc_cfg = CONFIG_MAPPING[mtype].from_dict(config.encoder_config) else: enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path) else: enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path) if config.max_position_embeddings > getattr(enc_cfg, "max_position_embeddings", 8192): enc_cfg.max_position_embeddings = config.max_position_embeddings from transformers import AutoModel as _AutoModel self.encoder = _AutoModel.from_config(enc_cfg, attn_implementation="sdpa") H = enc_cfg.hidden_size head_in = 2 * H h1, h2 = config.head_h1, config.head_h2 d = config.dropout self.dropout = nn.Dropout(d) def _mlp(out_dim: int) -> nn.Sequential: return nn.Sequential( nn.Linear(head_in, h1), nn.GELU(), nn.Dropout(d), nn.Linear(h1, h2), nn.GELU(), nn.Dropout(d), nn.Linear(h2, out_dim), ) # Router self.head_router = _mlp(2) # Ticker bank self.head_event = _mlp(N_EVENTS) self.head_tone = _mlp(1) self.head_implied_direction = _mlp(1) self.head_novelty = _mlp(1) self.head_claim = _mlp(N_CLAIMS_TICKER) self.head_specificity = _mlp(1) self.head_materiality = _mlp(1) # Macro bank self.head_topic = _mlp(N_TOPICS) self.head_directional_read = _mlp(1) self.head_severity = _mlp(N_SEVERITY) self.head_novelty_macro = _mlp(N_NOVELTY_MACRO) self.head_claim_macro = _mlp(N_CLAIMS_MACRO) self.head_hawkish_dovish = _mlp(N_HD) def _pool(self, last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: cls = last_hidden[:, 0] m = attention_mask.unsqueeze(-1).to(last_hidden.dtype) mean_pool = (last_hidden * m).sum(1) / m.sum(1).clamp(min=1.0) return self.dropout(torch.cat([cls, mean_pool], dim=-1)) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Shannon2Output: enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask) pooled = self._pool(enc.last_hidden_state, attention_mask) return Shannon2Output( mode_logits=self.head_router(pooled), event_logits=self.head_event(pooled), tone=self.head_tone(pooled).squeeze(-1), implied_direction=self.head_implied_direction(pooled).squeeze(-1), novelty=self.head_novelty(pooled).squeeze(-1), claim_logits=self.head_claim(pooled), specificity=self.head_specificity(pooled).squeeze(-1), materiality_if_true=self.head_materiality(pooled).squeeze(-1), topic_logits=self.head_topic(pooled), directional_read=self.head_directional_read(pooled).squeeze(-1), severity_logits=self.head_severity(pooled), novelty_macro_logits=self.head_novelty_macro(pooled), claim_macro_logits=self.head_claim_macro(pooled), hawkish_dovish_logits=self.head_hawkish_dovish(pooled), ) @torch.no_grad() def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, mention_threshold: float = 0.5) -> dict: out = self.forward(input_ids=input_ids, attention_mask=attention_mask) ev_prob = torch.sigmoid(out.event_logits) return { "mode_prob": F.softmax(out.mode_logits, dim=-1), # ticker "event_prob": ev_prob, "event_mentioned": (ev_prob >= mention_threshold).float(), "tone": out.tone.clamp(-1.0, 1.0), "implied_direction": out.implied_direction.clamp(-1.0, 1.0), "novelty": out.novelty.clamp(0.0, 1.0), "claim_prob": F.softmax(out.claim_logits, dim=-1), "specificity": out.specificity.clamp(0.0, 1.0), "materiality_if_true": out.materiality_if_true.clamp(0.0, 1.0), # macro "topic_prob": F.softmax(out.topic_logits, dim=-1), "directional_read": out.directional_read.clamp(-1.0, 1.0), "severity_prob": F.softmax(out.severity_logits, dim=-1), "novelty_macro_prob": F.softmax(out.novelty_macro_logits, dim=-1), "claim_macro_prob": F.softmax(out.claim_macro_logits, dim=-1), "hawkish_dovish_prob": F.softmax(out.hawkish_dovish_logits, dim=-1), } def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): if hasattr(self.encoder, "gradient_checkpointing_enable"): self.encoder.gradient_checkpointing_enable( gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) def gradient_checkpointing_disable(self): if hasattr(self.encoder, "gradient_checkpointing_disable"): self.encoder.gradient_checkpointing_disable()