import json import numpy as np import torch from sklearn.model_selection import train_test_split from torch.utils.data import Dataset from src.schemas.labels import SENTIMENT_LABELS def load_data(path: str) -> list[dict]: with open(path) as f: return [json.loads(line) for line in f] def deduplicate_positions(samples: list[dict]) -> list[dict]: """Select one position per entity. Prefers the position whose position_text matches entity_text exactly (case-insensitive). If none matches, selects the longest position. """ out = [] for s in samples: new_entities = [] for e in s["entities"]: positions = e["positions"] if not positions: new_entities.append(e) continue exact = [ p for p in positions if p["position_text"].lower() == e["entity_text"].lower() ] if exact: best = max(exact, key=lambda p: p["length"]) else: best = max(positions, key=lambda p: p["length"]) new_entities.append({**e, "positions": [best]}) out.append({**s, "entities": new_entities}) return out def flatten_to_examples( samples: list[dict], mode: str, ) -> list[dict]: """Flatten augmented data to one example per (entity, position) pair. Reads pre-computed fields from the augmented JSONL: marker -> seg_a = marker_text, seg_b = None qa_m -> seg_a = entity_centered_window, seg_b = qa_m_question qa_b -> 3 binary examples per position using qa_b_hypotheses """ sentiments = list(SENTIMENT_LABELS.classes) label2id = SENTIMENT_LABELS.label2id examples = [] for s in samples: for e in s["entities"]: label_str = e.get("label") base = { "sample_id": s["id"], "entity_id": e["entity_id"], "entity_text": e["entity_text"], "entity_type": e["entity_type"], } for p in e["positions"]: if mode == "marker": ex = {**base, "seg_a": p["marker_text"], "seg_b": None} if label_str in label2id: ex["label"] = label2id[label_str] examples.append(ex) elif mode == "qa_m": ex = { **base, "seg_a": p["entity_centered_window"], "seg_b": p["qa_m_question"], } if label_str in label2id: ex["label"] = label2id[label_str] examples.append(ex) elif mode == "qa_b": for sentiment in sentiments: ex = { **base, "seg_a": p["entity_centered_window"], "seg_b": p["qa_b_hypotheses"][sentiment], "sentiment": sentiment, } if label_str in label2id: ex["label"] = 1 if sentiment == label_str else 0 examples.append(ex) else: raise ValueError(f"Unknown mode: {mode!r}") return examples def split_data( examples: list[dict], val_frac: float, test_frac: float, seed: int = 42 ) -> tuple[list[dict], list[dict], list[dict]]: """Split at the *sample* level""" sample_ids = np.array(list({e["sample_id"] for e in examples})) remaining_ids, test_ids = train_test_split( sample_ids, test_size=test_frac, random_state=seed ) val_frac_adj = val_frac / (1.0 - test_frac) train_ids, val_ids = train_test_split( remaining_ids, test_size=val_frac_adj, random_state=seed ) train_set = set(train_ids) val_set = set(val_ids) test_set = set(test_ids) return ( [e for e in examples if e["sample_id"] in train_set], [e for e in examples if e["sample_id"] in val_set], [e for e in examples if e["sample_id"] in test_set], ) class EntitySentimentDataset(Dataset): def __init__(self, examples: list[dict], tokenizer, max_len: int): self.examples = examples self.tokenizer = tokenizer self.max_len = max_len def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> dict: ex = self.examples[idx] seg_a = ex["seg_a"] seg_b = ex["seg_b"] if seg_b is None: enc = self.tokenizer( seg_a, max_length=self.max_len, truncation=True, padding="max_length", return_tensors="pt", ) else: enc = self.tokenizer( seg_a, seg_b, max_length=self.max_len, truncation="only_first", padding="max_length", return_tensors="pt", ) item = { "input_ids": enc["input_ids"].squeeze(0), "attention_mask": enc["attention_mask"].squeeze(0), } if "label" in ex: item["labels"] = torch.tensor(ex["label"], dtype=torch.long) return item class DeduplicatedEntitySentimentDataset(EntitySentimentDataset): """Like EntitySentimentDataset but with one position per entity. Applies deduplicate_positions before flattening, so each entity contributes exactly one training example. """ def __init__(self, samples: list[dict], mode: str, tokenizer, max_len: int): deduped = deduplicate_positions(samples) examples = flatten_to_examples(deduped, mode=mode) super().__init__(examples, tokenizer, max_len)