File size: 12,126 Bytes
ca23a08
 
ba006b9
 
 
 
0ba7b45
ba006b9
 
 
 
 
 
0ba7b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c67463
0ba7b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba006b9
 
16cd637
 
 
ba006b9
16cd637
 
ba006b9
16cd637
 
 
 
ba006b9
16cd637
 
 
 
 
ba006b9
16cd637
 
 
 
 
 
 
ba006b9
 
ca23a08
ba006b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba7b45
ba006b9
 
 
16cd637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba006b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca23a08
ba006b9
 
 
 
 
 
ca23a08
ba006b9
 
 
 
 
ca23a08
ba006b9
 
 
 
caa85fe
ba006b9
 
 
16cd637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba7b45
ca23a08
 
 
0ba7b45
ca23a08
ba006b9
ca23a08
ba006b9
 
 
 
 
 
 
 
 
 
0ba7b45
ca23a08
 
 
0ba7b45
 
 
ba006b9
 
 
 
 
0ba7b45
16cd637
 
0ba7b45
16cd637
 
 
0ba7b45
16cd637
 
 
 
 
 
 
0ba7b45
16cd637
 
 
0ba7b45
 
 
16cd637
 
 
 
0ba7b45
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
"""Dataset preparation code."""

import torch


from collections import defaultdict
from dataclasses import dataclass

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset


class DataRegistry:
    _registry = {}

    @classmethod
    def register(cls, name):
        def decorator(func):
            cls._registry[name] = func
            return func
        return decorator

    @classmethod
    def get(cls, name):
        return cls._registry[name]


@dataclass
class DataBlob:
    train_loader: DataLoader
    val_loader: DataLoader
    test_loader: DataLoader
    label2id: dict[int, str] | None = None


def build_label_mapping(loader: DataLoader):
    idx = 0
    label_to_id = {"O": idx}
    for batch in loader:
        for item in batch:
            labels = batch["gold_labels"]
            for annotation in labels:
                if annotation:
                    label = list(annotation.values())[0]
                    if label not in label_to_id:
                        idx += 1
                        label_to_id[label] = idx
    return label_to_id


class LitBankEntityDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        tokens = item["sentence"]
        spans = item["entity_spans"] or []

        # Create binary start mask for the 1D detector
        starts = torch.zeros(len(tokens), dtype=torch.long)
        for s, _ in spans:
            if s < len(tokens):
                starts[s] = 1

        return {
            "sentence": tokens,
            "starts": starts,
            "entity_spans": spans,
            "entity_labels": item["entity_labels"] or [],
            "task_id": 1
        }


class LitBankMentionDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        tokens = item["sentence"]
        # The ArrowDataset gives None for [].
        mentions = item["mentions"] if item["mentions"] is not None else []

        n_tokens = len(tokens)
        starts = torch.zeros(n_tokens, dtype=torch.long)
        span_labels = torch.zeros((n_tokens, n_tokens), dtype=torch.long)

        for s, e in mentions:
            # Ensure indices are within bounds (LitBank e is often inclusive)
            if s < n_tokens and e < n_tokens:
                starts[s] = 1
                span_labels[s, e] = 1

        return {
            "tokens": tokens,
            "starts": starts,
            "span_labels": span_labels,
            "task_id": 0,
        }


def mentions_by_sentence(example):
    mentions_per_sentence = defaultdict(list)
    for cluster in example["coref_chains"]:
        for mention in cluster:
            sent_idx, start, end = mention
            # In the ArrowDataset have to use str or byte as key.
            mentions_per_sentence[str(sent_idx)].append((start, end))
    example["mentions"] = mentions_per_sentence
    return example


def flatten_to_sentences(batch):
    new_batch = {"sentence": [], "mentions": []}

    # Ensure we are iterating over the lists in the batch
    for sentences, mentions_dict in zip(batch["sentences"], batch["mentions"]):
        # Some versions of datasets might save dicts as None if empty
        if mentions_dict is None:
            mentions_dict = {}

        for i, sent in enumerate(sentences):
            # Safe access: get the list of mentions or empty list
            sent_mentions = mentions_dict.get(str(i), [])

            new_batch["sentence"].append(sent)
            new_batch["mentions"].append(sent_mentions)

    return new_batch


