| import json |
| import numpy as np |
| import torch |
| from sklearn.model_selection import train_test_split |
| from torch.utils.data import Dataset |
| from src.schemas.labels import SENTIMENT_LABELS |
|
|
|
|
| def load_data(path: str) -> list[dict]: |
| with open(path) as f: |
| return [json.loads(line) for line in f] |
|
|
|
|
| def deduplicate_positions(samples: list[dict]) -> list[dict]: |
| """Select one position per entity. |
| |
| Prefers the position whose position_text matches entity_text exactly |
| (case-insensitive). If none matches, selects the longest position. |
| """ |
| out = [] |
| for s in samples: |
| new_entities = [] |
| for e in s["entities"]: |
| positions = e["positions"] |
| if not positions: |
| new_entities.append(e) |
| continue |
|
|
| exact = [ |
| p for p in positions |
| if p["position_text"].lower() == e["entity_text"].lower() |
| ] |
|
|
| if exact: |
| best = max(exact, key=lambda p: p["length"]) |
| else: |
| best = max(positions, key=lambda p: p["length"]) |
|
|
| new_entities.append({**e, "positions": [best]}) |
| out.append({**s, "entities": new_entities}) |
| return out |
|
|
|
|
| def flatten_to_examples( |
| samples: list[dict], |
| mode: str, |
| ) -> list[dict]: |
| """Flatten augmented data to one example per (entity, position) pair. |
| |
| Reads pre-computed fields from the augmented JSONL: |
| marker -> seg_a = marker_text, seg_b = None |
| qa_m -> seg_a = entity_centered_window, seg_b = qa_m_question |
| qa_b -> 3 binary examples per position using qa_b_hypotheses |
| """ |
| sentiments = list(SENTIMENT_LABELS.classes) |
| label2id = SENTIMENT_LABELS.label2id |
| examples = [] |
|
|
| for s in samples: |
| for e in s["entities"]: |
| label_str = e.get("label") |
|
|
| base = { |
| "sample_id": s["id"], |
| "entity_id": e["entity_id"], |
| "entity_text": e["entity_text"], |
| "entity_type": e["entity_type"], |
| } |
|
|
| for p in e["positions"]: |
| if mode == "marker": |
| ex = {**base, "seg_a": p["marker_text"], "seg_b": None} |
| if label_str in label2id: |
| ex["label"] = label2id[label_str] |
| examples.append(ex) |
|
|
| elif mode == "qa_m": |
| ex = { |
| **base, |
| "seg_a": p["entity_centered_window"], |
| "seg_b": p["qa_m_question"], |
| } |
| if label_str in label2id: |
| ex["label"] = label2id[label_str] |
| examples.append(ex) |
|
|
| elif mode == "qa_b": |
| for sentiment in sentiments: |
| ex = { |
| **base, |
| "seg_a": p["entity_centered_window"], |
| "seg_b": p["qa_b_hypotheses"][sentiment], |
| "sentiment": sentiment, |
| } |
| if label_str in label2id: |
| ex["label"] = 1 if sentiment == label_str else 0 |
| examples.append(ex) |
|
|
| else: |
| raise ValueError(f"Unknown mode: {mode!r}") |
|
|
| return examples |
|
|
|
|
| def split_data( |
| examples: list[dict], val_frac: float, test_frac: float, seed: int = 42 |
| ) -> tuple[list[dict], list[dict], list[dict]]: |
| """Split at the *sample* level""" |
| sample_ids = np.array(list({e["sample_id"] for e in examples})) |
|
|
| remaining_ids, test_ids = train_test_split( |
| sample_ids, test_size=test_frac, random_state=seed |
| ) |
| val_frac_adj = val_frac / (1.0 - test_frac) |
| train_ids, val_ids = train_test_split( |
| remaining_ids, test_size=val_frac_adj, random_state=seed |
| ) |
|
|
| train_set = set(train_ids) |
| val_set = set(val_ids) |
| test_set = set(test_ids) |
|
|
| return ( |
| [e for e in examples if e["sample_id"] in train_set], |
| [e for e in examples if e["sample_id"] in val_set], |
| [e for e in examples if e["sample_id"] in test_set], |
| ) |
|
|
|
|
| class EntitySentimentDataset(Dataset): |
| def __init__(self, examples: list[dict], tokenizer, max_len: int): |
| self.examples = examples |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| ex = self.examples[idx] |
| seg_a = ex["seg_a"] |
| seg_b = ex["seg_b"] |
|
|
| if seg_b is None: |
| enc = self.tokenizer( |
| seg_a, |
| max_length=self.max_len, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| else: |
| enc = self.tokenizer( |
| seg_a, seg_b, |
| max_length=self.max_len, |
| truncation="only_first", |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| item = { |
| "input_ids": enc["input_ids"].squeeze(0), |
| "attention_mask": enc["attention_mask"].squeeze(0), |
| } |
| if "label" in ex: |
| item["labels"] = torch.tensor(ex["label"], dtype=torch.long) |
| return item |
|
|
|
|
| class DeduplicatedEntitySentimentDataset(EntitySentimentDataset): |
| """Like EntitySentimentDataset but with one position per entity. |
| |
| Applies deduplicate_positions before flattening, so each entity |
| contributes exactly one training example. |
| """ |
|
|
| def __init__(self, samples: list[dict], mode: str, tokenizer, max_len: int): |
| deduped = deduplicate_positions(samples) |
| examples = flatten_to_examples(deduped, mode=mode) |
| super().__init__(examples, tokenizer, max_len) |
|
|