File size: 15,572 Bytes
be5f706
 
 
 
 
 
 
 
e63569d
be6a29a
be5f706
 
be6a29a
be5f706
8c50d16
 
 
be5f706
 
be6a29a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6a29a
be5f706
 
 
 
be6a29a
 
be5f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63569d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6a29a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63569d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be5f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c50d16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
"""
PyTorch Dataset for anime filename token classification.

Loads JSONL data (tokens + BIO labels) and converts to model inputs.
Handles token-ID conversion, label encoding, padding, and truncation.
"""

import json
from collections import Counter
import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Dict, List, Optional, Sequence, Tuple

from .config import Config
from .label_repairs import repair_sequel_season_labels
from .tokenizer import AnimeTokenizer


def encode_token_classification_values(
    item: Dict,
    tokenizer: AnimeTokenizer,
    label2id: Dict[str, int],
    max_length: int,
    apply_label_repairs: bool = True,
    vocab: Optional[Dict[str, int]] = None,
) -> Tuple[List[int], List[bool], List[int]]:
    tokens, labels = training_labels_for_tokenizer(item, tokenizer, apply_label_repairs)

    token_vocab = vocab if vocab is not None else tokenizer.get_vocab()
    unk_id = tokenizer.unk_token_id if tokenizer.unk_token_id is not None else 1
    input_ids = [token_vocab.get(token, unk_id) for token in tokens]
    input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id]

    label_ids: List[int] = [-100]
    label_ids.extend(label2id.get(label, 0) for label in labels)
    label_ids.append(-100)

    attention_mask = [1] * len(input_ids)

    if len(input_ids) > max_length:
        input_ids = [input_ids[0]] + input_ids[1:max_length - 1] + [input_ids[-1]]
        label_ids = [label_ids[0]] + label_ids[1:max_length - 1] + [label_ids[-1]]
        attention_mask = [attention_mask[0]] + attention_mask[1:max_length - 1] + [attention_mask[-1]]

    pad_len = max_length - len(input_ids)
    if pad_len > 0:
        input_ids += [tokenizer.pad_token_id] * pad_len
        label_ids += [-100] * pad_len
        attention_mask += [0] * pad_len

    return input_ids, [bool(value) for value in attention_mask], label_ids


def encode_token_classification_item(
    item: Dict,
    tokenizer: AnimeTokenizer,
    label2id: Dict[str, int],
    max_length: int,
    apply_label_repairs: bool = True,
    vocab: Optional[Dict[str, int]] = None,
) -> Dict[str, torch.Tensor]:
    input_ids, attention_mask, label_ids = encode_token_classification_values(
        item,
        tokenizer,
        label2id,
        max_length,
        apply_label_repairs,
        vocab,
    )

    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.bool),
        "labels": torch.tensor(label_ids, dtype=torch.long),
    }


class AnimeItemsDataset(Dataset):
    """Map-style dataset backed by already-loaded JSONL items."""

    def __init__(
        self,
        data: Sequence[Dict],
        tokenizer: AnimeTokenizer,
        label2id: Dict[str, int],
        max_length: int = 64,
        apply_label_repairs: bool = True,
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.apply_label_repairs = apply_label_repairs
        self.vocab = tokenizer.get_vocab()

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return encode_token_classification_item(
            self.data[idx],
            self.tokenizer,
            self.label2id,
            self.max_length,
            self.apply_label_repairs,
            self.vocab,
        )


class EncodedAnimeDataset(Dataset):
    """Dataset that stores padded tensors so training workers do no token work."""

    def __init__(
        self,
        data: Sequence[Dict],
        tokenizer: AnimeTokenizer,
        label2id: Dict[str, int],
        max_length: int = 64,
        device: Optional[torch.device] = None,
        apply_label_repairs: bool = True,
    ):
        target_device = device or torch.device("cpu")
        vocab = tokenizer.get_vocab()
        input_ids = np.full(
            (len(data), max_length),
            tokenizer.pad_token_id,
            dtype=np.int64,
        )
        attention_mask = np.zeros((len(data), max_length), dtype=np.bool_)
        labels = np.full((len(data), max_length), -100, dtype=np.int64)

        for idx, item in enumerate(data):
            item_input_ids, item_attention_mask, item_labels = encode_token_classification_values(
                item,
                tokenizer,
                label2id,
                max_length,
                apply_label_repairs,
                vocab,
            )
            input_ids[idx] = item_input_ids
            attention_mask[idx] = item_attention_mask
            labels[idx] = item_labels

        self.input_ids = torch.from_numpy(input_ids).to(target_device)
        self.attention_mask = torch.from_numpy(attention_mask).to(target_device)
        self.labels = torch.from_numpy(labels).to(target_device)

    def __len__(self) -> int:
        return self.input_ids.shape[0]

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx],
        }