def extract_spans_from_bio(sentence_tokens):
    spans = []
    labels = []
    current_span = None

    for i, token_data in enumerate(sentence_tokens):
        tag = token_data["bio_tags"][0] if token_data["bio_tags"] else "O"

        if tag.startswith("B-"):
            if current_span:
                spans.append(tuple(current_span))
            label = tag.split("-")[1]
            current_span = [i, i]  # inclusive start/end
            labels.append(label)

        elif tag.startswith("I-") and current_span:
            current_span[1] = i  # inclusive extension

        else:
            if current_span:
                spans.append(tuple(current_span))
                current_span = None

    if current_span:
        spans.append(tuple(current_span))

    return spans, labels

def flatten_entities(batch):
    new_batch = {
        "sentence": [],
        "entity_spans": [],
        "entity_labels": []
    }
    for doc_sentences in batch["entities"]:
        for sentence_tokens in doc_sentences:
            tokens = [t["token"] for t in sentence_tokens]
            spans, labels = extract_spans_from_bio(sentence_tokens)

            new_batch["sentence"].append(tokens)
            new_batch["entity_spans"].append(spans)
            new_batch["entity_labels"].append(labels)
    return new_batch


def collate_fn(batch):
    sentences = [item["tokens"] for item in batch]
    # Padding up to longest sentence.
    max_len = max(len(s) for s in sentences)
    starts_list = []  # 0 - 1 indicator for start tokens.
    spans_list = []  # 0 - 1 indicator for (start, end) pairs.

    for item in batch:
        curr_len = len(item["starts"])
        starts_list.append(item["starts"])
        padded_span = torch.zeros((max_len, max_len), dtype=torch.long)
        padded_span[:curr_len, :curr_len] = item["span_labels"]
        spans_list.append(padded_span)

    # 1D padding for token classification.
    starts_padded = pad_sequence(starts_list, batch_first=True, padding_value=-1)
    token_mask = starts_padded != -1
    starts_padded[starts_padded == -1] = 0

    # 2D padding for token-pair classification: B x N x N
    spans_padded = torch.stack(spans_list)
    # 2D length mask: B x N x 1 & B x 1 x N -> (B, N, N)
    valid_len_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1)
    # 2. Causal j >= i mask: B x N x N
    upper_tri_mask = torch.triu(
        torch.ones((max_len, max_len), dtype=torch.bool),
        diagonal=0,
    )
    # Mask all not start token positions: (B X N X 1)
    is_start_mask = starts_padded.unsqueeze(2).bool()
    # Full mask is "and"ing all masks together (like attention): B x N x N
    span_loss_mask = valid_len_mask & upper_tri_mask & is_start_mask

    return {
        "sentences": sentences,  # list[list[str]]
        "starts": starts_padded,  # (B, N) - Targets for start classifier
        "spans": spans_padded,  # (B, N, N) - Targets for span classifier
        "token_mask": token_mask,  # (B, N) - For 1D loss
        "span_loss_mask": span_loss_mask,  # (B, N, N) - For 2D loss
        "task_id": torch.tensor([item["task_id"] for item in batch]),
    }


def entity_collate_fn(batch):
    # 1. Extract tokens using 'sentence' key
    sentences = [item["sentence"] for item in batch]
    max_len = max(len(s) for s in sentences)

    starts_list = []
    spans_list = []
    gold_label_maps = []

    for item in batch:
        starts_list.append(item["starts"])

        # 2. Build 2D binary matrix using 'entity_spans'
        binary_span_matrix = torch.zeros((max_len, max_len), dtype=torch.long)
        current_labels = {}

        # Use synchronized keys: entity_spans and entity_labels
        for (s, e), label_str in zip(item["entity_spans"], item["entity_labels"]):
            if s < max_len and e < max_len:
                binary_span_matrix[s, e] = 1
                current_labels[(s, e)] = label_str

        spans_list.append(binary_span_matrix)
        gold_label_maps.append(current_labels)

    # 3. Padding & Masking
    starts_padded = pad_sequence(starts_list, batch_first=True, padding_value=-1)
    token_mask = starts_padded != -1

    # Clean targets for loss (replace -1 with 0)
    starts_targets = starts_padded.clone()
    starts_targets[starts_targets == -1] = 0

    spans_padded = torch.stack(spans_list)

    valid_len_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1)
    upper_tri_mask = torch.triu(torch.ones((max_len, max_len), dtype=torch.bool), 0)
    is_start_mask = starts_targets.unsqueeze(2).bool()
    span_loss_mask = valid_len_mask & upper_tri_mask & is_start_mask

    return {
        "sentences": sentences,
        "starts": starts_targets,
        "spans": spans_padded,
        "gold_labels": gold_label_maps,
        "token_mask": token_mask,
        "span_loss_mask": span_loss_mask,
        "task_id": torch.tensor([item["task_id"] for item in batch])
    }


