lamossta commited on
Commit
50aa44f
Β·
1 Parent(s): 399f588
src/schemas/data.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, fields
2
+
3
+
4
+ @dataclass(frozen=True)
5
+ class Position:
6
+ """A single character span within a sample's text."""
7
+
8
+ position_text: str = ""
9
+ length: int = 0
10
+ offset: int = 0
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class Entity:
15
+ """A mention group for one entity within a sample."""
16
+
17
+ entity_id: str = ""
18
+ entity_text: str = ""
19
+ entity_type: str = ""
20
+ positions: list = field(default_factory=list)
21
+ label: str = ""
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class Sample:
26
+ """A single news article with its annotated entities."""
27
+
28
+ id: str = ""
29
+ text: str = ""
30
+ entities: list = field(default_factory=list)
31
+
32
+
33
+ def required_keys(cls) -> set[str]:
34
+ """Return the set of required top-level keys for a data dataclass."""
35
+ return {f.name for f in fields(cls)}
src/schemas/labels.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass(frozen=True)
5
+ class LabelRemap:
6
+ """Maps non-standard label strings to canonical equivalents."""
7
+
8
+ mapping: dict[str, str] = field(
9
+ default_factory=lambda: {"very positive": "positive"}
10
+ )
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class EntityTypes:
15
+ """Accepted entity type values."""
16
+
17
+ types: tuple[str, ...] = ("company", "location")
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class SentimentLabels:
22
+ """3-class sentiment label schema."""
23
+
24
+ classes: tuple[str, ...] = ("negative", "neutral", "positive")
25
+ label2id: dict[str, int] = field(default_factory=dict)
26
+ id2label: dict[int, str] = field(default_factory=dict)
27
+
28
+ def __post_init__(self):
29
+ object.__setattr__(
30
+ self, "label2id", {s: i for i, s in enumerate(self.classes)}
31
+ )
32
+ object.__setattr__(
33
+ self, "id2label", {i: s for i, s in enumerate(self.classes)}
34
+ )
35
+
36
+ @property
37
+ def num_labels(self) -> int:
38
+ return len(self.classes)
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class BinaryLabels:
43
+ """Binary (yes / no) label schema used by qa_b mode."""
44
+
45
+ classes: tuple[str, ...] = ("no", "yes")
46
+ label2id: dict[str, int] = field(default_factory=dict)
47
+ id2label: dict[int, str] = field(default_factory=dict)
48
+
49
+ def __post_init__(self):
50
+ object.__setattr__(
51
+ self, "label2id", {s: i for i, s in enumerate(self.classes)}
52
+ )
53
+ object.__setattr__(
54
+ self, "id2label", {i: s for i, s in enumerate(self.classes)}
55
+ )
56
+
57
+ @property
58
+ def num_labels(self) -> int:
59
+ return len(self.classes)
60
+
61
+
62
+ # ── Training modes ────────────────────────────────────────────────────────────
63
+
64
+ @dataclass(frozen=True)
65
+ class MarkerMode:
66
+ """Single-sequence mode: wrap entity with [E]...[/E] special tokens.
67
+
68
+ Seg A: "[E] Google [/E] had strong earnings but Microsoft missed."
69
+ Label: 3-way (negative=0, neutral=1, positive=2)
70
+ """
71
+
72
+ name: str = "marker"
73
+ entity_start: str = "[E]"
74
+ entity_end: str = "[/E]"
75
+ labels: SentimentLabels = field(default_factory=SentimentLabels)
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class QaMMode:
80
+ """Sentence-pair QA-M mode (Sun et al. 2019).
81
+
82
+ Seg A: "Google had strong earnings but Microsoft missed."
83
+ Seg B: "What do you think of the sentiment of the company Google ?"
84
+ Label: 3-way (negative=0, neutral=1, positive=2)
85
+ """
86
+
87
+ name: str = "qa_m"
88
+ question_template: str = "What do you think of the sentiment of the {entity_type} {entity} ?"
89
+ labels: SentimentLabels = field(default_factory=SentimentLabels)
90
+
91
+
92
+ @dataclass(frozen=True)
93
+ class QaBMode:
94
+ """Sentence-pair QA-B mode (Sun et al. 2019), binary.
95
+
96
+ Seg A: "Google had strong earnings but Microsoft missed."
97
+ Seg B: "The polarity of the company Google is positive ."
98
+ Label: binary (no=0, yes=1)
99
+
100
+ Three forward passes per entity at inference; highest P(yes) wins.
101
+ """
102
+
103
+ name: str = "qa_b"
104
+ hypothesis_template: str = "The polarity of the {entity_type} {entity} is {sentiment} ."
105
+ labels: BinaryLabels = field(default_factory=BinaryLabels)
106
+ sentiment_labels: SentimentLabels = field(default_factory=SentimentLabels)
107
+
108
+
109
+ SENTIMENT_LABELS = SentimentLabels()
110
+ BINARY_LABELS = BinaryLabels()
111
+ LABEL_REMAP = LabelRemap()
112
+ ENTITY_TYPES = EntityTypes()
113
+
114
+ MARKER_MODE = MarkerMode()
115
+ QA_M_MODE = QaMMode()
116
+ QA_B_MODE = QaBMode()
117
+
118
+ MODES: dict[str, MarkerMode | QaMMode | QaBMode] = {
119
+ "marker": MARKER_MODE,
120
+ "qa_m": QA_M_MODE,
121
+ "qa_b": QA_B_MODE,
122
+ }
src/schemas/requests.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class PositionInput(BaseModel):
7
+ position_text: str
8
+ length: int
9
+ offset: int
10
+
11
+
12
+ class EntityInput(BaseModel):
13
+ entity_id: int
14
+ entity_text: str
15
+ entity_type: Literal["company", "location"]
16
+ positions: list[PositionInput]
17
+
18
+
19
+ class SampleInput(BaseModel):
20
+ id: int
21
+ text: str
22
+ entities: list[EntityInput]
src/schemas/responses.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class EntityOutput(BaseModel):
7
+ entity_id: int
8
+ entity_text: str
9
+ classification: Literal["positive", "negative", "neutral"]
10
+
11
+
12
+ class SampleOutput(BaseModel):
13
+ id: int
14
+ entities: list[EntityOutput]