class AnimeDataset(AnimeItemsDataset):
    """
    Dataset for anime filename token classification.

    Loads pre-tokenized data from JSONL files and prepares model inputs.
    Each sample has:
        - input_ids: token IDs with [CLS] prefix and [SEP] suffix
        - attention_mask: 1 for real tokens, 0 for padding
        - labels: integer label IDs, -100 for special/padding tokens
    """

    def __init__(
        self,
        data_path: str,
        tokenizer: AnimeTokenizer,
        label2id: Dict[str, int],
        max_length: int = 64,
    ):
        """
        Args:
            data_path: Path to JSONL file with tokens and labels.
            tokenizer: AnimeTokenizer instance.
            label2id: Mapping from label string to integer ID.
            max_length: Maximum sequence length (including special tokens).
        """
        data: List[Dict] = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    data.append(json.loads(line))
        super().__init__(data, tokenizer, label2id, max_length)


def align_tokens_for_tokenizer(
    tokens: List[str],
    labels: List[str],
    tokenizer: AnimeTokenizer,
) -> tuple[List[str], List[str]]:
    """
    Align pre-labeled JSONL samples to the selected tokenizer.

    The existing datasets store regex-tokenized samples. For the char A/B run,
    each original token is split into characters while preserving BIO spans:
    B-X stays on the first character, and the rest become I-X.
    """
    if getattr(tokenizer, "tokenizer_variant", "regex") != "char":
        return tokens, labels

    aligned_tokens: List[str] = []
    aligned_labels: List[str] = []

    for token, label in zip(tokens, labels):
        pieces = tokenizer.tokenize(token)
        if not pieces:
            continue

        aligned_tokens.extend(pieces)
        aligned_labels.append(label)

        if label.startswith(("B-", "I-")):
            continuation = "I-" + label.split("-", 1)[1]
        else:
            continuation = label
        aligned_labels.extend([continuation] * (len(pieces) - 1))

    return aligned_tokens, aligned_labels


def labels_for_tokenizer(
    item: Dict,
    tokenizer: AnimeTokenizer,
) -> Tuple[List[str], List[str]]:
    """
    Return tokens and labels in the exact tokenizer space used by the model.

    Older DMHY weak-label files store a post-processed token sequence where
    group/title brackets may be expanded even though AnimeTokenizer keeps the
    same bracketed text as one inference token. If the raw filename is present,
    project those weak labels back to character spans and then onto the current
    tokenizer output. This keeps train/eval/inference preprocessing identical.
    """
    filename = item.get("filename")
    source_tokens, source_labels, _repairs = repair_sequel_season_labels(item)
    tokenizer_variant = getattr(tokenizer, "tokenizer_variant", "regex")

    if not filename:
        return align_tokens_for_tokenizer(source_tokens, source_labels, tokenizer)

    # Current char datasets are already in the exact inference token space.
    # Avoid re-scanning every filename during training.
    if item.get("tokenizer_variant") == tokenizer_variant:
        target_tokens = tokenizer.tokenize(filename)
        if source_tokens == target_tokens:
            return source_tokens, source_labels

    projected = project_labels_from_filename(
        filename=filename,
        source_tokens=source_tokens,
        source_labels=source_labels,
        tokenizer=tokenizer,
    )
    if projected is not None:
        return projected

    # Fall back to the legacy behavior for synthetic fixtures or malformed rows.
    return align_tokens_for_tokenizer(source_tokens, source_labels, tokenizer)


def training_labels_for_tokenizer(
    item: Dict,
    tokenizer: AnimeTokenizer,
    apply_label_repairs: bool,
) -> Tuple[List[str], List[str]]:
    """Fast path for authoritative char JSONL rows used in full training."""
    tokenizer_variant = getattr(tokenizer, "tokenizer_variant", "regex")
    if not apply_label_repairs and item.get("tokenizer_variant") == tokenizer_variant:
        tokens = item.get("tokens", [])
        labels = item.get("labels", [])
        filename = item.get("filename")
        if len(tokens) == len(labels):
            if tokenizer_variant != "char" or filename is None or tokens == list(str(filename)):
                return tokens, labels
    return labels_for_tokenizer(item, tokenizer)


def token_offsets_in_text(text: str, tokens: List[str]) -> Optional[List[Tuple[int, int]]]:
    """Find token character offsets by scanning left to right."""
    offsets: List[Tuple[int, int]] = []
    cursor = 0
    for token in tokens:
        if token == "":
            offsets.append((cursor, cursor))
            continue
        start = text.find(token, cursor)
        if start < 0:
            return None
        end = start + len(token)
        offsets.append((start, end))
        cursor = end
    return offsets


