from dataclasses import dataclass, field @dataclass(frozen=True) class LabelRemap: """Maps non-standard label strings to canonical equivalents.""" mapping: dict[str, str] = field( default_factory=lambda: {"very positive": "positive"} ) @dataclass(frozen=True) class EntityTypes: """Accepted entity type values.""" types: tuple[str, ...] = ("company", "location") @dataclass(frozen=True) class SentimentLabels: """3-class sentiment label schema.""" classes: tuple[str, ...] = ("negative", "neutral", "positive") label2id: dict[str, int] = field(default_factory=dict) id2label: dict[int, str] = field(default_factory=dict) def __post_init__(self): object.__setattr__( self, "label2id", {s: i for i, s in enumerate(self.classes)} ) object.__setattr__( self, "id2label", {i: s for i, s in enumerate(self.classes)} ) @property def num_labels(self) -> int: return len(self.classes) @dataclass(frozen=True) class BinaryLabels: """Binary (yes / no) label schema used by qa_b mode.""" classes: tuple[str, ...] = ("no", "yes") label2id: dict[str, int] = field(default_factory=dict) id2label: dict[int, str] = field(default_factory=dict) def __post_init__(self): object.__setattr__( self, "label2id", {s: i for i, s in enumerate(self.classes)} ) object.__setattr__( self, "id2label", {i: s for i, s in enumerate(self.classes)} ) @property def num_labels(self) -> int: return len(self.classes) # ── Training modes ──────────────────────────────────────────────────────────── @dataclass(frozen=True) class MarkerMode: """Single-sequence mode: wrap entity with [E]...[/E] special tokens. Seg A: "[E] Google [/E] had strong earnings but Microsoft missed." Label: 3-way (negative=0, neutral=1, positive=2) """ name: str = "marker" entity_start: str = "[E]" entity_end: str = "[/E]" labels: SentimentLabels = field(default_factory=SentimentLabels) @dataclass(frozen=True) class QaMMode: """Sentence-pair QA-M mode (Sun et al. 2019). Seg A: "Google had strong earnings but Microsoft missed." Seg B: "What do you think of the sentiment of the company Google ?" Label: 3-way (negative=0, neutral=1, positive=2) """ name: str = "qa_m" question_template: str = "What do you think of the sentiment of the {entity_type} {entity} ?" labels: SentimentLabels = field(default_factory=SentimentLabels) @dataclass(frozen=True) class QaBMode: """Sentence-pair QA-B mode (Sun et al. 2019), binary. Seg A: "Google had strong earnings but Microsoft missed." Seg B: "The polarity of the company Google is positive ." Label: binary (no=0, yes=1) Three forward passes per entity at inference; highest P(yes) wins. """ name: str = "qa_b" hypothesis_template: str = "The polarity of the {entity_type} {entity} is {sentiment} ." labels: BinaryLabels = field(default_factory=BinaryLabels) sentiment_labels: SentimentLabels = field(default_factory=SentimentLabels) SENTIMENT_LABELS = SentimentLabels() BINARY_LABELS = BinaryLabels() LABEL_REMAP = LabelRemap() ENTITY_TYPES = EntityTypes() MARKER_MODE = MarkerMode() QA_M_MODE = QaMMode() QA_B_MODE = QaBMode() MODES: dict[str, MarkerMode | QaMMode | QaBMode] = { "marker": MARKER_MODE, "qa_m": QA_M_MODE, "qa_b": QA_B_MODE, }