File size: 18,802 Bytes
2d7e335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
"""
AAM Diffusion LLM — Tokenizer

Sentence-level + subword BPE hybrid tokenizer designed specifically
for AAM's sentence arrangement task.

Unlike standard tokenizers (GPT-2 BPE, SentencePiece) that tokenize
at the subword level, AAM's tokenizer is designed with SENTENCE
ARRANGEMENT in mind:

1. Sentences are the primary unit of generation (not individual tokens)
2. Within sentences, subword BPE handles individual words
3. Special tokens for graph structure (evidence, anomaly, confidence)
4. Sentence boundary markers for the diffusion model

The tokenizer maintains two levels:
- Sentence level: Where sentences begin/end, for the diffusion model
  to arrange and revise non-sequentially
- Token level: Subword units within sentences, for detailed generation

Analogi: Jin Soun tidak berpikir dalam kata-per-kata — dia
berpikir dalam KALIMAT. "Pencuri = Diancang pair. Ju Jangmok = cover."
Setiap kalimat sudah utuh, yang dia susun adalah URUTAN kalimat.
"""

from __future__ import annotations

import json
import re
import unicodedata
from collections import Counter
from pathlib import Path
from typing import Optional

from diffusion_llm.config.model_config import TokenizerConfig


# Special token IDs (always at the start of vocabulary)
SPECIAL_TOKENS = [
    "<pad>",        # 0
    "<bos>",        # 1
    "<eos>",        # 2
    "<mask>",       # 3
    "<noise>",      # 4
    "<sent>",       # 5 - sentence boundary
    "<evidence>",   # 6
    "<anomaly>",    # 7
    "<confidence>", # 8
    "<reasoning>",  # 9
    "<composition>",# 10
    "<temporal>",   # 11
    "<unk>",        # 12
]


