| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass(frozen=True) |
| class LabelRemap: |
| """Maps non-standard label strings to canonical equivalents.""" |
|
|
| mapping: dict[str, str] = field( |
| default_factory=lambda: {"very positive": "positive"} |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class EntityTypes: |
| """Accepted entity type values.""" |
|
|
| types: tuple[str, ...] = ("company", "location") |
|
|
|
|
| @dataclass(frozen=True) |
| class SentimentLabels: |
| """3-class sentiment label schema.""" |
|
|
| classes: tuple[str, ...] = ("negative", "neutral", "positive") |
| label2id: dict[str, int] = field(default_factory=dict) |
| id2label: dict[int, str] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| object.__setattr__( |
| self, "label2id", {s: i for i, s in enumerate(self.classes)} |
| ) |
| object.__setattr__( |
| self, "id2label", {i: s for i, s in enumerate(self.classes)} |
| ) |
|
|
| @property |
| def num_labels(self) -> int: |
| return len(self.classes) |
|
|
|
|
| @dataclass(frozen=True) |
| class BinaryLabels: |
| """Binary (yes / no) label schema used by qa_b mode.""" |
|
|
| classes: tuple[str, ...] = ("no", "yes") |
| label2id: dict[str, int] = field(default_factory=dict) |
| id2label: dict[int, str] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| object.__setattr__( |
| self, "label2id", {s: i for i, s in enumerate(self.classes)} |
| ) |
| object.__setattr__( |
| self, "id2label", {i: s for i, s in enumerate(self.classes)} |
| ) |
|
|
| @property |
| def num_labels(self) -> int: |
| return len(self.classes) |
|
|
|
|
| |
|
|
| @dataclass(frozen=True) |
| class MarkerMode: |
| """Single-sequence mode: wrap entity with [E]...[/E] special tokens. |
| |
| Seg A: "[E] Google [/E] had strong earnings but Microsoft missed." |
| Label: 3-way (negative=0, neutral=1, positive=2) |
| """ |
|
|
| name: str = "marker" |
| entity_start: str = "[E]" |
| entity_end: str = "[/E]" |
| labels: SentimentLabels = field(default_factory=SentimentLabels) |
|
|
|
|
| @dataclass(frozen=True) |
| class QaMMode: |
| """Sentence-pair QA-M mode (Sun et al. 2019). |
| |
| Seg A: "Google had strong earnings but Microsoft missed." |
| Seg B: "What do you think of the sentiment of the company Google ?" |
| Label: 3-way (negative=0, neutral=1, positive=2) |
| """ |
|
|
| name: str = "qa_m" |
| question_template: str = "What do you think of the sentiment of the {entity_type} {entity} ?" |
| labels: SentimentLabels = field(default_factory=SentimentLabels) |
|
|
|
|
| @dataclass(frozen=True) |
| class QaBMode: |
| """Sentence-pair QA-B mode (Sun et al. 2019), binary. |
| |
| Seg A: "Google had strong earnings but Microsoft missed." |
| Seg B: "The polarity of the company Google is positive ." |
| Label: binary (no=0, yes=1) |
| |
| Three forward passes per entity at inference; highest P(yes) wins. |
| """ |
|
|
| name: str = "qa_b" |
| hypothesis_template: str = "The polarity of the {entity_type} {entity} is {sentiment} ." |
| labels: BinaryLabels = field(default_factory=BinaryLabels) |
| sentiment_labels: SentimentLabels = field(default_factory=SentimentLabels) |
|
|
|
|
| SENTIMENT_LABELS = SentimentLabels() |
| BINARY_LABELS = BinaryLabels() |
| LABEL_REMAP = LabelRemap() |
| ENTITY_TYPES = EntityTypes() |
|
|
| MARKER_MODE = MarkerMode() |
| QA_M_MODE = QaMMode() |
| QA_B_MODE = QaBMode() |
|
|
| MODES: dict[str, MarkerMode | QaMMode | QaBMode] = { |
| "marker": MARKER_MODE, |
| "qa_m": QA_M_MODE, |
| "qa_b": QA_B_MODE, |
| } |
|
|