File size: 5,864 Bytes
51620d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import json
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from src.schemas.labels import SENTIMENT_LABELS
def load_data(path: str) -> list[dict]:
with open(path) as f:
return [json.loads(line) for line in f]
def deduplicate_positions(samples: list[dict]) -> list[dict]:
"""Select one position per entity.
Prefers the position whose position_text matches entity_text exactly
(case-insensitive). If none matches, selects the longest position.
"""
out = []
for s in samples:
new_entities = []
for e in s["entities"]:
positions = e["positions"]
if not positions:
new_entities.append(e)
continue
exact = [
p for p in positions
if p["position_text"].lower() == e["entity_text"].lower()
]
if exact:
best = max(exact, key=lambda p: p["length"])
else:
best = max(positions, key=lambda p: p["length"])
new_entities.append({**e, "positions": [best]})
out.append({**s, "entities": new_entities})
return out
def flatten_to_examples(
samples: list[dict],
mode: str,
) -> list[dict]:
"""Flatten augmented data to one example per (entity, position) pair.
Reads pre-computed fields from the augmented JSONL:
marker -> seg_a = marker_text, seg_b = None
qa_m -> seg_a = entity_centered_window, seg_b = qa_m_question
qa_b -> 3 binary examples per position using qa_b_hypotheses
"""
sentiments = list(SENTIMENT_LABELS.classes)
label2id = SENTIMENT_LABELS.label2id
examples = []
for s in samples:
for e in s["entities"]:
label_str = e.get("label")
base = {
"sample_id": s["id"],
"entity_id": e["entity_id"],
"entity_text": e["entity_text"],
"entity_type": e["entity_type"],
}
for p in e["positions"]:
if mode == "marker":
ex = {**base, "seg_a": p["marker_text"], "seg_b": None}
if label_str in label2id:
ex["label"] = label2id[label_str]
examples.append(ex)
elif mode == "qa_m":
ex = {
**base,
"seg_a": p["entity_centered_window"],
"seg_b": p["qa_m_question"],
}
if label_str in label2id:
ex["label"] = label2id[label_str]
examples.append(ex)
elif mode == "qa_b":
for sentiment in sentiments:
ex = {
**base,
"seg_a": p["entity_centered_window"],
"seg_b": p["qa_b_hypotheses"][sentiment],
"sentiment": sentiment,
}
if label_str in label2id:
ex["label"] = 1 if sentiment == label_str else 0
examples.append(ex)
else:
raise ValueError(f"Unknown mode: {mode!r}")
return examples
def split_data(
examples: list[dict], val_frac: float, test_frac: float, seed: int = 42
) -> tuple[list[dict], list[dict], list[dict]]:
"""Split at the *sample* level"""
sample_ids = np.array(list({e["sample_id"] for e in examples}))
remaining_ids, test_ids = train_test_split(
sample_ids, test_size=test_frac, random_state=seed
)
val_frac_adj = val_frac / (1.0 - test_frac)
train_ids, val_ids = train_test_split(
remaining_ids, test_size=val_frac_adj, random_state=seed
)
train_set = set(train_ids)
val_set = set(val_ids)
test_set = set(test_ids)
return (
[e for e in examples if e["sample_id"] in train_set],
[e for e in examples if e["sample_id"] in val_set],
[e for e in examples if e["sample_id"] in test_set],
)
class EntitySentimentDataset(Dataset):
def __init__(self, examples: list[dict], tokenizer, max_len: int):
self.examples = examples
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> dict:
ex = self.examples[idx]
seg_a = ex["seg_a"]
seg_b = ex["seg_b"]
if seg_b is None:
enc = self.tokenizer(
seg_a,
max_length=self.max_len,
truncation=True,
padding="max_length",
return_tensors="pt",
)
else:
enc = self.tokenizer(
seg_a, seg_b,
max_length=self.max_len,
truncation="only_first",
padding="max_length",
return_tensors="pt",
)
item = {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
}
if "label" in ex:
item["labels"] = torch.tensor(ex["label"], dtype=torch.long)
return item
class DeduplicatedEntitySentimentDataset(EntitySentimentDataset):
"""Like EntitySentimentDataset but with one position per entity.
Applies deduplicate_positions before flattening, so each entity
contributes exactly one training example.
"""
def __init__(self, samples: list[dict], mode: str, tokenizer, max_len: int):
deduped = deduplicate_positions(samples)
examples = flatten_to_examples(deduped, mode=mode)
super().__init__(examples, tokenizer, max_len)
|