| | """Dataset preparation code.""" |
| |
|
| | import torch |
| |
|
| |
|
| | from collections import defaultdict |
| | from dataclasses import dataclass |
| |
|
| | from torch.nn.utils.rnn import pad_sequence |
| | from torch.utils.data import Dataset, DataLoader |
| | from datasets import load_dataset |
| |
|
| |
|
| | class DataRegistry: |
| | _registry = {} |
| |
|
| | @classmethod |
| | def register(cls, name): |
| | def decorator(func): |
| | cls._registry[name] = func |
| | return func |
| | return decorator |
| |
|
| | @classmethod |
| | def get(cls, name): |
| | return cls._registry[name] |
| |
|
| |
|
| | @dataclass |
| | class DataBlob: |
| | train_loader: DataLoader |
| | val_loader: DataLoader |
| | test_loader: DataLoader |
| | label2id: dict[int, str] | None = None |
| |
|
| |
|
| | def build_label_mapping(loader: DataLoader): |
| | idx = 0 |
| | label_to_id = {"O": idx} |
| | for batch in loader: |
| | for item in batch: |
| | labels = batch["gold_labels"] |
| | for annotation in labels: |
| | if annotation: |
| | label = list(annotation.values())[0] |
| | if label not in label_to_id: |
| | idx += 1 |
| | label_to_id[label] = idx |
| | return label_to_id |
| |
|
| |
|
| | class LitBankEntityDataset(Dataset): |
| | def __init__(self, hf_dataset): |
| | self.dataset = hf_dataset |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.dataset[idx] |
| | tokens = item["sentence"] |
| | spans = item["entity_spans"] or [] |
| |
|
| | |
| | starts = torch.zeros(len(tokens), dtype=torch.long) |
| | for s, _ in spans: |
| | if s < len(tokens): |
| | starts[s] = 1 |
| |
|
| | return { |
| | "sentence": tokens, |
| | "starts": starts, |
| | "entity_spans": spans, |
| | "entity_labels": item["entity_labels"] or [], |
| | "task_id": 1 |
| | } |
| |
|
| |
|
| | class LitBankMentionDataset(Dataset): |
| | def __init__(self, hf_dataset): |
| | self.dataset = hf_dataset |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.dataset[idx] |
| | tokens = item["sentence"] |
| | |
| | mentions = item["mentions"] if item["mentions"] is not None else [] |
| |
|
| | n_tokens = len(tokens) |
| | starts = torch.zeros(n_tokens, dtype=torch.long) |
| | span_labels = torch.zeros((n_tokens, n_tokens), dtype=torch.long) |
| |
|
| | for s, e in mentions: |
| | |
| | if s < n_tokens and e < n_tokens: |
| | starts[s] = 1 |
| | span_labels[s, e] = 1 |
| |
|
| | return { |
| | "tokens": tokens, |
| | "starts": starts, |
| | "span_labels": span_labels, |
| | "task_id": 0, |
| | } |
| |
|
| |
|
| | def mentions_by_sentence(example): |
| | mentions_per_sentence = defaultdict(list) |
| | for cluster in example["coref_chains"]: |
| | for mention in cluster: |
| | sent_idx, start, end = mention |
| | |
| | mentions_per_sentence[str(sent_idx)].append((start, end)) |
| | example["mentions"] = mentions_per_sentence |
| | return example |
| |
|
| |
|
| | def flatten_to_sentences(batch): |
| | new_batch = {"sentence": [], "mentions": []} |
| |
|
| | |
| | for sentences, mentions_dict in zip(batch["sentences"], batch["mentions"]): |
| | |
| | if mentions_dict is None: |
| | mentions_dict = {} |
| |
|
| | for i, sent in enumerate(sentences): |
| | |
| | sent_mentions = mentions_dict.get(str(i), []) |
| |
|
| | new_batch["sentence"].append(sent) |
| | new_batch["mentions"].append(sent_mentions) |
| |
|
| | return new_batch |
| |
|
| |
|
| | def extract_spans_from_bio(sentence_tokens): |
| | spans = [] |
| | labels = [] |
| | current_span = None |
| |
|
| | for i, token_data in enumerate(sentence_tokens): |
| | tag = token_data["bio_tags"][0] if token_data["bio_tags"] else "O" |
| |
|
| | if tag.startswith("B-"): |
| | if current_span: |
| | spans.append(tuple(current_span)) |
| | label = tag.split("-")[1] |
| | current_span = [i, i] |
| | labels.append(label) |
| |
|
| | elif tag.startswith("I-") and current_span: |
| | current_span[1] = i |
| |
|
| | else: |
| | if current_span: |
| | spans.append(tuple(current_span)) |
| | current_span = None |
| |
|
| | if current_span: |
| | spans.append(tuple(current_span)) |
| |
|
| | return spans, labels |
| |
|
| | def flatten_entities(batch): |
| | new_batch = { |
| | "sentence": [], |
| | "entity_spans": [], |
| | "entity_labels": [] |
| | } |
| | for doc_sentences in batch["entities"]: |
| | for sentence_tokens in doc_sentences: |
| | tokens = [t["token"] for t in sentence_tokens] |
| | spans, labels = extract_spans_from_bio(sentence_tokens) |
| |
|
| | new_batch["sentence"].append(tokens) |
| | new_batch["entity_spans"].append(spans) |
| | new_batch["entity_labels"].append(labels) |
| | return new_batch |
| |
|
| |
|
| | def collate_fn(batch): |
| | sentences = [item["tokens"] for item in batch] |
| | |
| | max_len = max(len(s) for s in sentences) |
| | starts_list = [] |
| | spans_list = [] |
| |
|
| | for item in batch: |
| | curr_len = len(item["starts"]) |
| | starts_list.append(item["starts"]) |
| | padded_span = torch.zeros((max_len, max_len), dtype=torch.long) |
| | padded_span[:curr_len, :curr_len] = item["span_labels"] |
| | spans_list.append(padded_span) |
| |
|
| | |
| | starts_padded = pad_sequence(starts_list, batch_first=True, padding_value=-1) |
| | token_mask = starts_padded != -1 |
| | starts_padded[starts_padded == -1] = 0 |
| |
|
| | |
| | spans_padded = torch.stack(spans_list) |
| | |
| | valid_len_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1) |
| | |
| | upper_tri_mask = torch.triu( |
| | torch.ones((max_len, max_len), dtype=torch.bool), |
| | diagonal=0, |
| | ) |
| | |
| | is_start_mask = starts_padded.unsqueeze(2).bool() |
| | |
| | span_loss_mask = valid_len_mask & upper_tri_mask & is_start_mask |
| |
|
| | return { |
| | "sentences": sentences, |
| | "starts": starts_padded, |
| | "spans": spans_padded, |
| | "token_mask": token_mask, |
| | "span_loss_mask": span_loss_mask, |
| | "task_id": torch.tensor([item["task_id"] for item in batch]), |
| | } |
| |
|
| |
|
| | def entity_collate_fn(batch): |
| | |
| | sentences = [item["sentence"] for item in batch] |
| | max_len = max(len(s) for s in sentences) |
| |
|
| | starts_list = [] |
| | spans_list = [] |
| | gold_label_maps = [] |
| |
|
| | for item in batch: |
| | starts_list.append(item["starts"]) |
| |
|
| | |
| | binary_span_matrix = torch.zeros((max_len, max_len), dtype=torch.long) |
| | current_labels = {} |
| |
|
| | |
| | for (s, e), label_str in zip(item["entity_spans"], item["entity_labels"]): |
| | if s < max_len and e < max_len: |
| | binary_span_matrix[s, e] = 1 |
| | current_labels[(s, e)] = label_str |
| |
|
| | spans_list.append(binary_span_matrix) |
| | gold_label_maps.append(current_labels) |
| |
|
| | |
| | starts_padded = pad_sequence(starts_list, batch_first=True, padding_value=-1) |
| | token_mask = starts_padded != -1 |
| |
|
| | |
| | starts_targets = starts_padded.clone() |
| | starts_targets[starts_targets == -1] = 0 |
| |
|
| | spans_padded = torch.stack(spans_list) |
| |
|
| | valid_len_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1) |
| | upper_tri_mask = torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), 0) |
| | is_start_mask = starts_targets.unsqueeze(2).bool() |
| | span_loss_mask = valid_len_mask & upper_tri_mask & is_start_mask |
| |
|
| | return { |
| | "sentences": sentences, |
| | "starts": starts_targets, |
| | "spans": spans_padded, |
| | "gold_labels": gold_label_maps, |
| | "token_mask": token_mask, |
| | "span_loss_mask": span_loss_mask, |
| | "task_id": torch.tensor([item["task_id"] for item in batch]) |
| | } |
| |
|
| |
|
| | def debug_print_entity_batch(batch): |
| | sentences = batch["sentences"] |
| | gold_labels_list = batch["gold_labels"] |
| | task_ids = batch["task_id"] |
| |
|
| | print(f"--- Batch Debug (Size: {len(sentences)}) ---") |
| |
|
| | for i, (tokens, labels_dict) in enumerate(zip(sentences, gold_labels_list)): |
| | task_name = "Entity" if task_ids[i] == 1 else "Mention" |
| | print(f"\n[Sentence {i}] Task: {task_name}") |
| | print(f"Text: {' '.join(tokens)}") |
| |
|
| | if not labels_dict: |
| | print(" No entities found.") |
| | continue |
| |
|
| | print(" Entities:") |
| | for (start, end), label in labels_dict.items(): |
| | |
| | entity_text = " ".join(tokens[start:end]) |
| | print(f" - [{label}] '{entity_text}' (indices: {start}:{end})") |
| |
|
| |
|
| | @DataRegistry.register("litbank_mentions") |
| | def make_litbank( |
| | repo_id: str = "coref-data/litbank_raw", |
| | tag: str = "split_0", |
| | batch_size: int = 4, |
| | ) -> tuple[DataLoader, DataLoader, DataLoader]: |
| | """Reformat litbank to as a sentence-level mention-detection dataset.""" |
| | litbank = load_dataset(repo_id, tag) |
| | litbank_sentences_mentions = litbank.map(mentions_by_sentence).map( |
| | flatten_to_sentences, batched=True, remove_columns=litbank["train"].column_names |
| | ) |
| | no = 0 |
| | for i in range(len(litbank_sentences_mentions["train"])): |
| | mentions = litbank_sentences_mentions["train"][i]["mentions"] |
| | |
| | if mentions is None or len(mentions) == 0: |
| | no += 1 |
| | print(f"Training sentences without mentions: {no}.") |
| | bs = batch_size |
| | train = LitBankMentionDataset(litbank_sentences_mentions["train"]) |
| | val = LitBankMentionDataset(litbank_sentences_mentions["validation"]) |
| | test = LitBankMentionDataset(litbank_sentences_mentions["test"]) |
| | train_loader = DataLoader(train, batch_size=bs, shuffle=True, collate_fn=collate_fn) |
| | val_loader = DataLoader(val, batch_size=bs, shuffle=False, collate_fn=collate_fn) |
| | test_loader = DataLoader(test, batch_size=bs, shuffle=False, collate_fn=collate_fn) |
| | |
| | try: |
| | next(iter(train_loader)) |
| | except Exception as e: |
| | raise e |
| | return DataBlob(train_loader, val_loader, test_loader) |
| |
|
| |
|
| | @DataRegistry.register("litbank_entities") |
| | def make_litbank_entity( |
| | repo_id: str = "coref-data/litbank_raw", |
| | tag: str = "split_0", |
| | batch_size: int = 4, |
| | ) -> tuple[DataLoader, DataLoader, DataLoader]: |
| | litbank = load_dataset(repo_id, tag) |
| | entities_data = litbank.map( |
| | flatten_entities, |
| | batched=True, |
| | remove_columns=litbank["train"].column_names |
| | ) |
| | bs = batch_size |
| | train = LitBankEntityDataset(entities_data["train"]) |
| | val = LitBankEntityDataset(entities_data["validation"]) |
| | test = LitBankEntityDataset(entities_data["test"]) |
| | train_loader = DataLoader(train, batch_size=bs, shuffle=True, collate_fn=entity_collate_fn) |
| | val_loader = DataLoader(val, batch_size=bs, shuffle=False, collate_fn=entity_collate_fn) |
| | test_loader = DataLoader(test, batch_size=bs, shuffle=False, collate_fn=entity_collate_fn) |
| | try: |
| | next(iter(train_loader)) |
| | except Exception as e: |
| | raise e |
| | label2id = build_label_mapping(train_loader) |
| | return DataBlob(train_loader, val_loader, test_loader, label2id) |
| |
|