binomial-marks-1 / configuration_marks.py
ilayibrahimzadeh's picture
Initial publish of binomial-marks-1
f7b715f verified
"""Configuration class for binomial-marks-1.
Distributed alongside the model on HuggingFace Hub so
`AutoConfig.from_pretrained(repo, trust_remote_code=True)` works.
"""
from __future__ import annotations
from transformers.configuration_utils import PretrainedConfig
TOPICS = (
"guidance", "revenue_growth", "margins", "demand", "buybacks",
"dividends", "m_and_a", "headcount", "macro_exposure", "competition",
)
TONES = ("mgmt_confidence", "mgmt_defensiveness", "analyst_skepticism")
class MarksConfig(PretrainedConfig):
"""Config for MarksMultiHead.
Holds the head spec and the underlying ModernBERT-large config (we wrap
it as a child config so HF tooling can serialize cleanly).
"""
model_type = "marks"
def __init__(
self,
encoder_name_or_path: str = "answerdotai/ModernBERT-large",
encoder_config: dict | None = None,
max_position_embeddings: int = 16384,
# NOTE: named `marks_rope_strategy` (not `rope_scaling`) to avoid
# collision with PretrainedConfig.rope_scaling which transformers
# tries to validate as a dict shape.
marks_rope_strategy: str = "yarn", # "yarn" | "ntk" | "none"
original_max_position: int = 8192,
head_dim_ratio: int = 4, # head hidden = H // ratio
dropout: float = 0.1,
topic_score_range: tuple[float, float] = (-2.0, 2.0),
tone_score_range: tuple[float, float] = (1.0, 5.0),
topics: tuple[str, ...] = TOPICS,
tones: tuple[str, ...] = TONES,
loss_weights: dict | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.encoder_name_or_path = encoder_name_or_path
self.encoder_config = encoder_config or {}
self.max_position_embeddings = max_position_embeddings
self.marks_rope_strategy = marks_rope_strategy
self.original_max_position = original_max_position
self.head_dim_ratio = head_dim_ratio
self.dropout = dropout
self.topic_score_range = list(topic_score_range)
self.tone_score_range = list(tone_score_range)
self.topics = list(topics)
self.tones = list(tones)
self.loss_weights = loss_weights or {
"topic_mentioned": 0.5,
"topic_score": 1.5,
"tone_scores": 0.2,
}