sv-task / src /schemas /labels.py
lamossta's picture
schemas
50aa44f
from dataclasses import dataclass, field
@dataclass(frozen=True)
class LabelRemap:
"""Maps non-standard label strings to canonical equivalents."""
mapping: dict[str, str] = field(
default_factory=lambda: {"very positive": "positive"}
)
@dataclass(frozen=True)
class EntityTypes:
"""Accepted entity type values."""
types: tuple[str, ...] = ("company", "location")
@dataclass(frozen=True)
class SentimentLabels:
"""3-class sentiment label schema."""
classes: tuple[str, ...] = ("negative", "neutral", "positive")
label2id: dict[str, int] = field(default_factory=dict)
id2label: dict[int, str] = field(default_factory=dict)
def __post_init__(self):
object.__setattr__(
self, "label2id", {s: i for i, s in enumerate(self.classes)}
)
object.__setattr__(
self, "id2label", {i: s for i, s in enumerate(self.classes)}
)
@property
def num_labels(self) -> int:
return len(self.classes)
@dataclass(frozen=True)
class BinaryLabels:
"""Binary (yes / no) label schema used by qa_b mode."""
classes: tuple[str, ...] = ("no", "yes")
label2id: dict[str, int] = field(default_factory=dict)
id2label: dict[int, str] = field(default_factory=dict)
def __post_init__(self):
object.__setattr__(
self, "label2id", {s: i for i, s in enumerate(self.classes)}
)
object.__setattr__(
self, "id2label", {i: s for i, s in enumerate(self.classes)}
)
@property
def num_labels(self) -> int:
return len(self.classes)
# ── Training modes ────────────────────────────────────────────────────────────
@dataclass(frozen=True)
class MarkerMode:
"""Single-sequence mode: wrap entity with [E]...[/E] special tokens.
Seg A: "[E] Google [/E] had strong earnings but Microsoft missed."
Label: 3-way (negative=0, neutral=1, positive=2)
"""
name: str = "marker"
entity_start: str = "[E]"
entity_end: str = "[/E]"
labels: SentimentLabels = field(default_factory=SentimentLabels)
@dataclass(frozen=True)
class QaMMode:
"""Sentence-pair QA-M mode (Sun et al. 2019).
Seg A: "Google had strong earnings but Microsoft missed."
Seg B: "What do you think of the sentiment of the company Google ?"
Label: 3-way (negative=0, neutral=1, positive=2)
"""
name: str = "qa_m"
question_template: str = "What do you think of the sentiment of the {entity_type} {entity} ?"
labels: SentimentLabels = field(default_factory=SentimentLabels)
@dataclass(frozen=True)
class QaBMode:
"""Sentence-pair QA-B mode (Sun et al. 2019), binary.
Seg A: "Google had strong earnings but Microsoft missed."
Seg B: "The polarity of the company Google is positive ."
Label: binary (no=0, yes=1)
Three forward passes per entity at inference; highest P(yes) wins.
"""
name: str = "qa_b"
hypothesis_template: str = "The polarity of the {entity_type} {entity} is {sentiment} ."
labels: BinaryLabels = field(default_factory=BinaryLabels)
sentiment_labels: SentimentLabels = field(default_factory=SentimentLabels)
SENTIMENT_LABELS = SentimentLabels()
BINARY_LABELS = BinaryLabels()
LABEL_REMAP = LabelRemap()
ENTITY_TYPES = EntityTypes()
MARKER_MODE = MarkerMode()
QA_M_MODE = QaMMode()
QA_B_MODE = QaBMode()
MODES: dict[str, MarkerMode | QaMMode | QaBMode] = {
"marker": MARKER_MODE,
"qa_m": QA_M_MODE,
"qa_b": QA_B_MODE,
}