class AamTokenizer:
    """AAM Sentence-Level + Subword BPE Hybrid Tokenizer.

    This tokenizer is specifically designed for the AAM Diffusion LLM:
    - It understands sentence boundaries (<sent> tokens)
    - It has special tokens for graph structure
    - It uses BPE for subword tokenization within sentences
    - It can encode/decode both plain text and graph-conditioned text

    Usage:
        tokenizer = AamTokenizer()
        tokenizer.train(texts, vocab_size=28000)

        # Encode text
        ids = tokenizer.encode("Berdasarkan analisis, pencuri adalah Diancang.")

        # Decode back
        text = tokenizer.decode(ids)

        # With graph structure tokens
        ids = tokenizer.encode_with_structure(
            "Pencuri = Diancang pair",
            evidence_nodes=["hefei", "diancang"],
            anomalies=[{"desc": "no external pill consumption"}],
        )
    """

    def __init__(self, config: Optional[TokenizerConfig] = None):
        """Initialize the tokenizer.

        Args:
            config: Tokenizer configuration. Uses defaults if None.
        """
        self.config = config or TokenizerConfig()

        # Build initial vocabulary with special tokens
        self.vocab: dict[str, int] = {}
        self.id_to_token: dict[int, str] = {}
        self._init_special_tokens()

        # BPE merges (learned during training)
        self.merges: dict[tuple[str, str], int] = {}
        self._bpe_cache: dict[str, str] = {}

        # Compiled patterns
        self._sentence_pattern = re.compile(
            r'(?<=[.!?])\s+|(?<=\n)\s*'
        )
        self._word_pattern = re.compile(
            r'\w+|[^\w\s]'
        )

        # Flag
        self._is_trained = False

    def _init_special_tokens(self) -> None:
        """Initialize special tokens in vocabulary."""
        for i, token in enumerate(SPECIAL_TOKENS):
            self.vocab[token] = i
            self.id_to_token[i] = token

    @property
    def pad_id(self) -> int:
        return self.vocab[self.config.pad_token]

    @property
    def bos_id(self) -> int:
        return self.vocab[self.config.bos_token]

    @property
    def eos_id(self) -> int:
        return self.vocab[self.config.eos_token]

    @property
    def mask_id(self) -> int:
        return self.vocab[self.config.mask_token]

    @property
    def noise_id(self) -> int:
        return self.vocab[self.config.noise_token]

    @property
    def sent_id(self) -> int:
        return self.vocab[self.config.sentence_boundary_token]

    @property
    def unk_id(self) -> int:
        return self.vocab.get("<unk>", len(SPECIAL_TOKENS) - 1)

    @property
    def vocab_size(self) -> int:
        """Current vocabulary size."""
        return len(self.vocab)

    @property
    def is_trained(self) -> bool:
        """Whether the tokenizer has been trained."""
        return self._is_trained

    def train(
        self,
        texts: list[str],
        vocab_size: Optional[int] = None,
    ) -> None:
        """Train the BPE tokenizer on a corpus.

        Args:
            texts: List of training texts.
            vocab_size: Target vocabulary size. Uses config if None.
        """
        target_vocab = vocab_size or self.config.bpe_vocab_size

        # Step 1: Pre-tokenize into words
        word_freqs: Counter = Counter()
        for text in texts:
            words = self._pre_tokenize(text)
            for word in words:
                word_freqs[word] += 1

        # Step 2: Initialize character-level vocabulary
        char_vocab: set[str] = set()
        for word in word_freqs:
            for char in word:
                char_vocab.add(char)

        # Add character tokens to vocabulary
        for char in sorted(char_vocab):
            if char not in self.vocab:
                idx = len(self.vocab)
                self.vocab[char] = idx
                self.id_to_token[idx] = char

        # Step 3: Split words into character sequences
        word_splits: dict[str, list[str]] = {}
        for word in word_freqs:
            word_splits[word] = list(word)
            # Add end-of-word marker
            if len(word_splits[word]) > 1:
                word_splits[word][-1] = word_splits[word][-1] + "</w>"

        # Step 4: Learn BPE merges
        n_merges = target_vocab - len(self.vocab)
        for i in range(n_merges):
            # Count pairs
            pair_freqs: Counter = Counter()
            for word, freq in word_freqs.items():
                symbols = word_splits.get(word, [])
                for j in range(len(symbols) - 1):
                    pair = (symbols[j], symbols[j + 1])
                    pair_freqs[pair] += freq

            if not pair_freqs:
                break

            # Find most frequent pair
            best_pair = pair_freqs.most_common(1)[0][0]

            # Record merge
            self.merges[best_pair] = i

            # Apply merge
            new_symbol = best_pair[0] + best_pair[1]
            for word in word_splits:
                symbols = word_splits[word]
                new_symbols = []
                j = 0
                while j < len(symbols):
                    if (
                        j < len(symbols) - 1
                        and symbols[j] == best_pair[0]
                        and symbols[j + 1] == best_pair[1]
                    ):
                        new_symbols.append(new_symbol)
                        j += 2
                    else:
                        new_symbols.append(symbols[j])
                        j += 1
                word_splits[word] = new_symbols

            # Add merged token to vocabulary
            if new_symbol not in self.vocab:
                idx = len(self.vocab)
                self.vocab[new_symbol] = idx
                self.id_to_token[idx] = new_symbol

        self._is_trained = True
        self._bpe_cache.clear()

    def _pre_tokenize(self, text: str) -> list[str]:
        """Pre-tokenize text into words.

        Args:
            text: Input text.

        Returns:
            List of words.
        """
        # Normalize unicode
        text = unicodedata.normalize("NFC", text)
        # Split into words and punctuation
        words = self._word_pattern.findall(text.lower())
        return words

    def _bpe_encode(self, word: str) -> list[str]:
        """Apply BPE to a single word.

        Args:
            word: Input word (lowercase).

        Returns:
            List of BPE tokens.
        """
        if word in self._bpe_cache:
            return self._bpe_cache[word].split()

        # Start with character-level split
        symbols = list(word)
        if len(symbols) > 1:
            symbols[-1] = symbols[-1] + "</w>"

        # Apply merges in order
        while len(symbols) > 1:
            # Find the pair with the lowest merge rank
            best_pair = None
            best_rank = float("inf")

            for i in range(len(symbols) - 1):
                pair = (symbols[i], symbols[i + 1])
                rank = self.merges.get(pair, float("inf"))
                if rank < best_rank:
                    best_rank = rank
                    best_pair = pair

            if best_pair is None or best_rank == float("inf"):
                break

            # Apply merge
            new_symbol = best_pair[0] + best_pair[1]
            new_symbols = []
            i = 0
            while i < len(symbols):
                if (
                    i < len(symbols) - 1
                    and symbols[i] == best_pair[0]
                    and symbols[i + 1] == best_pair[1]
                ):
                    new_symbols.append(new_symbol)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            symbols = new_symbols

        # Cache result
        self._bpe_cache[word] = " ".join(symbols)
        return symbols

    def encode(self, text: str, add_special: bool = True) -> list[int]:
        """Encode text to token IDs.

        The encoding process:
        1. Split text into sentences
        2. Insert sentence boundary tokens between sentences
        3. BPE-encode each word within sentences
        4. Add BOS/EOS tokens if requested

        Args:
            text: Input text.
            add_special: Whether to add BOS/EOS tokens.

        Returns:
            List of token IDs.
        """
        ids = []

        if add_special:
            ids.append(self.bos_id)

        # Split into sentences
        sentences = self._split_sentences(text)

        for i, sentence in enumerate(sentences):
            if i > 0:
                ids.append(self.sent_id)  # Sentence boundary

            # Tokenize words in the sentence
            words = self._pre_tokenize(sentence)
            for word in words:
                if self._is_trained:
                    bpe_tokens = self._bpe_encode(word)
                    for token in bpe_tokens:
                        if token in self.vocab:
                            ids.append(self.vocab[token])
                        else:
                            ids.append(self.unk_id)
                else:
                    # Fallback: character-level encoding
                    for char in word:
                        if char in self.vocab:
                            ids.append(self.vocab[char])
                        else:
                            ids.append(self.unk_id)

        if add_special:
            ids.append(self.eos_id)

        return ids

    def encode_with_structure(
        self,
        text: str,
        evidence_nodes: Optional[list[str]] = None,
        compositions: Optional[list[str]] = None,
        anomalies: Optional[list[str]] = None,
        reasoning_steps: Optional[list[str]] = None,
        confidence: Optional[float] = None,
    ) -> list[int]:
        """Encode text with graph structure tokens.

        Adds structural tokens that represent the graph conditioning,
        so the model knows what kind of evidence/anomalies it's
        generating from.

        Args:
            text: The narrative text.
            evidence_nodes: List of evidence node labels.
            compositions: List of composition descriptions.
            anomalies: List of anomaly descriptions.
            reasoning_steps: List of reasoning step descriptions.
            confidence: Overall confidence score.

        Returns:
            List of token IDs with structure tokens.
        """
        ids = [self.bos_id]

        # Evidence section
        if evidence_nodes:
            ids.append(self.vocab["<evidence>"])
            for node in evidence_nodes:
                node_ids = self.encode(node, add_special=False)
                ids.extend(node_ids)
            ids.append(self.vocab["<evidence>"])  # Close section

        # Anomaly section
        if anomalies:
            ids.append(self.vocab["<anomaly>"])
            for anomaly in anomalies:
                anom_ids = self.encode(anomaly, add_special=False)
                ids.extend(anom_ids)
            ids.append(self.vocab["<anomaly>"])

        # Reasoning section
        if reasoning_steps:
            ids.append(self.vocab["<reasoning>"])
            for step in reasoning_steps:
                step_ids = self.encode(step, add_special=False)
                ids.extend(step_ids)
                ids.append(self.sent_id)
            ids.append(self.vocab["<reasoning>"])

        # Confidence
        if confidence is not None:
            ids.append(self.vocab["<confidence>"])
            # Encode confidence as a token (discretized)
            conf_bucket = min(int(confidence * 10), 9)
            conf_token = f"<conf_{conf_bucket}>"
            if conf_token in self.vocab:
                ids.append(self.vocab[conf_token])

        # Composition section
        if compositions:
            ids.append(self.vocab["<composition>"])
            for comp in compositions:
                comp_ids = self.encode(comp, add_special=False)
                ids.extend(comp_ids)
                ids.append(self.sent_id)
            ids.append(self.vocab["<composition>"])

        # Main narrative
        narrative_ids = self.encode(text, add_special=False)
        ids.extend(narrative_ids)

        ids.append(self.eos_id)
        return ids

    def decode(self, ids: list[int], skip_special: bool = False) -> str:
        """Decode token IDs back to text.

        Args:
            ids: List of token IDs.
            skip_special: Whether to skip special tokens in output.

        Returns:
            Decoded text string.
        """
        special_ids = set()
        if skip_special:
            for token in SPECIAL_TOKENS:
                if token in self.vocab:
                    special_ids.add(self.vocab[token])

        tokens = []
        for id_ in ids:
            if skip_special and id_ in special_ids:
                continue
            if id_ in self.id_to_token:
                tokens.append(self.id_to_token[id_])
            else:
                tokens.append("<unk>")

        # Join and clean up BPE tokens
        text = "".join(tokens)
        text = text.replace("</w>", " ")
        # Clean up sentence boundaries
        text = text.replace("<sent>", ". ")
        # Clean up multiple spaces
        text = re.sub(r'\s+', ' ', text).strip()

        return text

    def _split_sentences(self, text: str) -> list[str]:
        """Split text into sentences.

        Args:
            text: Input text.

        Returns:
            List of sentence strings.
        """
        sentences = self._sentence_pattern.split(text)
        return [s.strip() for s in sentences if s.strip()]

    def pad_sequence(
        self,
        ids: list[int],
        max_len: int,
        pad_id: Optional[int] = None,
    ) -> list[int]:
        """Pad a sequence to max_len.

        Args:
            ids: Token IDs.
            max_len: Target length.
            pad_id: Padding token ID. Uses config if None.

        Returns:
            Padded sequence.
        """
        padding_id = pad_id if pad_id is not None else self.pad_id
        if len(ids) >= max_len:
            return ids[:max_len]
        return ids + [padding_id] * (max_len - len(ids))

    def get_sentence_boundaries(self, ids: list[int]) -> list[int]:
        """Find sentence boundary positions in a token sequence.

        This is used by the diffusion model to identify which tokens
        belong to which sentence, enabling non-sequential generation
        and revision at the sentence level.

        Args:
            ids: Token IDs.

        Returns:
            List of indices where sentence boundaries occur.
        """
        boundaries = []
        for i, id_ in enumerate(ids):
            if id_ == self.sent_id:
                boundaries.append(i)
        return boundaries

    def save(self, path: str | Path) -> None:
        """Save tokenizer to file.

        Args:
            path: Output file path (JSON).
        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)

        data = {
            "config": {
                "bpe_vocab_size": self.config.bpe_vocab_size,
                "max_sentences": self.config.max_sentences,
                "sentence_boundary_token": self.config.sentence_boundary_token,
                "pad_token": self.config.pad_token,
                "bos_token": self.config.bos_token,
                "eos_token": self.config.eos_token,
                "mask_token": self.config.mask_token,
                "noise_token": self.config.noise_token,
                "min_frequency": self.config.min_frequency,
            },
            "vocab": self.vocab,
            "merges": {f"{k[0]}|||{k[1]}": v for k, v in self.merges.items()},
            "is_trained": self._is_trained,
        }

        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

    @classmethod
    def load(cls, path: str | Path) -> AamTokenizer:
        """Load tokenizer from file.

        Args:
            path: Input file path (JSON).

        Returns:
            Loaded AamTokenizer.
        """
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)

        config = TokenizerConfig(**data.get("config", {}))
        tokenizer = cls(config=config)

        # Restore vocabulary
        tokenizer.vocab = data["vocab"]
        tokenizer.id_to_token = {int(v): k for k, v in data["vocab"].items()}

        # Restore merges
        tokenizer.merges = {}
        for k_str, v in data.get("merges", {}).items():
            parts = k_str.split("|||")
            tokenizer.merges[(parts[0], parts[1])] = v

        tokenizer._is_trained = data.get("is_trained", False)

        return tokenizer

    def __len__(self) -> int:
        return self.vocab_size

    def __repr__(self) -> str:
        status = "trained" if self._is_trained else "untrained"
        return (
            f"AamTokenizer(vocab_size={self.vocab_size}, "
            f"merges={len(self.merges)}, status={status})"
        )