def debug_print_entity_batch(batch):
    sentences = batch["sentences"]
    gold_labels_list = batch["gold_labels"]
    task_ids = batch["task_id"]

    print(f"--- Batch Debug (Size: {len(sentences)}) ---")

    for i, (tokens, labels_dict) in enumerate(zip(sentences, gold_labels_list)):
        task_name = "Entity" if task_ids[i] == 1 else "Mention"
        print(f"\n[Sentence {i}] Task: {task_name}")
        print(f"Text: {' '.join(tokens)}")

        if not labels_dict:
            print("  No entities found.")
            continue

        print("  Entities:")
        for (start, end), label in labels_dict.items():
            # Slice tokens: 'end' is exclusive in our logic
            entity_text = " ".join(tokens[start:end])
            print(f"    - [{label}] '{entity_text}' (indices: {start}:{end})")


@DataRegistry.register("litbank_mentions")
def make_litbank(
    repo_id: str = "coref-data/litbank_raw",
    tag: str = "split_0",
    batch_size: int = 4,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """Reformat litbank to as a sentence-level mention-detection dataset."""
    litbank = load_dataset(repo_id, tag)
    litbank_sentences_mentions = litbank.map(mentions_by_sentence).map(
        flatten_to_sentences, batched=True, remove_columns=litbank["train"].column_names
    )
    no = 0
    for i in range(len(litbank_sentences_mentions["train"])):
        mentions = litbank_sentences_mentions["train"][i]["mentions"]
        # Check if None or empty
        if mentions is None or len(mentions) == 0:
            no += 1
    print(f"Training sentences without mentions: {no}.")
    bs = batch_size
    train = LitBankMentionDataset(litbank_sentences_mentions["train"])
    val = LitBankMentionDataset(litbank_sentences_mentions["validation"])
    test = LitBankMentionDataset(litbank_sentences_mentions["test"])
    train_loader = DataLoader(train, batch_size=bs, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val, batch_size=bs, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test, batch_size=bs, shuffle=False, collate_fn=collate_fn)
    # Sanity check
    try:
        next(iter(train_loader))
    except Exception as e:
        raise e
    return DataBlob(train_loader, val_loader, test_loader)


@DataRegistry.register("litbank_entities")
def make_litbank_entity(
    repo_id: str = "coref-data/litbank_raw",
    tag: str = "split_0",
    batch_size: int = 4,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    litbank = load_dataset(repo_id, tag)
    entities_data = litbank.map(
        flatten_entities,
        batched=True,
        remove_columns=litbank["train"].column_names
    )
    bs = batch_size
    train = LitBankEntityDataset(entities_data["train"])
    val = LitBankEntityDataset(entities_data["validation"])
    test = LitBankEntityDataset(entities_data["test"])
    train_loader = DataLoader(train, batch_size=bs, shuffle=True, collate_fn=entity_collate_fn)
    val_loader = DataLoader(val, batch_size=bs, shuffle=False, collate_fn=entity_collate_fn)
    test_loader = DataLoader(test, batch_size=bs, shuffle=False, collate_fn=entity_collate_fn)
    try:
        next(iter(train_loader))
    except Exception as e:
        raise e
    label2id = build_label_mapping(train_loader)
    return DataBlob(train_loader, val_loader, test_loader, label2id)