def project_source_labels_to_chars(
    text: str,
    source_tokens: List[str],
    source_labels: List[str],
) -> Optional[List[str]]:
    """Project source token BIO labels to per-character entity names."""
    offsets = token_offsets_in_text(text, source_tokens)
    if offsets is None or len(source_tokens) != len(source_labels):
        return None

    char_entities = ["O"] * len(text)
    for token, label, (start, end) in zip(source_tokens, source_labels, offsets):
        if not label.startswith(("B-", "I-")):
            continue
        entity = label.split("-", 1)[1]

        # Bracketed single-token metadata in older data often includes the
        # brackets in the token. Keep container punctuation as O so a tokenizer
        # that splits brackets can learn cleaner boundaries.
        inner_start = start
        inner_end = end
        if len(token) >= 2 and token[0] in "[【(《" and token[-1] in "]】)》":
            inner_start += 1
            inner_end -= 1

        for pos in range(inner_start, inner_end):
            if 0 <= pos < len(char_entities):
                char_entities[pos] = entity
    return char_entities


def labels_from_char_projection(
    text: str,
    target_tokens: List[str],
    char_entities: List[str],
) -> Optional[List[str]]:
    """Assign legal IOB2 labels to target tokens from per-character entities."""
    offsets = token_offsets_in_text(text, target_tokens)
    if offsets is None:
        return None

    labels: List[str] = []
    active_entity: Optional[str] = None
    for start, end in offsets:
        span_entities = [
            char_entities[pos]
            for pos in range(start, end)
            if 0 <= pos < len(char_entities) and char_entities[pos] != "O"
        ]
        if not span_entities:
            labels.append("O")
            active_entity = None
            continue

        entity = Counter(span_entities).most_common(1)[0][0]
        prefix = "I" if active_entity == entity else "B"
        labels.append(f"{prefix}-{entity}")
        active_entity = entity
    return labels


def project_labels_from_filename(
    filename: str,
    source_tokens: List[str],
    source_labels: List[str],
    tokenizer: AnimeTokenizer,
) -> Optional[Tuple[List[str], List[str]]]:
    """
    Re-tokenize filename and project weak BIO labels onto that tokenizer.

    Returns None when source tokens cannot be aligned to the filename.
    """
    char_entities = project_source_labels_to_chars(filename, source_tokens, source_labels)
    if char_entities is None:
        return None

    target_tokens = tokenizer.tokenize(filename)
    target_labels = labels_from_char_projection(filename, target_tokens, char_entities)
    if target_labels is None or len(target_tokens) != len(target_labels):
        return None
    return target_tokens, target_labels


def create_datasets(
    data_path: str,
    tokenizer: AnimeTokenizer,
    config: Config,
) -> tuple:
    """
    Create train and validation datasets from a JSONL file.

    The file is split by the first N samples for training,
    the rest for validation based on config.train_split.

    Returns:
        (train_dataset, eval_dataset)
    """
    # Load all data to determine split
    with open(data_path, 'r', encoding='utf-8') as f:
        all_data = [json.loads(line) for line in f if line.strip()]

    split_idx = int(len(all_data) * config.train_split)
    train_data = all_data[:split_idx]
    eval_data = all_data[split_idx:]

    # Write temp files for each split
    import tempfile
    import os

    train_file = os.path.join(tempfile.gettempdir(), "anime_train.jsonl")
    eval_file = os.path.join(tempfile.gettempdir(), "anime_eval.jsonl")

    with open(train_file, 'w', encoding='utf-8') as f:
        for item in train_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    with open(eval_file, 'w', encoding='utf-8') as f:
        for item in eval_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    train_dataset = AnimeDataset(
        data_path=train_file,
        tokenizer=tokenizer,
        label2id=config.label2id,
        max_length=config.max_seq_length,
    )
    eval_dataset = AnimeDataset(
        data_path=eval_file,
        tokenizer=tokenizer,
        label2id=config.label2id,
        max_length=config.max_seq_length,
    )

    return train_dataset, eval_dataset


if __name__ == "__main__":
    # Quick test
    cfg = Config()

    tok = AnimeTokenizer()
    # Build a minimal vocab
    tok.build_vocab([["[ANi]", "test", "S2", "-", "03"],
                     ["[Baha]", "anime", "01"]])

    ds = AnimeDataset(
        data_path="data/synthetic.jsonl",
        tokenizer=tok,
        label2id=cfg.label2id,
        max_length=cfg.max_seq_length,
    )

    print(f"Dataset size: {len(ds)}")
    if len(ds) > 0:
        sample = ds[0]
        print(f"input_ids shape: {sample['input_ids'].shape}")
        print(f"attention_mask shape: {sample['attention_mask'].shape}")
        print(f"labels shape: {sample['labels'].shape}")
        print(f"input_ids: {sample['input_ids'].tolist()}")
        print(f"labels: {sample['labels'].tolist()}")
        print(f"attention_mask: {sample['attention_mask'].tolist()}")