schemas
Browse files- src/schemas/data.py +35 -0
- src/schemas/labels.py +122 -0
- src/schemas/requests.py +22 -0
- src/schemas/responses.py +14 -0
src/schemas/data.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field, fields
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass(frozen=True)
|
| 5 |
+
class Position:
|
| 6 |
+
"""A single character span within a sample's text."""
|
| 7 |
+
|
| 8 |
+
position_text: str = ""
|
| 9 |
+
length: int = 0
|
| 10 |
+
offset: int = 0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class Entity:
|
| 15 |
+
"""A mention group for one entity within a sample."""
|
| 16 |
+
|
| 17 |
+
entity_id: str = ""
|
| 18 |
+
entity_text: str = ""
|
| 19 |
+
entity_type: str = ""
|
| 20 |
+
positions: list = field(default_factory=list)
|
| 21 |
+
label: str = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class Sample:
|
| 26 |
+
"""A single news article with its annotated entities."""
|
| 27 |
+
|
| 28 |
+
id: str = ""
|
| 29 |
+
text: str = ""
|
| 30 |
+
entities: list = field(default_factory=list)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def required_keys(cls) -> set[str]:
|
| 34 |
+
"""Return the set of required top-level keys for a data dataclass."""
|
| 35 |
+
return {f.name for f in fields(cls)}
|
src/schemas/labels.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass(frozen=True)
|
| 5 |
+
class LabelRemap:
|
| 6 |
+
"""Maps non-standard label strings to canonical equivalents."""
|
| 7 |
+
|
| 8 |
+
mapping: dict[str, str] = field(
|
| 9 |
+
default_factory=lambda: {"very positive": "positive"}
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class EntityTypes:
|
| 15 |
+
"""Accepted entity type values."""
|
| 16 |
+
|
| 17 |
+
types: tuple[str, ...] = ("company", "location")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class SentimentLabels:
|
| 22 |
+
"""3-class sentiment label schema."""
|
| 23 |
+
|
| 24 |
+
classes: tuple[str, ...] = ("negative", "neutral", "positive")
|
| 25 |
+
label2id: dict[str, int] = field(default_factory=dict)
|
| 26 |
+
id2label: dict[int, str] = field(default_factory=dict)
|
| 27 |
+
|
| 28 |
+
def __post_init__(self):
|
| 29 |
+
object.__setattr__(
|
| 30 |
+
self, "label2id", {s: i for i, s in enumerate(self.classes)}
|
| 31 |
+
)
|
| 32 |
+
object.__setattr__(
|
| 33 |
+
self, "id2label", {i: s for i, s in enumerate(self.classes)}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def num_labels(self) -> int:
|
| 38 |
+
return len(self.classes)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass(frozen=True)
|
| 42 |
+
class BinaryLabels:
|
| 43 |
+
"""Binary (yes / no) label schema used by qa_b mode."""
|
| 44 |
+
|
| 45 |
+
classes: tuple[str, ...] = ("no", "yes")
|
| 46 |
+
label2id: dict[str, int] = field(default_factory=dict)
|
| 47 |
+
id2label: dict[int, str] = field(default_factory=dict)
|
| 48 |
+
|
| 49 |
+
def __post_init__(self):
|
| 50 |
+
object.__setattr__(
|
| 51 |
+
self, "label2id", {s: i for i, s in enumerate(self.classes)}
|
| 52 |
+
)
|
| 53 |
+
object.__setattr__(
|
| 54 |
+
self, "id2label", {i: s for i, s in enumerate(self.classes)}
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def num_labels(self) -> int:
|
| 59 |
+
return len(self.classes)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ββ Training modes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
|
| 64 |
+
@dataclass(frozen=True)
|
| 65 |
+
class MarkerMode:
|
| 66 |
+
"""Single-sequence mode: wrap entity with [E]...[/E] special tokens.
|
| 67 |
+
|
| 68 |
+
Seg A: "[E] Google [/E] had strong earnings but Microsoft missed."
|
| 69 |
+
Label: 3-way (negative=0, neutral=1, positive=2)
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
name: str = "marker"
|
| 73 |
+
entity_start: str = "[E]"
|
| 74 |
+
entity_end: str = "[/E]"
|
| 75 |
+
labels: SentimentLabels = field(default_factory=SentimentLabels)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass(frozen=True)
|
| 79 |
+
class QaMMode:
|
| 80 |
+
"""Sentence-pair QA-M mode (Sun et al. 2019).
|
| 81 |
+
|
| 82 |
+
Seg A: "Google had strong earnings but Microsoft missed."
|
| 83 |
+
Seg B: "What do you think of the sentiment of the company Google ?"
|
| 84 |
+
Label: 3-way (negative=0, neutral=1, positive=2)
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
name: str = "qa_m"
|
| 88 |
+
question_template: str = "What do you think of the sentiment of the {entity_type} {entity} ?"
|
| 89 |
+
labels: SentimentLabels = field(default_factory=SentimentLabels)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass(frozen=True)
|
| 93 |
+
class QaBMode:
|
| 94 |
+
"""Sentence-pair QA-B mode (Sun et al. 2019), binary.
|
| 95 |
+
|
| 96 |
+
Seg A: "Google had strong earnings but Microsoft missed."
|
| 97 |
+
Seg B: "The polarity of the company Google is positive ."
|
| 98 |
+
Label: binary (no=0, yes=1)
|
| 99 |
+
|
| 100 |
+
Three forward passes per entity at inference; highest P(yes) wins.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
name: str = "qa_b"
|
| 104 |
+
hypothesis_template: str = "The polarity of the {entity_type} {entity} is {sentiment} ."
|
| 105 |
+
labels: BinaryLabels = field(default_factory=BinaryLabels)
|
| 106 |
+
sentiment_labels: SentimentLabels = field(default_factory=SentimentLabels)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
SENTIMENT_LABELS = SentimentLabels()
|
| 110 |
+
BINARY_LABELS = BinaryLabels()
|
| 111 |
+
LABEL_REMAP = LabelRemap()
|
| 112 |
+
ENTITY_TYPES = EntityTypes()
|
| 113 |
+
|
| 114 |
+
MARKER_MODE = MarkerMode()
|
| 115 |
+
QA_M_MODE = QaMMode()
|
| 116 |
+
QA_B_MODE = QaBMode()
|
| 117 |
+
|
| 118 |
+
MODES: dict[str, MarkerMode | QaMMode | QaBMode] = {
|
| 119 |
+
"marker": MARKER_MODE,
|
| 120 |
+
"qa_m": QA_M_MODE,
|
| 121 |
+
"qa_b": QA_B_MODE,
|
| 122 |
+
}
|
src/schemas/requests.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PositionInput(BaseModel):
|
| 7 |
+
position_text: str
|
| 8 |
+
length: int
|
| 9 |
+
offset: int
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EntityInput(BaseModel):
|
| 13 |
+
entity_id: int
|
| 14 |
+
entity_text: str
|
| 15 |
+
entity_type: Literal["company", "location"]
|
| 16 |
+
positions: list[PositionInput]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SampleInput(BaseModel):
|
| 20 |
+
id: int
|
| 21 |
+
text: str
|
| 22 |
+
entities: list[EntityInput]
|
src/schemas/responses.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EntityOutput(BaseModel):
|
| 7 |
+
entity_id: int
|
| 8 |
+
entity_text: str
|
| 9 |
+
classification: Literal["positive", "negative", "neutral"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SampleOutput(BaseModel):
|
| 13 |
+
id: int
|
| 14 |
+
entities: list[EntityOutput]
|