Wolfvin commited on
Commit
2d7e335
·
verified ·
1 Parent(s): cc1beb2

AAM Diffusion LLM v1.0 — The Body of Aphantasic Abstraction Model

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AAM Diffusion LLM v1.0 — HuggingFace Repository Files
2
+
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ *.so
10
+ .env
11
+ output/
12
+ aam-diffusion-v1/
13
+ *.log
14
+ .DS_Store
README.md ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - id
4
+ - en
5
+ license: mit
6
+ library_name: pytorch
7
+ tags:
8
+ - diffusion
9
+ - text-generation
10
+ - aam
11
+ - aphantasic-abstraction-model
12
+ - sentence-arrangement
13
+ - graph-conditioned
14
+ - indonesian
15
+ ---
16
+
17
+ # AAM Diffusion LLM v1.0
18
+
19
+ > **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
20
+
21
+ The dedicated "body" of the **Aphantasic Abstraction Model (AAM)** — a small diffusion LLM specifically trained to arrange sentences from structured graph data.
22
+
23
+ ## What is this?
24
+
25
+ This is **NOT** a general-purpose LLM. This is a **SPECIALIZED sentence composer** that:
26
+ - Takes **graph-structured conditioning** as input (evidence nodes, anomalies, reasoning chains, confidence scores)
27
+ - Produces **coherent natural language narratives** through iterative denoising (diffusion process)
28
+ - **Cannot hallucinate** — it can only narrate what the graph knows
29
+
30
+ ### Why Diffusion (Not Autoregressive)?
31
+
32
+ 1. **Non-sequential generation** — Can revise earlier parts while generating later parts, mirroring how thoughts form: vague intuition → clearer pattern → explicit narrative
33
+ 2. **Graph conditioning** — The entire graph structure is encoded as conditioning, not just a text prefix
34
+ 3. **Anti-hallucination by design** — Trained exclusively on Graph→Narrative pairs, the model has no capability to generate information outside the graph conditioning
35
+
36
+ ## Architecture
37
+
38
+ ```
39
+ ┌──────────────────────────────────────────────────────────┐
40
+ │ AAM = 1 Pikiran + 1 Tubuh │
41
+ │ │
42
+ │ Pikiran (Mind) = RSVS Knowledge Graph │
43
+ │ - Structural memory — perfect recall │
44
+ │ - Relational — understands concept connections │
45
+ │ - Confidence scores — knows certainty levels │
46
+ │ │
47
+ │ Tubuh (Body) = AAM Diffusion LLM (This Model) │
48
+ │ ┌─────────────────────────────────────────────┐ │
49
+ │ │ Graph Conditioning Encoder │ │
50
+ │ │ ├─ Evidence Node Encoder │ │
51
+ │ │ ├─ Composition Encoder │ │
52
+ │ │ ├─ Anomaly Encoder │ │
53
+ │ │ ├─ Reasoning Chain Encoder │ │
54
+ │ │ ├─ Confidence Embedding │ │
55
+ │ │ ├─ Temporal Embedding │ │
56
+ │ │ └─ Graph Attention Layers │ │
57
+ │ │ ↓ (cross-attention keys/values) │ │
58
+ │ ├─────────────────────────────────────────────┤ │
59
+ │ │ Diffusion Transformer (Denoiser) │ │
60
+ │ │ ├─ Token Embedding │ │
61
+ │ │ ├─ Timestep Embedding (sinusoidal) │ │
62
+ │ │ ├─ N × TransformerBlock: │ │
63
+ │ │ │ ├─ AdaptiveLayerNorm + Self-Attention │ │
64
+ │ │ │ ├─ AdaptiveLayerNorm + Cross-Attention │ │
65
+ │ │ │ └─ AdaptiveLayerNorm + Feed-Forward │ │
66
+ │ │ └─ Output Projection │ │
67
+ │ │ ↓ (predicted noise) │ │
68
+ │ ├─────────────────────────────────────────────┤ │
69
+ │ │ Noise Scheduler │ │
70
+ │ │ ├─ Forward: x_0 + noise → x_t │ │
71
+ │ │ └─ Reverse: x_t → denoise → x_{t-1} │ │
72
+ │ └─────────────────────────────────────────────┘ │
73
+ │ │
74
+ │ Training: Graph→Narrative pairs │
75
+ │ Inference: Noise → N denoising steps → Narrative │
76
+ └──────────────────────────────────────────────────────────┘
77
+ ```
78
+
79
+ ## Model Details (v1.0 — Trained)
80
+
81
+ | Parameter | Value |
82
+ |-----------|-------|
83
+ | Architecture | Diffusion Transformer with Graph Conditioning |
84
+ | d_model | 64 |
85
+ | n_layers | 2 |
86
+ | n_heads | 4 |
87
+ | d_ff | 128 |
88
+ | **Total Parameters** | **311,670 (311.7K)** |
89
+ | Vocab size | 500 (BPE + special tokens) |
90
+ | Max sequence length | 32 |
91
+ | Diffusion timesteps (train) | 50 |
92
+ | Diffusion timesteps (inference) | 5 |
93
+ | Noise schedule | Cosine |
94
+ | Prediction type | Epsilon (noise prediction) |
95
+ | Sampling method | DDIM |
96
+
97
+ > **Note**: This v1.0 model was trained with minimal parameters (311K) for proof-of-concept on CPU. For production use, scale up to the `base` (170M) or `medium` (300M) configurations provided in the framework.
98
+
99
+ ## Model Sizes (Framework Supports)
100
+
101
+ | Size | d_model | Layers | Heads | Params | Recommended For |
102
+ |------|---------|--------|-------|--------|----------------|
103
+ | tiny | 256 | 4 | 4 | ~25M | Quick testing, debugging |
104
+ | small | 512 | 8 | 8 | ~70M | Development, prototyping |
105
+ | **base** | **768** | **12** | **12** | **~170M** | **Recommended for training** |
106
+ | medium | 1024 | 12 | 16 | ~300M | Final training, best quality |
107
+
108
+ ## Usage
109
+
110
+ ### Quick Start
111
+
112
+ ```python
113
+ from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
114
+
115
+ # Load model
116
+ config = AamDiffusionConfig.from_json("config.json")
117
+ model = AamDiffusionModel.load("model.pt", device="cpu")
118
+ tokenizer = AamTokenizer.load("tokenizer.json")
119
+
120
+ # Create generator
121
+ generator = AamGenerator(model, tokenizer, config)
122
+
123
+ # Generate narrative from graph conditioning
124
+ result = generator.generate(
125
+ trigger="Siapa yang mencuri Snow Plum Pill?",
126
+ evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
127
+ anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
128
+ reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
129
+ source_trust=0.85,
130
+ )
131
+
132
+ print(result.narrative)
133
+ print(f"Confidence: {result.confidence:.1%}")
134
+ print(f"Steps: {result.n_diffusion_steps}")
135
+ ```
136
+
137
+ ### Training Your Own Model
138
+
139
+ ```python
140
+ from diffusion_llm import AamDiffusionConfig, get_default_config
141
+ from diffusion_llm.training import AamTrainer, GraphNarrativeDataset
142
+ from diffusion_llm.data import DataPipeline
143
+
144
+ # Get config for your desired size
145
+ config = get_default_config("base") # 170M params
146
+
147
+ # Prepare data pipeline
148
+ pipeline = DataPipeline(config)
149
+ tokenizer, train_loader, val_loader = pipeline.prepare()
150
+
151
+ # Create and train model
152
+ model = AamDiffusionModel(config)
153
+ trainer = AamTrainer(config, model, tokenizer, train_loader.dataset, val_loader.dataset)
154
+ trainer.train()
155
+ ```
156
+
157
+ ### Command Line
158
+
159
+ ```bash
160
+ # Train with default config
161
+ python diffusion_llm/scripts/train.py --model_size base
162
+
163
+ # Generate narratives
164
+ python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --generate
165
+
166
+ # Export model
167
+ python diffusion_llm/scripts/export.py --checkpoint output/best.pt --output model_export/
168
+ ```
169
+
170
+ ## Philosophy
171
+
172
+ **AAM = 1 Pikiran + 1 Tubuh (1 Mind + 1 Body)**
173
+
174
+ - **Mind** = RSVS Knowledge Graph (structural memory, perfect recall, relational understanding)
175
+ - **Body** = This Diffusion LLM (sentence arranger, graph-conditioned, anti-hallucination)
176
+
177
+ Unlike using a rented LLM (GPT, Claude, etc.) as the "body", this model is **specifically trained for AAM**:
178
+ - It **cannot generate** information not present in the graph conditioning
179
+ - It **arranges sentences** based on structured evidence
180
+ - It uses **diffusion** (non-sequential generation) instead of autoregressive generation
181
+ - It is **small** but **specialized** — like Jin Soun's body in the novel, it may be "third-rate" but it's **his own**
182
+
183
+ > Jin Soun bukan orang yang menyewa tubuh orang lain untuk berbicara.
184
+ > Dia punya tubuh sendiri — lemah, third-rate, tapi MILIKNYA.
185
+ > Karena tubuhnya khusus dilatih untuk mengeksekusi perintah dari
186
+ > pikirannya (bukan pikiran orang lain), outputnya lebih terarah
187
+ > daripada orang yang punya tubuh lebih kuat tapi pikiran lebih lemah.
188
+
189
+ ## Framework Structure
190
+
191
+ ```
192
+ diffusion_llm/
193
+ ├── __init__.py # Public API
194
+ ├── config/
195
+ │ └── model_config.py # All configuration dataclasses
196
+ ├── tokenizer/
197
+ │ └── aam_tokenizer.py # Sentence-level + BPE hybrid tokenizer
198
+ ├── model/
199
+ │ ├── noise_scheduler.py # Forward/reverse diffusion process
200
+ │ ├── graph_encoder.py # Graph conditioning encoder
201
+ │ ├── diffusion_transformer.py # Core denoising transformer
202
+ │ └── aam_diffusion_model.py # Complete model (combines all)
203
+ ├── training/
204
+ │ ├── losses.py # Loss functions (MSE, MAE, Huber, weighted)
205
+ │ ├── dataset.py # GraphNarrative dataset
206
+ │ └── trainer.py # Training loop with AMP, EMA, etc.
207
+ ├── inference/
208
+ │ └── generator.py # Inference pipeline
209
+ ├── data/
210
+ │ ├── synthetic_generator.py # Synthetic training data
211
+ │ └── data_pipeline.py # Data preparation pipeline
212
+ ├── scripts/
213
+ │ ├── train.py # Training entry point
214
+ │ ├── evaluate.py # Evaluation & generation
215
+ │ └── export.py # Model export
216
+ └── tests/
217
+ ├── test_model.py # Model component tests
218
+ └── test_scheduler.py # Noise scheduler tests
219
+ ```
220
+
221
+ ## Training Data Format
222
+
223
+ Data training dalam format JSONL:
224
+
225
+ ```json
226
+ {
227
+ "narrative": "Berdasarkan analisis, Diancang Five Swords mencuri Snow Plum Pill.",
228
+ "trigger": "Siapa yang mencuri Snow Plum Pill?",
229
+ "evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok"],
230
+ "compositions": [],
231
+ "confidence_map": {"Hefei": 0.9, "Diancang Five Swords": 0.85},
232
+ "anomalies": ["Tidak ada konsumsi pil baru di pasar gelap"],
233
+ "reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
234
+ "source_trust": 0.85,
235
+ "language": "id"
236
+ }
237
+ ```
238
+
239
+ ## License
240
+
241
+ MIT
242
+
243
+ ## Citation
244
+
245
+ ```bibtex
246
+ @software{aam_diffusion_llm_v1,
247
+ title = {AAM Diffusion LLM: The Body of Aphantasic Abstraction Model},
248
+ author = {AAM Team},
249
+ year = {2026},
250
+ description = {A specialized diffusion LLM for sentence arrangement from graph-structured data},
251
+ url = {https://huggingface.co/aam-diffusion-v1}
252
+ }
253
+ ```
config.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "d_model": 64,
4
+ "n_layers": 2,
5
+ "n_heads": 4,
6
+ "d_ff": 128,
7
+ "dropout": 0.1,
8
+ "activation": "gelu",
9
+ "max_seq_len": 32,
10
+ "vocab_size": 500,
11
+ "pos_encoding_type": "learned",
12
+ "use_flash_attention": false,
13
+ "norm_type": "layernorm",
14
+ "norm_eps": 1e-06,
15
+ "init_std": 0.02
16
+ },
17
+ "diffusion": {
18
+ "n_timesteps": 50,
19
+ "n_inference_steps": 5,
20
+ "schedule_type": "cosine",
21
+ "beta_start": 0.0001,
22
+ "beta_end": 0.02,
23
+ "prediction_type": "epsilon",
24
+ "sampling_method": "ddim",
25
+ "eta_ddim": 0.0,
26
+ "clip_sample_max": 5.0,
27
+ "clip_sample_min": -5.0,
28
+ "loss_type": "mse",
29
+ "loss_weighting": "none",
30
+ "p2_gamma": 1.0,
31
+ "p2_k": 1.0
32
+ },
33
+ "graph_encoder": {
34
+ "d_graph": 32,
35
+ "n_graph_layers": 1,
36
+ "n_graph_heads": 2,
37
+ "max_evidence_nodes": 3,
38
+ "max_compositions": 2,
39
+ "max_anomalies": 2,
40
+ "max_reasoning_steps": 2,
41
+ "conditioning_method": "cross_attention",
42
+ "embed_confidence": false,
43
+ "embed_temporal": false
44
+ },
45
+ "tokenizer": {
46
+ "bpe_vocab_size": 500,
47
+ "max_sentences": 32,
48
+ "sentence_boundary_token": "<sent>",
49
+ "pad_token": "<pad>",
50
+ "bos_token": "<bos>",
51
+ "eos_token": "<eos>",
52
+ "mask_token": "<mask>",
53
+ "noise_token": "<noise>",
54
+ "evidence_token": "<evidence>",
55
+ "anomaly_token": "<anomaly>",
56
+ "confidence_token": "<confidence>",
57
+ "reasoning_token": "<reasoning>",
58
+ "composition_token": "<composition>",
59
+ "temporal_token": "<temporal>",
60
+ "min_frequency": 2,
61
+ "dropout_rate": 0.0
62
+ },
63
+ "training": {
64
+ "learning_rate": 0.001,
65
+ "weight_decay": 0.01,
66
+ "adam_beta1": 0.9,
67
+ "adam_beta2": 0.999,
68
+ "adam_eps": 1e-08,
69
+ "lr_schedule": "cosine",
70
+ "warmup_steps": 5,
71
+ "batch_size": 2,
72
+ "gradient_accumulation_steps": 4,
73
+ "max_steps": 50,
74
+ "max_epochs": 100,
75
+ "dropout": 0.1,
76
+ "grad_clip_norm": 1.0,
77
+ "use_amp": false,
78
+ "amp_dtype": "bf16",
79
+ "save_every_steps": 5000,
80
+ "eval_every_steps": 1000,
81
+ "keep_last_n_checkpoints": 3,
82
+ "use_ema": true,
83
+ "ema_decay": 0.9999,
84
+ "train_data_path": "",
85
+ "val_data_path": "",
86
+ "num_workers": 0,
87
+ "log_every_steps": 100,
88
+ "wandb_project": "aam-diffusion-llm",
89
+ "wandb_run_name": ""
90
+ },
91
+ "inference": {
92
+ "n_steps": 5,
93
+ "temperature": 1.0,
94
+ "top_k": 50,
95
+ "top_p": 0.95,
96
+ "repetition_penalty": 1.2,
97
+ "max_output_sentences": 16,
98
+ "language": "id"
99
+ },
100
+ "model_name": "aam-diffusion-v1.0",
101
+ "output_dir": "./aam-diffusion-v1",
102
+ "seed": 42,
103
+ "aam_mind_source": "rsvs_graph",
104
+ "aam_body_type": "specialized_diffusion"
105
+ }
diffusion_llm/README.md ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AAM Diffusion LLM Framework
2
+
3
+ > **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
4
+
5
+ Framework khusus untuk melatih Diffusion LLM yang menjadi "tubuh" (body) dari Aphantasic Abstraction Model (AAM). Ini BUKAN LLM umum — ini model yang KHUSUS dilatih untuk menyusun kalimat dari data graph yang terstruktur.
6
+
7
+ ---
8
+
9
+ ## Filosofi
10
+
11
+ ### Kenapa Bukan LLM Umum?
12
+
13
+ Konsep sebelumnya: "tubuh Jin Soun = LLM umum (GPT, Claude, dll.)" — ini **salah besar**.
14
+
15
+ | Aspek | LLM Umum (Sewaan) | AAM Diffusion LLM (Milik Sendiri) |
16
+ |-------|-------------------|-----------------------------------|
17
+ | Input | Prompt teks | Graph conditioning (evidence, anomaly, dll.) |
18
+ | Output | Teks probabilistik | Narrative yang grounded di graph |
19
+ | Hallucination | BISA mengarang | TIDAK BISA — hanya menarasikan apa yang graph ketahui |
20
+ | Tujuan | General purpose | Khusus menyusun kalimat dari graph |
21
+ | Ukuran | 7B-175B params | 100M-500M params |
22
+ | Metode | Autoregressive | Diffusion (non-sequential) |
23
+ | Identitas | Sewaan | MILIK AAM sendiri |
24
+
25
+ ### Kenapa Diffusion (Bukan Autoregressive)?
26
+
27
+ 1. **Non-sequential** — Bisa merevisi bagian awal saat generating bagian akhir. Mirip cara Jin Soun membentuk pikiran: vague → clearer → explicit.
28
+
29
+ 2. **Graph conditioning** — Seluruh graph bisa di-encode sebagai conditioning, bukan hanya prefix. Autoregressive hanya bisa melihat "apa yang sudah di-generate sebelumnya."
30
+
31
+ 3. **Coherent long-form** — Diffusion menghasilkan teks yang lebih koheren untuk narasi panjang karena setiap token "mengetahui" tentang token lain.
32
+
33
+ 4. **Anti-hallucination** — Model dilatih KHUSUS untuk Graph→Narrative, tidak punya kapabilitas mengarang informasi di luar graph.
34
+
35
+ ---
36
+
37
+ ## Arsitektur
38
+
39
+ ```
40
+ ┌──────────────────────────────────────────────────────────┐
41
+ │ AAM = 1 Pikiran + 1 Tubuh │
42
+ │ │
43
+ │ Pikiran (Mind) = RSVS Knowledge Graph │
44
+ │ - Structural memory — mengingat SEMUA │
45
+ │ - Relational — memahami koneksi antar konsep │
46
+ │ - Perfect recall — tidak pernah lupa │
47
+ │ - Confidence scores — tahu apa yang pasti vs ragu │
48
+ │ │
49
+ │ Tubuh (Body) = AAM Diffusion LLM │
50
+ │ ┌─────────────────────────────────────────────┐ │
51
+ │ │ Graph Conditioning Encoder │ │
52
+ │ │ ├─ Evidence Node Encoder │ │
53
+ │ │ ├─ Composition Encoder │ │
54
+ │ │ ├─ Anomaly Encoder │ │
55
+ │ │ ├─ Reasoning Chain Encoder │ │
56
+ │ │ ├─ Confidence Embedding │ │
57
+ │ │ ├─ Temporal Embedding │ │
58
+ │ │ └─ Graph Attention Layers │ │
59
+ │ │ ↓ (cross-attention keys/values) │ │
60
+ │ ├─────────────────────────────────────────────┤ │
61
+ │ │ Diffusion Transformer (Denoiser) │ │
62
+ │ │ ├─ Token Embedding │ │
63
+ │ │ ├─ Timestep Embedding (sinusoidal) │ │
64
+ │ │ ├─ N × TransformerBlock: │ │
65
+ │ │ │ ├─ AdaptiveLayerNorm + Self-Attention │ │
66
+ │ │ │ ├─ AdaptiveLayerNorm + Cross-Attention │ │
67
+ │ │ │ └─ AdaptiveLayerNorm + Feed-Forward │ │
68
+ │ │ └─ Output Projection │ │
69
+ │ │ ↓ (predicted noise) │ │
70
+ │ ├─────────────────────────────────────────────┤ │
71
+ │ │ Noise Scheduler │ │
72
+ │ │ ├─ Forward: x_0 + noise → x_t │ │
73
+ │ │ └─ Reverse: x_t → denoise → x_{t-1} │ │
74
+ │ └─────────────────────────────────────────────┘ │
75
+ │ │
76
+ │ Training: Graph→Narrative pairs │
77
+ │ Inference: Noise → N denoising steps → Narrative │
78
+ └─────────────────────────────���────────────────────────────┘
79
+ ```
80
+
81
+ ---
82
+
83
+ ## Struktur Folder
84
+
85
+ ```
86
+ diffusion_llm/
87
+ ├── __init__.py # Package init with public API
88
+ ├── config/
89
+ │ ├── __init__.py
90
+ │ └── model_config.py # All configuration dataclasses
91
+ ├── tokenizer/
92
+ │ ├── __init__.py
93
+ │ └── aam_tokenizer.py # Sentence-level + BPE hybrid tokenizer
94
+ ├── model/
95
+ │ ├── __init__.py
96
+ │ ├── noise_scheduler.py # Forward/reverse diffusion process
97
+ │ ├── graph_encoder.py # Graph conditioning encoder
98
+ │ ├── diffusion_transformer.py # Core denoising transformer
99
+ │ └── aam_diffusion_model.py # Complete model (combines all)
100
+ ├── training/
101
+ │ ├── __init__.py
102
+ │ ├── losses.py # Loss functions (MSE, MAE, Huber, weighted)
103
+ │ ├── dataset.py # GraphNarrative dataset
104
+ │ └── trainer.py # Training loop with AMP, EMA, etc.
105
+ ├── inference/
106
+ │ ├── __init__.py
107
+ │ └── generator.py # Inference pipeline
108
+ ├── data/
109
+ │ ├── __init__.py
110
+ │ ├── synthetic_generator.py # Synthetic training data
111
+ │ └── data_pipeline.py # Data preparation pipeline
112
+ ├── scripts/
113
+ │ ├── train.py # Training entry point
114
+ │ ├── evaluate.py # Evaluation & generation
115
+ │ └── export.py # Model export
116
+ ├── tests/
117
+ │ ├── __init__.py
118
+ │ ├── test_scheduler.py # Noise scheduler tests
119
+ │ └── test_model.py # Model component tests
120
+ ├── requirements.txt # Python dependencies
121
+ └── README.md # This file
122
+ ```
123
+
124
+ ---
125
+
126
+ ## Quick Start
127
+
128
+ ### 1. Install Dependencies
129
+
130
+ ```bash
131
+ pip install torch numpy pytest
132
+ ```
133
+
134
+ ### 2. Generate Synthetic Data
135
+
136
+ ```python
137
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
138
+
139
+ generator = SyntheticDataGenerator(seed=42, language="id")
140
+ train_path, val_path = generator.generate_training_split(
141
+ output_dir="./data",
142
+ n_train=10000,
143
+ n_val=500,
144
+ )
145
+ ```
146
+
147
+ ### 3. Train the Model
148
+
149
+ ```bash
150
+ # Quick test with tiny model
151
+ python diffusion_llm/scripts/train.py --model_size tiny --max_steps 100
152
+
153
+ # Full training with base model
154
+ python diffusion_llm/scripts/train.py --model_size base --max_steps 500000
155
+ ```
156
+
157
+ ### 4. Generate Narratives
158
+
159
+ ```bash
160
+ # Generate samples
161
+ python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --generate
162
+
163
+ # Interactive mode
164
+ python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --interactive
165
+ ```
166
+
167
+ ### 5. Programmatic Usage
168
+
169
+ ```python
170
+ from diffusion_llm import (
171
+ AamDiffusionConfig, get_default_config,
172
+ AamDiffusionModel, AamTokenizer, AamGenerator,
173
+ )
174
+
175
+ # Load model and tokenizer
176
+ config = AamDiffusionConfig.from_json("output/config.json")
177
+ model = AamDiffusionModel.load("output/best.pt")
178
+ tokenizer = AamTokenizer.load("output/data/tokenizer.json")
179
+
180
+ # Create generator
181
+ generator = AamGenerator(model, tokenizer, config)
182
+
183
+ # Generate narrative from graph conditioning
184
+ result = generator.generate(
185
+ trigger="Siapa yang mencuri Snow Plum Pill?",
186
+ evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
187
+ anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
188
+ reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali"],
189
+ source_trust=0.85,
190
+ )
191
+
192
+ print(result.narrative)
193
+ print(f"Confidence: {result.confidence:.1%}")
194
+ print(f"Steps: {result.n_diffusion_steps}")
195
+ ```
196
+
197
+ ---
198
+
199
+ ## Model Sizes
200
+
201
+ | Size | d_model | Layers | Heads | Params | Recommended For |
202
+ |------|---------|--------|-------|--------|----------------|
203
+ | tiny | 256 | 4 | 4 | ~25M | Quick testing, debugging |
204
+ | small | 512 | 8 | 8 | ~70M | Development, prototyping |
205
+ | **base** | **768** | **12** | **12** | **~170M** | **Recommended for training** |
206
+ | medium | 1024 | 12 | 16 | ~300M | Final training, best quality |
207
+
208
+ ---
209
+
210
+ ## Konfigurasi
211
+
212
+ ### Model Config
213
+
214
+ ```python
215
+ from diffusion_llm.config.model_config import AamDiffusionConfig, ModelConfig, DiffusionConfig
216
+
217
+ config = AamDiffusionConfig(
218
+ model=ModelConfig(
219
+ d_model=768, # Hidden dimension
220
+ n_layers=12, # Transformer blocks
221
+ n_heads=12, # Attention heads
222
+ d_ff=3072, # Feed-forward dimension
223
+ vocab_size=32000, # Vocabulary size
224
+ max_seq_len=512, # Maximum sequence length
225
+ ),
226
+ diffusion=DiffusionConfig(
227
+ n_timesteps=1000, # Training timesteps
228
+ n_inference_steps=50, # Inference steps (fewer = faster)
229
+ schedule_type="cosine", # Noise schedule
230
+ prediction_type="epsilon", # Predict noise
231
+ sampling_method="ddim", # Fast deterministic sampling
232
+ ),
233
+ )
234
+ ```
235
+
236
+ ### Inference Config
237
+
238
+ ```python
239
+ from diffusion_llm.config.model_config import InferenceConfig
240
+
241
+ inference = InferenceConfig(
242
+ n_steps=50, # Denoising steps
243
+ temperature=1.0, # Sampling temperature
244
+ top_k=50, # Top-k sampling
245
+ max_output_sentences=16, # Max sentences
246
+ language="id", # Output language
247
+ )
248
+ ```
249
+
250
+ ---
251
+
252
+ ## Integrasi dengan AAM Pipeline
253
+
254
+ Framework ini dirancang untuk menjadi "tubuh" dari AAM. Setelah model dilatih,
255
+ integrasi dengan `pipeline.py` sangat mudah:
256
+
257
+ ```python
258
+ # Dalam pipeline.py, ganti fallback:
259
+ from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator
260
+
261
+ class AamPipeline:
262
+ def __init__(self, ...):
263
+ # Load trained diffusion model
264
+ diffusion_config = AamDiffusionConfig.from_json("path/to/config.json")
265
+ diffusion_model = AamDiffusionModel.load("path/to/best.pt")
266
+ diffusion_tokenizer = AamTokenizer.load("path/to/tokenizer.json")
267
+ self.diffusion_llm = AamGenerator(diffusion_model, diffusion_tokenizer, diffusion_config)
268
+ ```
269
+
270
+ ---
271
+
272
+ ## Training Data Format
273
+
274
+ Data training dalam format JSONL, satu contoh per baris:
275
+
276
+ ```json
277
+ {
278
+ "narrative": "Berdasarkan analisis, Diancang Five Swords mencuri Snow Plum Pill menggunakan Ju Jangmok sebagai kambing hitam.",
279
+ "trigger": "Siapa yang mencuri Snow Plum Pill?",
280
+ "evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
281
+ "compositions": [],
282
+ "confidence_map": {"Hefei": 0.9, "Diancang Five Swords": 0.85, "Ju Jangmok": 0.7},
283
+ "anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
284
+ "reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola", "Pattern completion dari bukti terpisah"],
285
+ "source_trust": 0.85,
286
+ "temporal_context": [],
287
+ "language": "id",
288
+ "source": "synthetic"
289
+ }
290
+ ```
291
+
292
+ ---
293
+
294
+ ## Running Tests
295
+
296
+ ```bash
297
+ # Run all tests
298
+ cd diffusion_llm
299
+ python -m pytest tests/ -v
300
+
301
+ # Run specific test
302
+ python -m pytest tests/test_model.py -v
303
+
304
+ # Run with coverage
305
+ python -m pytest tests/ --cov=diffusion_llm
306
+ ```
307
+
308
+ ---
309
+
310
+ ## Roadmap
311
+
312
+ - [x] **Phase 1: Framework Design** — Arsitektur, config, interface
313
+ - [x] **Phase 2: Core Components** — Noise scheduler, transformer, graph encoder, tokenizer
314
+ - [x] **Phase 3: Training Infrastructure** — Trainer, dataset, loss functions, synthetic data
315
+ - [x] **Phase 4: Inference Pipeline** — Generator, batch generation, interactive mode
316
+ - [ ] **Phase 5: Training Execution** — Train on synthetic data, iterate
317
+ - [ ] **Phase 6: Real Data** — Collect real Graph→Narrative pairs from AAM usage
318
+ - [ ] **Phase 7: Optimization** — Quantization, distillation, flash attention
319
+ - [ ] **Phase 8: Integration** — Plug trained model into AAM pipeline
320
+
321
+ ---
322
+
323
+ ## Analogi Novel
324
+
325
+ > Jin Soun bukan orang yang menyewa tubuh orang lain untuk berbicara.
326
+ > Dia punya tubuh sendiri — lemah, third-rate, tapi MILIKNYA.
327
+ > Karena tubuhnya khusus dilatih untuk mengeksekusi perintah dari
328
+ > pikirannya (bukan pikiran orang lain), outputnya lebih terarah
329
+ > daripada orang yang punya tubuh lebih kuat tapi pikiran lebih lemah.
330
+ >
331
+ > **AAM = 1 pikiran + 1 tubuh. Bukan 1 pikiran + tubuh sewaan.**
diffusion_llm/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM Framework — The Body of Aphantasic Abstraction Model
3
+
4
+ "AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)
5
+
6
+ Pikiran (Mind) = RSVS Knowledge Graph — structural, relational, perfect memory
7
+ Tubuh (Body) = This Diffusion LLM — generates natural language FROM the graph
8
+
9
+ This is NOT a general-purpose LLM. This is a SPECIALIZED sentence composer
10
+ that takes structured graph data as input and produces coherent, evidence-backed
11
+ narrative output. Think of it as a "vocal cord" for the graph — it can only
12
+ say what the graph knows, but it says it fluently.
13
+
14
+ Why Diffusion?
15
+ - Diffusion models start from noise and iteratively denoise
16
+ - This mirrors how Jin Soun's thoughts form: from vague intuition ->
17
+ clearer pattern -> explicit narrative
18
+ - Unlike autoregressive LLMs (GPT), diffusion models can:
19
+ - Be conditioned on structured input (graph)
20
+ - Revise earlier parts during generation (non-sequential)
21
+ - Produce more coherent long-form text from structure
22
+
23
+ Architecture:
24
+ Input: Graph conditioning (evidence nodes, compositions, confidence, anomalies)
25
+ Process: Iterative denoising from noise
26
+ Output: Natural language narrative grounded in graph structure
27
+
28
+ Analogi: Jin Soun (graph) + tubuhnya (this model).
29
+ Tubuhnya third-rate, tapi karena KHUSUS dilatih untuk
30
+ mengeksekusi perintah dari graph-nya sendiri, outputnya
31
+ lebih terarah daripada LLM umum yang "tidak kenal" graph.
32
+ """
33
+
34
+ __version__ = "0.1.0"
35
+ __author__ = "AAM Team"
36
+
37
+ from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
38
+ from diffusion_llm.model.noise_scheduler import NoiseScheduler
39
+ from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
40
+ from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
41
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
42
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
43
+ from diffusion_llm.inference.generator import AamGenerator
44
+ from diffusion_llm.training.trainer import AamTrainer
45
+ from diffusion_llm.training.dataset import GraphNarrativeDataset
46
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
47
+
48
+ __all__ = [
49
+ "AamDiffusionConfig",
50
+ "get_default_config",
51
+ "NoiseScheduler",
52
+ "GraphConditioningEncoder",
53
+ "DiffusionTransformer",
54
+ "AamDiffusionModel",
55
+ "AamTokenizer",
56
+ "AamGenerator",
57
+ "AamTrainer",
58
+ "GraphNarrativeDataset",
59
+ "SyntheticDataGenerator",
60
+ ]
diffusion_llm/config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Configuration module for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
4
+
5
+ __all__ = ["AamDiffusionConfig", "get_default_config"]
diffusion_llm/config/model_config.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Model Configuration
3
+
4
+ Defines all hyperparameters for the diffusion model architecture,
5
+ training process, and inference pipeline.
6
+
7
+ Design Philosophy:
8
+ - Small model (100M-500M params) — specialized, not general
9
+ - Sentence-level tokenization — not subword, because AAM arranges
10
+ sentences, not individual tokens
11
+ - Graph-conditioned — the model MUST receive graph structure as input
12
+ - Non-sequential generation — diffusion, not autoregressive
13
+
14
+ Analogi: Seperti tubuh Jin Soun, model ini kecil tapi KKHUSUS
15
+ dilatih untuk satu tugas: menarasikan dari graph. Tidak perlu
16
+ 7B params kalau tugasku hanya menyusun kalimat dari data yang
17
+ sudah terstruktur.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ from dataclasses import dataclass, field, asdict
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """Architecture hyperparameters for the Diffusion Transformer.
31
+
32
+ Target: 100M-500M parameters total.
33
+ Calculation:
34
+ params ≈ d_model^2 * (12 * n_layers) for transformer
35
+ d_model=512, n_layers=8 → ~50M core params
36
+ d_model=768, n_layers=12 → ~170M core params
37
+ d_model=1024, n_layers=12 → ~300M core params
38
+ """
39
+
40
+ # --- Core Transformer ---
41
+ d_model: int = 768
42
+ """Hidden dimension of the transformer."""
43
+
44
+ n_layers: int = 12
45
+ """Number of transformer blocks."""
46
+
47
+ n_heads: int = 12
48
+ """Number of attention heads (d_model must be divisible by n_heads)."""
49
+
50
+ d_ff: int = 3072
51
+ """Feed-forward hidden dimension (typically 4x d_model)."""
52
+
53
+ dropout: float = 0.1
54
+ """Dropout rate for attention and feed-forward layers."""
55
+
56
+ activation: str = "gelu"
57
+ """Activation function: 'gelu' or 'relu'."""
58
+
59
+ # --- Sequence ---
60
+ max_seq_len: int = 512
61
+ """Maximum sequence length (in sentence-level tokens)."""
62
+
63
+ # --- Vocabulary ---
64
+ vocab_size: int = 32000
65
+ """Vocabulary size for the tokenizer.
66
+ Since we use sentence-level tokens + subword BPE hybrid,
67
+ this includes special tokens + subword units.
68
+ """
69
+
70
+ # --- Positional Encoding ---
71
+ pos_encoding_type: str = "rotary"
72
+ """Positional encoding type: 'rotary' (RoPE) or 'learned'."""
73
+
74
+ # --- Attention ---
75
+ use_flash_attention: bool = True
76
+ """Whether to use Flash Attention 2 if available."""
77
+
78
+ # --- Normalization ---
79
+ norm_type: str = "rmsnorm"
80
+ """Normalization type: 'rmsnorm' or 'layernorm'."""
81
+
82
+ norm_eps: float = 1e-6
83
+ """Epsilon for normalization layers."""
84
+
85
+ # --- Initialization ---
86
+ init_std: float = 0.02
87
+ """Standard deviation for weight initialization."""
88
+
89
+ def estimate_params(self) -> str:
90
+ """Estimate total parameter count."""
91
+ # Embedding: vocab_size * d_model
92
+ embed_params = self.vocab_size * self.d_model
93
+ # Per layer: 4 * d_model^2 (QKV + O) + 2 * d_model * d_ff (FF)
94
+ layer_params = 4 * self.d_model ** 2 + 2 * self.d_model * self.d_ff
95
+ total = embed_params + self.n_layers * layer_params
96
+ if total >= 1e9:
97
+ return f"{total / 1e9:.1f}B"
98
+ elif total >= 1e6:
99
+ return f"{total / 1e6:.1f}M"
100
+ else:
101
+ return f"{total / 1e3:.1f}K"
102
+
103
+
104
+ @dataclass
105
+ class DiffusionConfig:
106
+ """Hyperparameters for the diffusion process.
107
+
108
+ The diffusion process works on the latent representation of text:
109
+ 1. Forward: Add Gaussian noise to text embeddings over T timesteps
110
+ 2. Reverse: Learn to denoise step by step
111
+ 3. At inference: Start from pure noise, denoise to coherent text
112
+
113
+ This is DIFFERENT from image diffusion because:
114
+ - We operate in a learned latent space (not pixel space)
115
+ - Text has discrete structure (sentences, not pixels)
116
+ - We use a text-specific noise schedule
117
+ """
118
+
119
+ # --- Noise Schedule ---
120
+ n_timesteps: int = 1000
121
+ """Total number of diffusion timesteps for training."""
122
+
123
+ n_inference_steps: int = 50
124
+ """Number of denoising steps at inference (fewer = faster, less quality)."""
125
+
126
+ schedule_type: str = "cosine"
127
+ """Noise schedule type: 'linear', 'cosine', or 'sigmoid'."""
128
+
129
+ beta_start: float = 1e-4
130
+ """Starting beta for linear schedule."""
131
+
132
+ beta_end: float = 0.02
133
+ """Ending beta for linear schedule."""
134
+
135
+ # --- Noise Prediction ---
136
+ prediction_type: str = "epsilon"
137
+ """What the model predicts: 'epsilon' (noise), 'x0' (clean data),
138
+ or 'v' (velocity). Epsilon prediction is most stable for text."""
139
+
140
+ # --- Sampling ---
141
+ sampling_method: str = "ddim"
142
+ """Sampling method: 'ddpm' (slow, stochastic) or 'ddim' (fast, deterministic)."""
143
+
144
+ eta_ddim: float = 0.0
145
+ """DDIM stochasticity parameter (0 = deterministic, 1 = full stochastic)."""
146
+
147
+ # --- Clipping ---
148
+ clip_sample_max: float = 5.0
149
+ """Maximum value for clipped samples during inference."""
150
+
151
+ clip_sample_min: float = -5.0
152
+ """Minimum value for clipped samples during inference."""
153
+
154
+ # --- Loss ---
155
+ loss_type: str = "mse"
156
+ """Loss function: 'mse' (L2) or 'mae' (L1) or 'huber'."""
157
+
158
+ loss_weighting: str = "min_snr"
159
+ """Loss weighting strategy: 'none', 'min_snr', or 'p2'."""
160
+
161
+ p2_gamma: float = 1.0
162
+ """P2 weighting gamma (only used if loss_weighting='p2')."""
163
+
164
+ p2_k: float = 1.0
165
+ """P2 weighting k (only used if loss_weighting='p2')."""
166
+
167
+
168
+ @dataclass
169
+ class GraphEncoderConfig:
170
+ """Configuration for the Graph Conditioning Encoder.
171
+
172
+ The graph encoder takes structured graph data (evidence nodes,
173
+ compositions, confidence scores, anomalies, reasoning chains)
174
+ and produces a conditioning vector that guides the diffusion process.
175
+
176
+ This is the KEY differentiator from general LLMs:
177
+ the model is conditioned on GRAPH STRUCTURE, not just text prompts.
178
+ """
179
+
180
+ # --- Graph Encoder Architecture ---
181
+ d_graph: int = 512
182
+ """Hidden dimension for graph encoding."""
183
+
184
+ n_graph_layers: int = 4
185
+ """Number of graph attention layers."""
186
+
187
+ n_graph_heads: int = 8
188
+ """Number of attention heads for graph encoding."""
189
+
190
+ # --- Input Dimensions ---
191
+ max_evidence_nodes: int = 50
192
+ """Maximum number of evidence nodes to encode."""
193
+
194
+ max_compositions: int = 20
195
+ """Maximum number of compositions to encode."""
196
+
197
+ max_anomalies: int = 10
198
+ """Maximum number of anomalies to encode."""
199
+
200
+ max_reasoning_steps: int = 15
201
+ """Maximum number of reasoning steps to encode."""
202
+
203
+ # --- Conditioning Injection ---
204
+ conditioning_method: str = "cross_attention"
205
+ """How to inject graph conditioning into the diffusion model:
206
+ 'cross_attention' (separate encoder, cross-attn in transformer)
207
+ 'ada_ln' (adaptive layer norm, conditioning modulates scale/shift)
208
+ 'concat' (concatenate conditioning to input sequence)
209
+ """
210
+
211
+ # --- Confidence Embedding ---
212
+ embed_confidence: bool = True
213
+ """Whether to embed confidence scores as part of the conditioning."""
214
+
215
+ # --- Temporal Embedding ---
216
+ embed_temporal: bool = True
217
+ """Whether to embed temporal context (time-based relationships)."""
218
+
219
+
220
+ @dataclass
221
+ class TokenizerConfig:
222
+ """Configuration for the AAM Sentence-Level Tokenizer.
223
+
224
+ Unlike standard BPE tokenizers that operate at subword level,
225
+ AAM's tokenizer is designed for SENTENCE ARRANGEMENT:
226
+ - Sentences are the primary unit of generation
227
+ - Within sentences, subword BPE handles individual words
228
+ - Special tokens for graph structure (evidence, anomaly, etc.)
229
+ """
230
+
231
+ # --- BPE ---
232
+ bpe_vocab_size: int = 28000
233
+ """Subword BPE vocabulary size (within the total vocab_size)."""
234
+
235
+ # --- Sentence-Level ---
236
+ max_sentences: int = 32
237
+ """Maximum number of sentences in one generation."""
238
+
239
+ sentence_boundary_token: str = "<sent>"
240
+ """Token marking sentence boundaries."""
241
+
242
+ # --- Special Tokens ---
243
+ pad_token: str = "<pad>"
244
+ bos_token: str = "<bos>"
245
+ eos_token: str = "<eos>"
246
+ mask_token: str = "<mask>"
247
+ noise_token: str = "<noise>"
248
+
249
+ # --- Graph-Structure Tokens ---
250
+ evidence_token: str = "<evidence>"
251
+ anomaly_token: str = "<anomaly>"
252
+ confidence_token: str = "<confidence>"
253
+ reasoning_token: str = "<reasoning>"
254
+ composition_token: str = "<composition>"
255
+ temporal_token: str = "<temporal>"
256
+
257
+ # --- Training ---
258
+ min_frequency: int = 2
259
+ """Minimum frequency for BPE merge operations."""
260
+
261
+ dropout_rate: float = 0.0
262
+ """BPE dropout rate (0 = no dropout, regularization during training)."""
263
+
264
+
265
+ @dataclass
266
+ class TrainingConfig:
267
+ """Training hyperparameters and settings."""
268
+
269
+ # --- Optimizer ---
270
+ learning_rate: float = 1e-4
271
+ """Peak learning rate."""
272
+
273
+ weight_decay: float = 0.01
274
+ """Weight decay for AdamW."""
275
+
276
+ adam_beta1: float = 0.9
277
+ """Adam beta1."""
278
+
279
+ adam_beta2: float = 0.999
280
+ """Adam beta2."""
281
+
282
+ adam_eps: float = 1e-8
283
+ """Adam epsilon."""
284
+
285
+ # --- Learning Rate Schedule ---
286
+ lr_schedule: str = "cosine"
287
+ """LR schedule: 'cosine', 'linear', or 'constant'."""
288
+
289
+ warmup_steps: int = 2000
290
+ """Number of warmup steps."""
291
+
292
+ # --- Training ---
293
+ batch_size: int = 32
294
+ """Training batch size (per GPU)."""
295
+
296
+ gradient_accumulation_steps: int = 4
297
+ """Gradient accumulation steps (effective batch = batch_size * this)."""
298
+
299
+ max_steps: int = 500000
300
+ """Maximum training steps."""
301
+
302
+ max_epochs: int = 100
303
+ """Maximum training epochs."""
304
+
305
+ # --- Regularization ---
306
+ dropout: float = 0.1
307
+ """Training dropout rate."""
308
+
309
+ grad_clip_norm: float = 1.0
310
+ """Gradient clipping max norm."""
311
+
312
+ # --- Mixed Precision ---
313
+ use_amp: bool = True
314
+ """Whether to use Automatic Mixed Precision (fp16/bf16)."""
315
+
316
+ amp_dtype: str = "bf16"
317
+ """AMP data type: 'fp16' or 'bf16'."""
318
+
319
+ # --- Checkpointing ---
320
+ save_every_steps: int = 5000
321
+ """Save checkpoint every N steps."""
322
+
323
+ eval_every_steps: int = 1000
324
+ """Evaluate every N steps."""
325
+
326
+ keep_last_n_checkpoints: int = 3
327
+ """Keep only the last N checkpoints."""
328
+
329
+ # --- EMA ---
330
+ use_ema: bool = True
331
+ """Whether to use Exponential Moving Average for inference weights."""
332
+
333
+ ema_decay: float = 0.9999
334
+ """EMA decay rate."""
335
+
336
+ # --- Data ---
337
+ train_data_path: str = ""
338
+ """Path to training data (JSONL format)."""
339
+
340
+ val_data_path: str = ""
341
+ """Path to validation data (JSONL format)."""
342
+
343
+ num_workers: int = 4
344
+ """Number of data loading workers."""
345
+
346
+ # --- Logging ---
347
+ log_every_steps: int = 100
348
+ """Log training metrics every N steps."""
349
+
350
+ wandb_project: str = "aam-diffusion-llm"
351
+ """Weights & Biases project name."""
352
+
353
+ wandb_run_name: str = ""
354
+ """Weights & Biases run name (auto-generated if empty)."""
355
+
356
+
357
+ @dataclass
358
+ class InferenceConfig:
359
+ """Inference-time configuration."""
360
+
361
+ n_steps: int = 50
362
+ """Number of denoising steps (more = better quality, slower)."""
363
+
364
+ temperature: float = 1.0
365
+ """Sampling temperature (1.0 = standard, <1 = more deterministic)."""
366
+
367
+ top_k: int = 50
368
+ """Top-k sampling for token decoding."""
369
+
370
+ top_p: float = 0.95
371
+ """Nucleus sampling threshold."""
372
+
373
+ repetition_penalty: float = 1.2
374
+ """Penalty for repeating tokens."""
375
+
376
+ max_output_sentences: int = 16
377
+ """Maximum number of sentences in output."""
378
+
379
+ language: str = "id"
380
+ """Output language: 'id' (Indonesian) or 'en' (English)."""
381
+
382
+
383
+ @dataclass
384
+ class AamDiffusionConfig:
385
+ """Master configuration for the AAM Diffusion LLM.
386
+
387
+ Combines all sub-configurations into a single object.
388
+ This is the entry point for configuring the entire framework.
389
+ """
390
+
391
+ model: ModelConfig = field(default_factory=ModelConfig)
392
+ diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
393
+ graph_encoder: GraphEncoderConfig = field(default_factory=GraphEncoderConfig)
394
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
395
+ training: TrainingConfig = field(default_factory=TrainingConfig)
396
+ inference: InferenceConfig = field(default_factory=InferenceConfig)
397
+
398
+ # --- Meta ---
399
+ model_name: str = "aam-diffusion-v0.1"
400
+ """Model name for saving/loading."""
401
+
402
+ output_dir: str = "./output"
403
+ """Base output directory."""
404
+
405
+ seed: int = 42
406
+ """Random seed for reproducibility."""
407
+
408
+ # --- AAM Philosophy ---
409
+ aam_mind_source: str = "rsvs_graph"
410
+ """Source of the 'mind' that conditions this 'body'.
411
+ Always 'rsvs_graph' for AAM — the model CANNOT generate
412
+ information not present in the graph conditioning."""
413
+
414
+ aam_body_type: str = "specialized_diffusion"
415
+ """Type of the 'body'. Always 'specialized_diffusion' for AAM.
416
+ This is NOT a general LLM — it only arranges sentences
417
+ based on graph-structured evidence."""
418
+
419
+ def to_dict(self) -> dict:
420
+ """Serialize config to dictionary."""
421
+ return asdict(self)
422
+
423
+ def to_json(self, path: str | Path) -> None:
424
+ """Save config to JSON file."""
425
+ path = Path(path)
426
+ path.parent.mkdir(parents=True, exist_ok=True)
427
+ with open(path, "w", encoding="utf-8") as f:
428
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
429
+
430
+ @classmethod
431
+ def from_json(cls, path: str | Path) -> AamDiffusionConfig:
432
+ """Load config from JSON file."""
433
+ with open(path, "r", encoding="utf-8") as f:
434
+ data = json.load(f)
435
+ return cls(
436
+ model=ModelConfig(**data.get("model", {})),
437
+ diffusion=DiffusionConfig(**data.get("diffusion", {})),
438
+ graph_encoder=GraphEncoderConfig(**data.get("graph_encoder", {})),
439
+ tokenizer=TokenizerConfig(**data.get("tokenizer", {})),
440
+ training=TrainingConfig(**data.get("training", {})),
441
+ inference=InferenceConfig(**data.get("inference", {})),
442
+ model_name=data.get("model_name", "aam-diffusion-v0.1"),
443
+ output_dir=data.get("output_dir", "./output"),
444
+ seed=data.get("seed", 42),
445
+ aam_mind_source=data.get("aam_mind_source", "rsvs_graph"),
446
+ aam_body_type=data.get("aam_body_type", "specialized_diffusion"),
447
+ )
448
+
449
+ def summary(self) -> str:
450
+ """Print a summary of the configuration."""
451
+ lines = [
452
+ "=" * 60,
453
+ f" AAM Diffusion LLM Configuration: {self.model_name}",
454
+ "=" * 60,
455
+ "",
456
+ f" Model Architecture:",
457
+ f" d_model={self.model.d_model}, n_layers={self.model.n_layers}, "
458
+ f"n_heads={self.model.n_heads}",
459
+ f" d_ff={self.model.d_ff}, vocab_size={self.model.vocab_size}",
460
+ f" max_seq_len={self.model.max_seq_len}",
461
+ f" Estimated params: {self.model.estimate_params()}",
462
+ "",
463
+ f" Diffusion Process:",
464
+ f" Timesteps (train)={self.diffusion.n_timesteps}",
465
+ f" Timesteps (inference)={self.diffusion.n_inference_steps}",
466
+ f" Schedule={self.diffusion.schedule_type}",
467
+ f" Prediction={self.diffusion.prediction_type}",
468
+ f" Sampling={self.diffusion.sampling_method}",
469
+ "",
470
+ f" Graph Encoder:",
471
+ f" d_graph={self.graph_encoder.d_graph}",
472
+ f" n_layers={self.graph_encoder.n_graph_layers}",
473
+ f" Conditioning={self.graph_encoder.conditioning_method}",
474
+ f" Max evidence nodes={self.graph_encoder.max_evidence_nodes}",
475
+ "",
476
+ f" Training:",
477
+ f" LR={self.training.learning_rate}",
478
+ f" Batch={self.training.batch_size} x {self.training.gradient_accumulation_steps} accum",
479
+ f" Max steps={self.training.max_steps}",
480
+ f" AMP={self.training.use_amp} ({self.training.amp_dtype})",
481
+ "",
482
+ f" AAM Philosophy:",
483
+ f" Mind = {self.aam_mind_source} (RSVS Knowledge Graph)",
484
+ f" Body = {self.aam_body_type} (This Model)",
485
+ f" Identity = 1 Mind + 1 Body (NOT rented LLM)",
486
+ "",
487
+ "=" * 60,
488
+ ]
489
+ return "\n".join(lines)
490
+
491
+
492
+ def get_default_config(
493
+ model_size: str = "base",
494
+ ) -> AamDiffusionConfig:
495
+ """Get a default configuration for different model sizes.
496
+
497
+ Args:
498
+ model_size: One of 'tiny', 'small', 'base', 'medium'.
499
+ - tiny: ~25M params (for quick testing)
500
+ - small: ~70M params (for development)
501
+ - base: ~170M params (recommended for training)
502
+ - medium: ~300M params (for final training)
503
+
504
+ Returns:
505
+ AamDiffusionConfig with appropriate settings.
506
+ """
507
+ configs = {
508
+ "tiny": AamDiffusionConfig(
509
+ model=ModelConfig(
510
+ d_model=256,
511
+ n_layers=4,
512
+ n_heads=4,
513
+ d_ff=1024,
514
+ vocab_size=16000,
515
+ max_seq_len=256,
516
+ ),
517
+ graph_encoder=GraphEncoderConfig(
518
+ d_graph=256,
519
+ n_graph_layers=2,
520
+ n_graph_heads=4,
521
+ ),
522
+ diffusion=DiffusionConfig(
523
+ n_timesteps=500,
524
+ n_inference_steps=20,
525
+ ),
526
+ training=TrainingConfig(
527
+ batch_size=16,
528
+ learning_rate=3e-4,
529
+ warmup_steps=500,
530
+ max_steps=100000,
531
+ ),
532
+ model_name="aam-diffusion-tiny",
533
+ ),
534
+ "small": AamDiffusionConfig(
535
+ model=ModelConfig(
536
+ d_model=512,
537
+ n_layers=8,
538
+ n_heads=8,
539
+ d_ff=2048,
540
+ vocab_size=24000,
541
+ max_seq_len=384,
542
+ ),
543
+ graph_encoder=GraphEncoderConfig(
544
+ d_graph=384,
545
+ n_graph_layers=4,
546
+ n_graph_heads=8,
547
+ ),
548
+ diffusion=DiffusionConfig(
549
+ n_timesteps=1000,
550
+ n_inference_steps=30,
551
+ ),
552
+ training=TrainingConfig(
553
+ batch_size=24,
554
+ learning_rate=2e-4,
555
+ warmup_steps=1000,
556
+ max_steps=200000,
557
+ ),
558
+ model_name="aam-diffusion-small",
559
+ ),
560
+ "base": AamDiffusionConfig(
561
+ model=ModelConfig(
562
+ d_model=768,
563
+ n_layers=12,
564
+ n_heads=12,
565
+ d_ff=3072,
566
+ vocab_size=32000,
567
+ max_seq_len=512,
568
+ ),
569
+ graph_encoder=GraphEncoderConfig(
570
+ d_graph=512,
571
+ n_graph_layers=4,
572
+ n_graph_heads=8,
573
+ ),
574
+ diffusion=DiffusionConfig(
575
+ n_timesteps=1000,
576
+ n_inference_steps=50,
577
+ ),
578
+ training=TrainingConfig(
579
+ batch_size=32,
580
+ learning_rate=1e-4,
581
+ warmup_steps=2000,
582
+ max_steps=500000,
583
+ ),
584
+ model_name="aam-diffusion-base",
585
+ ),
586
+ "medium": AamDiffusionConfig(
587
+ model=ModelConfig(
588
+ d_model=1024,
589
+ n_layers=12,
590
+ n_heads=16,
591
+ d_ff=4096,
592
+ vocab_size=32000,
593
+ max_seq_len=768,
594
+ ),
595
+ graph_encoder=GraphEncoderConfig(
596
+ d_graph=768,
597
+ n_graph_layers=6,
598
+ n_graph_heads=12,
599
+ ),
600
+ diffusion=DiffusionConfig(
601
+ n_timesteps=1000,
602
+ n_inference_steps=50,
603
+ ),
604
+ training=TrainingConfig(
605
+ batch_size=16,
606
+ learning_rate=5e-5,
607
+ warmup_steps=5000,
608
+ max_steps=1000000,
609
+ ),
610
+ model_name="aam-diffusion-medium",
611
+ ),
612
+ }
613
+
614
+ if model_size not in configs:
615
+ raise ValueError(
616
+ f"Unknown model_size '{model_size}'. "
617
+ f"Choose from: {list(configs.keys())}"
618
+ )
619
+
620
+ return configs[model_size]
diffusion_llm/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Data pipeline module for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
4
+ from diffusion_llm.data.data_pipeline import DataPipeline
5
+
6
+ __all__ = ["SyntheticDataGenerator", "DataPipeline"]
diffusion_llm/data/data_pipeline.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Data Pipeline
3
+
4
+ Orchestrates data preparation: from raw graph data and narratives
5
+ to tokenized, batched training data.
6
+
7
+ The pipeline handles:
8
+ 1. Loading raw graph→narrative pairs
9
+ 2. Generating synthetic data if real data isn't available
10
+ 3. Tokenizing all data
11
+ 4. Creating train/val splits
12
+ 5. Building DataLoaders
13
+
14
+ Analogi: Seperti proses persiapan sebelum Jin Soun berlatih —
15
+ mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan
16
+ data latihan yang terstruktur.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import logging
22
+ from pathlib import Path
23
+ from typing import Optional
24
+
25
+ from torch.utils.data import DataLoader
26
+
27
+ from diffusion_llm.config.model_config import AamDiffusionConfig
28
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
29
+ from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
30
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class DataPipeline:
36
+ """Data preparation pipeline for AAM Diffusion LLM training.
37
+
38
+ Orchestrates the entire data preparation process:
39
+ 1. Check for existing data
40
+ 2. Generate synthetic data if needed
41
+ 3. Train tokenizer on the data
42
+ 4. Create datasets and dataloaders
43
+
44
+ Usage:
45
+ pipeline = DataPipeline(config)
46
+ tokenizer, train_loader, val_loader = pipeline.prepare()
47
+ """
48
+
49
+ def __init__(self, config: AamDiffusionConfig):
50
+ self.config = config
51
+ self.output_dir = Path(config.output_dir) / "data"
52
+ self.output_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ def prepare(
55
+ self,
56
+ tokenizer: Optional[AamTokenizer] = None,
57
+ force_regenerate: bool = False,
58
+ ) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]:
59
+ """Prepare all data for training.
60
+
61
+ Args:
62
+ tokenizer: Optional pre-trained tokenizer.
63
+ force_regenerate: Whether to regenerate synthetic data.
64
+
65
+ Returns:
66
+ Tuple of (tokenizer, train_loader, val_loader).
67
+ """
68
+ train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None
69
+ val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None
70
+
71
+ # Step 1: Generate synthetic data if no real data
72
+ if not train_path or not train_path.exists() or force_regenerate:
73
+ logger.info("Generating synthetic training data...")
74
+ train_path, val_path = SyntheticDataGenerator.generate_training_split(
75
+ output_dir=self.output_dir,
76
+ n_train=10000,
77
+ n_val=500,
78
+ language=self.config.inference.language,
79
+ seed=self.config.seed,
80
+ )
81
+
82
+ # Step 2: Train tokenizer if not provided
83
+ if tokenizer is None or not tokenizer.is_trained:
84
+ logger.info("Training tokenizer...")
85
+ tokenizer = AamTokenizer()
86
+ # Read training texts for tokenizer training
87
+ texts = self._read_texts(train_path)
88
+ tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size)
89
+ tokenizer.save(self.output_dir / "tokenizer.json")
90
+ logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size)
91
+
92
+ # Step 3: Create datasets
93
+ logger.info("Creating datasets...")
94
+ train_dataset = GraphNarrativeDataset(
95
+ data_path=train_path,
96
+ tokenizer=tokenizer,
97
+ max_seq_len=self.config.model.max_seq_len,
98
+ max_evidence=self.config.graph_encoder.max_evidence_nodes,
99
+ max_anomalies=self.config.graph_encoder.max_anomalies,
100
+ max_reasoning=self.config.graph_encoder.max_reasoning_steps,
101
+ )
102
+
103
+ val_dataset = None
104
+ if val_path and val_path.exists():
105
+ val_dataset = GraphNarrativeDataset(
106
+ data_path=val_path,
107
+ tokenizer=tokenizer,
108
+ max_seq_len=self.config.model.max_seq_len,
109
+ max_evidence=self.config.graph_encoder.max_evidence_nodes,
110
+ max_anomalies=self.config.graph_encoder.max_anomalies,
111
+ max_reasoning=self.config.graph_encoder.max_reasoning_steps,
112
+ augment=False, # No augmentation for validation
113
+ )
114
+
115
+ # Step 4: Create dataloaders
116
+ train_loader = DataLoader(
117
+ train_dataset,
118
+ batch_size=self.config.training.batch_size,
119
+ shuffle=True,
120
+ num_workers=self.config.training.num_workers,
121
+ collate_fn=collate_fn,
122
+ pin_memory=True,
123
+ )
124
+
125
+ val_loader = None
126
+ if val_dataset:
127
+ val_loader = DataLoader(
128
+ val_dataset,
129
+ batch_size=self.config.training.batch_size,
130
+ shuffle=False,
131
+ num_workers=self.config.training.num_workers,
132
+ collate_fn=collate_fn,
133
+ pin_memory=True,
134
+ )
135
+
136
+ logger.info(
137
+ "Data pipeline ready: %d training examples, %s validation examples",
138
+ len(train_dataset),
139
+ len(val_dataset) if val_dataset else 0,
140
+ )
141
+
142
+ return tokenizer, train_loader, val_loader
143
+
144
+ def _read_texts(self, path: Path) -> list[str]:
145
+ """Read narrative texts from JSONL file for tokenizer training.
146
+
147
+ Args:
148
+ path: Path to JSONL data file.
149
+
150
+ Returns:
151
+ List of narrative texts.
152
+ """
153
+ import json
154
+ texts = []
155
+ if not path.exists():
156
+ return texts
157
+
158
+ with open(path, "r", encoding="utf-8") as f:
159
+ for line in f:
160
+ line = line.strip()
161
+ if not line:
162
+ continue
163
+ try:
164
+ data = json.loads(line)
165
+ # Collect both narratives and evidence for richer tokenizer
166
+ if data.get("narrative"):
167
+ texts.append(data["narrative"])
168
+ if data.get("trigger"):
169
+ texts.append(data["trigger"])
170
+ for ev in data.get("evidence_nodes", []):
171
+ texts.append(ev)
172
+ for anom in data.get("anomalies", []):
173
+ texts.append(anom)
174
+ for step in data.get("reasoning_steps", []):
175
+ texts.append(step)
176
+ except json.JSONDecodeError:
177
+ continue
178
+
179
+ return texts
diffusion_llm/data/synthetic_generator.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Synthetic Data Generator
3
+
4
+ Generates synthetic Graph→Narrative training pairs for
5
+ pre-training the diffusion model before real data is available.
6
+
7
+ The synthetic data follows the AAM pattern:
8
+ - Graph conditioning: evidence, compositions, anomalies, reasoning
9
+ - Target narrative: natural language text that represents the graph data
10
+
11
+ This is essential because:
12
+ 1. We need training data before the model can be used
13
+ 2. The data must follow the Graph→Narrative format specifically
14
+ 3. Synthetic data helps bootstrap the model's ability to
15
+ arrange sentences from structured evidence
16
+
17
+ Analogi: Seperti Jin Soun berlatih dengan kasus-kasus fiktif
18
+ sebelum menghadapi kasus nyata — data sintetis memberikan
19
+ "latihan dasar" sebelum data asli tersedia.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ import logging
26
+ import random
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Optional
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # --- Templates for synthetic data generation ---
35
+
36
+ # Indonesian narrative templates
37
+ ID_TEMPLATES = {
38
+ "analysis": [
39
+ "Berdasarkan analisis terhadap {trigger}: {evidence_summary}. {reasoning_summary}. Tingkat keyakinan: {confidence_pct}.",
40
+ "Analisis menunjukkan bahwa {trigger} terkait dengan {evidence_summary}. {anomaly_summary}. Kesimpulan: {reasoning_summary}.",
41
+ "Dari data yang tersedia, {trigger} memiliki koneksi ke {evidence_summary}. {reasoning_summary}. Confidence: {confidence_pct}.",
42
+ "Hasil investigasi: {trigger}. Bukti: {evidence_summary}. {anomaly_summary}. {reasoning_summary}.",
43
+ "Temuan: {trigger} berkorelasi dengan {evidence_summary}. Catatan: {anomaly_summary}. Analisis: {reasoning_summary}.",
44
+ ],
45
+ "evidence_summary": [
46
+ "bukti menunjukkan {nodes}",
47
+ "data dari {nodes} mengindikasikan",
48
+ "{nodes} menjadi kunci",
49
+ "informasi dari {nodes} mengarah ke",
50
+ "sumber {nodes} mengkonfirmasi",
51
+ ],
52
+ "anomaly_summary": [
53
+ "Anomali terdeteksi: {anomalies}",
54
+ "Perhatian: {anomalies}",
55
+ "Pola tidak lazim: {anomalies}",
56
+ "Ketidaksesuaian ditemukan: {anomalies}",
57
+ "Terdapat kejanggalan: {anomalies}",
58
+ ],
59
+ "reasoning_summary": [
60
+ "Langkah penalaran: {steps}",
61
+ "Proses deduksi: {steps}",
62
+ "Analisis bertahap: {steps}",
63
+ "Penelusuran logika: {steps}",
64
+ "Rantai penalaran: {steps}",
65
+ ],
66
+ }
67
+
68
+ # English narrative templates
69
+ EN_TEMPLATES = {
70
+ "analysis": [
71
+ "Based on analysis of {trigger}: {evidence_summary}. {reasoning_summary}. Confidence: {confidence_pct}.",
72
+ "Analysis indicates that {trigger} is related to {evidence_summary}. {anomaly_summary}. Conclusion: {reasoning_summary}.",
73
+ "From available data, {trigger} has connections to {evidence_summary}. {reasoning_summary}. Confidence level: {confidence_pct}.",
74
+ "Investigation results: {trigger}. Evidence: {evidence_summary}. {anomaly_summary}. {reasoning_summary}.",
75
+ "Findings: {trigger} correlates with {evidence_summary}. Note: {anomaly_summary}. Analysis: {reasoning_summary}.",
76
+ ],
77
+ "evidence_summary": [
78
+ "evidence shows {nodes}",
79
+ "data from {nodes} indicates",
80
+ "{nodes} are key factors",
81
+ "information from {nodes} points to",
82
+ "sources {nodes} confirm",
83
+ ],
84
+ "anomaly_summary": [
85
+ "Anomaly detected: {anomalies}",
86
+ "Note: {anomalies}",
87
+ "Unusual pattern: {anomalies}",
88
+ "Inconsistency found: {anomalies}",
89
+ "Irregularity observed: {anomalies}",
90
+ ],
91
+ "reasoning_summary": [
92
+ "Reasoning steps: {steps}",
93
+ "Deductive process: {steps}",
94
+ "Step-by-step analysis: {steps}",
95
+ "Logical trace: {steps}",
96
+ "Reasoning chain: {steps}",
97
+ ],
98
+ }
99
+
100
+ # Sample graph data for synthetic generation
101
+ SAMPLE_EVIDENCE_NODES = {
102
+ "id": [
103
+ "Hefei", "Diancang Five Swords", "Ju Jangmok", "Snow Plum Pill",
104
+ "Gyeryong Merchant Guild", "Simhyeon Pavilion", "Martial Alliance",
105
+ "Gu Ilmu", "Jang Hangi", "Blood Serpent Dance Step",
106
+ "taeul_sect", "dark_faction", "hefei_branch",
107
+ ],
108
+ "en": [
109
+ "Hefei", "Diancang Five Swords", "Ju Jangmok", "Snow Plum Pill",
110
+ "Gyeryong Merchant Guild", "Simhyeon Pavilion", "Martial Alliance",
111
+ "Gu Ilmu", "Jang Hangi", "Blood Serpent Dance Step",
112
+ "taeul_sect", "dark_faction", "hefei_branch",
113
+ ],
114
+ }
115
+
116
+ SAMPLE_TRIGGERS = {
117
+ "id": [
118
+ "Siapa yang mencuri Snow Plum Pill?",
119
+ "Analisis pergerakan Diancang Five Swords",
120
+ "Hubungan antara Ju Jangmok dan pencurian",
121
+ "Anomali dalam laporan Hefei",
122
+ "Investigasi inside job di Diancang",
123
+ "Pola konsumsi Snow Plum Pill",
124
+ "Cross-reference kejadian di Hefei",
125
+ "Evaluasi kepercayaan sumber informasi",
126
+ "Prediksi tindakan berikutnya tersangka",
127
+ "Pattern completion dari bukti terpisah",
128
+ ],
129
+ "en": [
130
+ "Who stole the Snow Plum Pill?",
131
+ "Analysis of Diancang Five Swords movements",
132
+ "Connection between Ju Jangmok and the theft",
133
+ "Anomalies in the Hefei reports",
134
+ "Investigation of inside job at Diancang",
135
+ "Pattern of Snow Plum Pill consumption",
136
+ "Cross-referencing events in Hefei",
137
+ "Source trustworthiness evaluation",
138
+ "Predicting next suspect actions",
139
+ "Pattern completion from disparate evidence",
140
+ ],
141
+ }
142
+
143
+ SAMPLE_ANOMALIES = {
144
+ "id": [
145
+ "Tidak ada konsumsi pil baru di pasar gelap",
146
+ "Pencuri menghilang tanpa jejak",
147
+ "Success rate pair lebih tinggi dari biasanya",
148
+ "Misi di-assign dari dalam Diancang sendiri",
149
+ "Ju Jangmok menghilang hari yang sama dengan pencurian",
150
+ "Tidak ada pencuri baru setelah Ju Jangmok menghilang",
151
+ ],
152
+ "en": [
153
+ "No new pill consumption in black market",
154
+ "Thief disappeared without a trace",
155
+ "Pair success rate unusually high",
156
+ "Mission assigned from within Diancang itself",
157
+ "Ju Jangmok disappeared same day as theft",
158
+ "No new thief appeared after Ju Jangmok vanished",
159
+ ],
160
+ }
161
+
162
+ SAMPLE_REASONING_STEPS = {
163
+ "id": [
164
+ "Recall: Ingat semua laporan terkait Hefei",
165
+ "Cross-reference: Bandingkan tanggal kejadian",
166
+ "Filter: Eliminasi yang tidak relevan",
167
+ "Anomaly: Deteksi ketidaksesuaian pola",
168
+ "Pattern: Hubungkan fragmen terpisah",
169
+ "Compose: Susun kesimpulan dari bukti",
170
+ "Predict: Perkirakan tindakan berikutnya",
171
+ "Verify: Cek konsistensi kesimpulan",
172
+ ],
173
+ "en": [
174
+ "Recall: Remember all reports related to Hefei",
175
+ "Cross-reference: Compare event dates",
176
+ "Filter: Eliminate irrelevant data",
177
+ "Anomaly: Detect pattern inconsistency",
178
+ "Pattern: Connect disparate fragments",
179
+ "Compose: Assemble conclusion from evidence",
180
+ "Predict: Estimate next actions",
181
+ "Verify: Check conclusion consistency",
182
+ ],
183
+ }
184
+
185
+
186
+ class SyntheticDataGenerator:
187
+ """Generate synthetic Graph→Narrative training pairs.
188
+
189
+ This generator creates training data that follows the AAM
190
+ pattern: structured graph conditioning → natural language narrative.
191
+
192
+ The generated data covers:
193
+ - Various trigger types (questions, analysis requests)
194
+ - Different numbers of evidence nodes (1-50)
195
+ - Various anomaly patterns
196
+ - Different reasoning chain lengths
197
+ - Confidence distributions
198
+ - Both Indonesian and English
199
+
200
+ Usage:
201
+ generator = SyntheticDataGenerator()
202
+ examples = generator.generate(n=1000, language="id")
203
+ generator.save(examples, "training_data.jsonl")
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ seed: int = 42,
209
+ language: str = "id",
210
+ ):
211
+ """Initialize the synthetic data generator.
212
+
213
+ Args:
214
+ seed: Random seed for reproducibility.
215
+ language: Default language for generation.
216
+ """
217
+ self.seed = seed
218
+ self.language = language
219
+ random.seed(seed)
220
+
221
+ def generate(
222
+ self,
223
+ n: int = 1000,
224
+ language: Optional[str] = None,
225
+ min_evidence: int = 2,
226
+ max_evidence: int = 15,
227
+ anomaly_probability: float = 0.6,
228
+ reasoning_probability: float = 0.8,
229
+ ) -> list[dict]:
230
+ """Generate synthetic training examples.
231
+
232
+ Args:
233
+ n: Number of examples to generate.
234
+ language: Language override.
235
+ min_evidence: Minimum evidence nodes per example.
236
+ max_evidence: Maximum evidence nodes per example.
237
+ anomaly_probability: Probability of including anomalies.
238
+ reasoning_probability: Probability of including reasoning steps.
239
+
240
+ Returns:
241
+ List of training example dictionaries.
242
+ """
243
+ lang = language or self.language
244
+ templates = ID_TEMPLATES if lang == "id" else EN_TEMPLATES
245
+ evidence_pool = SAMPLE_EVIDENCE_NODES.get(lang, SAMPLE_EVIDENCE_NODES["en"])
246
+ trigger_pool = SAMPLE_TRIGGERS.get(lang, SAMPLE_TRIGGERS["en"])
247
+ anomaly_pool = SAMPLE_ANOMALIES.get(lang, SAMPLE_ANOMALIES["en"])
248
+ reasoning_pool = SAMPLE_REASONING_STEPS.get(lang, SAMPLE_REASONING_STEPS["en"])
249
+
250
+ examples = []
251
+ for _ in range(n):
252
+ # Random trigger
253
+ trigger = random.choice(trigger_pool)
254
+
255
+ # Random evidence nodes
256
+ n_evidence = random.randint(min_evidence, max_evidence)
257
+ evidence = random.sample(evidence_pool, min(n_evidence, len(evidence_pool)))
258
+
259
+ # Random confidence map
260
+ confidence_map = {
261
+ node: round(random.uniform(0.3, 1.0), 2)
262
+ for node in evidence
263
+ }
264
+
265
+ # Random anomalies
266
+ anomalies = []
267
+ if random.random() < anomaly_probability:
268
+ n_anomalies = random.randint(1, 3)
269
+ anomalies = random.sample(anomaly_pool, min(n_anomalies, len(anomaly_pool)))
270
+
271
+ # Random reasoning steps
272
+ reasoning_steps = []
273
+ if random.random() < reasoning_probability:
274
+ n_steps = random.randint(2, 6)
275
+ reasoning_steps = random.sample(reasoning_pool, min(n_steps, len(reasoning_pool)))
276
+
277
+ # Source trust
278
+ source_trust = round(random.uniform(0.5, 1.0), 2)
279
+
280
+ # Generate narrative from template
281
+ narrative = self._generate_narrative(
282
+ trigger=trigger,
283
+ evidence=evidence,
284
+ anomalies=anomalies,
285
+ reasoning_steps=reasoning_steps,
286
+ confidence_map=confidence_map,
287
+ templates=templates,
288
+ lang=lang,
289
+ )
290
+
291
+ example = {
292
+ "narrative": narrative,
293
+ "trigger": trigger,
294
+ "evidence_nodes": evidence,
295
+ "compositions": [],
296
+ "confidence_map": confidence_map,
297
+ "anomalies": anomalies,
298
+ "reasoning_steps": reasoning_steps,
299
+ "source_trust": source_trust,
300
+ "language": lang,
301
+ "source": "synthetic",
302
+ }
303
+
304
+ examples.append(example)
305
+
306
+ logger.info("Generated %d synthetic examples (language=%s)", n, lang)
307
+ return examples
308
+
309
+ def _generate_narrative(
310
+ self,
311
+ trigger: str,
312
+ evidence: list[str],
313
+ anomalies: list[str],
314
+ reasoning_steps: list[str],
315
+ confidence_map: dict[str, float],
316
+ templates: dict,
317
+ lang: str,
318
+ ) -> str:
319
+ """Generate a narrative from templates.
320
+
321
+ Args:
322
+ trigger: Trigger text.
323
+ evidence: Evidence node labels.
324
+ anomalies: Anomaly descriptions.
325
+ reasoning_steps: Reasoning step descriptions.
326
+ confidence_map: Confidence scores.
327
+ templates: Template dictionary.
328
+ lang: Language code.
329
+
330
+ Returns:
331
+ Generated narrative string.
332
+ """
333
+ # Build narrative parts
334
+ evidence_str = ", ".join(evidence[:5])
335
+ avg_confidence = sum(confidence_map.values()) / max(len(confidence_map), 1)
336
+
337
+ # Fill templates
338
+ evidence_summary = random.choice(templates["evidence_summary"]).format(
339
+ nodes=evidence_str
340
+ )
341
+
342
+ anomaly_summary = ""
343
+ if anomalies:
344
+ anomaly_summary = random.choice(templates["anomaly_summary"]).format(
345
+ anomalies="; ".join(anomalies[:3])
346
+ )
347
+
348
+ reasoning_summary = ""
349
+ if reasoning_steps:
350
+ reasoning_summary = random.choice(templates["reasoning_summary"]).format(
351
+ steps="; ".join(reasoning_steps[:4])
352
+ )
353
+
354
+ # Main narrative
355
+ narrative = random.choice(templates["analysis"]).format(
356
+ trigger=trigger,
357
+ evidence_summary=evidence_summary,
358
+ anomaly_summary=anomaly_summary,
359
+ reasoning_summary=reasoning_summary,
360
+ confidence_pct=f"{avg_confidence:.0%}",
361
+ )
362
+
363
+ return narrative
364
+
365
+ def save(
366
+ self,
367
+ examples: list[dict],
368
+ path: str | Path,
369
+ ) -> None:
370
+ """Save examples to JSONL file.
371
+
372
+ Args:
373
+ examples: List of example dictionaries.
374
+ path: Output file path.
375
+ """
376
+ path = Path(path)
377
+ path.parent.mkdir(parents=True, exist_ok=True)
378
+
379
+ with open(path, "w", encoding="utf-8") as f:
380
+ for example in examples:
381
+ f.write(json.dumps(example, ensure_ascii=False) + "\n")
382
+
383
+ logger.info("Saved %d examples to %s", len(examples), path)
384
+
385
+ @classmethod
386
+ def generate_training_split(
387
+ cls,
388
+ output_dir: str | Path,
389
+ n_train: int = 10000,
390
+ n_val: int = 500,
391
+ language: str = "id",
392
+ seed: int = 42,
393
+ ) -> tuple[Path, Path]:
394
+ """Generate and save train/val splits.
395
+
396
+ Args:
397
+ output_dir: Output directory.
398
+ n_train: Number of training examples.
399
+ n_val: Number of validation examples.
400
+ language: Language for generation.
401
+ seed: Random seed.
402
+
403
+ Returns:
404
+ Tuple of (train_path, val_path).
405
+ """
406
+ output_dir = Path(output_dir)
407
+ output_dir.mkdir(parents=True, exist_ok=True)
408
+
409
+ generator = cls(seed=seed, language=language)
410
+
411
+ # Generate training data
412
+ train_examples = generator.generate(n=n_train, language=language)
413
+ train_path = output_dir / "train.jsonl"
414
+ generator.save(train_examples, train_path)
415
+
416
+ # Generate validation data (different seed)
417
+ val_generator = cls(seed=seed + 1, language=language)
418
+ val_examples = val_generator.generate(n=n_val, language=language)
419
+ val_path = output_dir / "val.jsonl"
420
+ val_generator.save(val_examples, val_path)
421
+
422
+ logger.info(
423
+ "Generated training split: %d train, %d val",
424
+ n_train, n_val,
425
+ )
426
+
427
+ return train_path, val_path
diffusion_llm/inference/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Inference module for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.inference.generator import AamGenerator
4
+
5
+ __all__ = ["AamGenerator"]
diffusion_llm/inference/generator.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Inference Generator
3
+
4
+ Generates natural language narratives from graph conditioning
5
+ using the trained diffusion model.
6
+
7
+ The generation process:
8
+ 1. Encode graph conditioning (evidence, anomalies, reasoning)
9
+ 2. Start from pure noise in the latent space
10
+ 3. Iteratively denoise for N steps
11
+ 4. Convert denoised embeddings to token IDs
12
+ 5. Detokenize to natural language text
13
+
14
+ Analogi: Seperti Jin Soun akhirnya "berbicara" — dari
15
+ pikiran yang kabur (noise) menjadi kata-kata yang jelas
16
+ (denoised narrative). Setiap langkah denoising = satu
17
+ langkah lebih dekat ke koherensi.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import logging
23
+ import time
24
+ from dataclasses import dataclass, field
25
+ from typing import Optional
26
+
27
+ import torch
28
+
29
+ from diffusion_llm.config.model_config import AamDiffusionConfig, InferenceConfig
30
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
31
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class GenerationResult:
38
+ """Result from a generation call.
39
+
40
+ Contains the generated narrative plus metadata about
41
+ how it was generated, for traceability.
42
+ """
43
+ narrative: str
44
+ """Generated narrative text."""
45
+
46
+ token_ids: list[int] = field(default_factory=list)
47
+ """Generated token IDs."""
48
+
49
+ n_diffusion_steps: int = 0
50
+ """Number of denoising steps used."""
51
+
52
+ generation_time_s: float = 0.0
53
+ """Wall-clock generation time."""
54
+
55
+ model_name: str = ""
56
+ """Name of the model used."""
57
+
58
+ evidence_used: list[str] = field(default_factory=list)
59
+ """Evidence nodes that were provided as conditioning."""
60
+
61
+ confidence: float = 0.0
62
+ """Overall confidence of the generation."""
63
+
64
+ language: str = "id"
65
+ """Output language."""
66
+
67
+ def to_dict(self) -> dict:
68
+ """Serialize to dictionary."""
69
+ return {
70
+ "narrative": self.narrative,
71
+ "n_diffusion_steps": self.n_diffusion_steps,
72
+ "generation_time_s": round(self.generation_time_s, 3),
73
+ "model_name": self.model_name,
74
+ "evidence_used": self.evidence_used,
75
+ "confidence": round(self.confidence, 3),
76
+ "language": self.language,
77
+ }
78
+
79
+
80
+ class AamGenerator:
81
+ """Generate narratives from graph conditioning using the trained model.
82
+
83
+ This is the main inference interface. It takes graph-structured
84
+ data (from the RSVS Knowledge Graph) and produces natural
85
+ language narratives through the diffusion denoising process.
86
+
87
+ Usage:
88
+ # Load model and tokenizer
89
+ config = AamDiffusionConfig.from_json("config.json")
90
+ model = AamDiffusionModel.load("best.pt")
91
+ tokenizer = AamTokenizer.load("tokenizer.json")
92
+
93
+ # Create generator
94
+ generator = AamGenerator(model, tokenizer, config)
95
+
96
+ # Generate narrative
97
+ result = generator.generate(
98
+ trigger="Siapa yang mencuri Snow Plum Pill?",
99
+ evidence_nodes=["hefei", "diancang", "ju_jangmok"],
100
+ anomalies=["no external pill consumption"],
101
+ reasoning_steps=["Diancang pair was in Hefei before theft"],
102
+ )
103
+ print(result.narrative)
104
+
105
+ Args:
106
+ model: Trained AamDiffusionModel.
107
+ tokenizer: Trained AamTokenizer.
108
+ config: AamDiffusionConfig with inference settings.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ model: AamDiffusionModel,
114
+ tokenizer: AamTokenizer,
115
+ config: AamDiffusionConfig,
116
+ ):
117
+ self.model = model
118
+ self.tokenizer = tokenizer
119
+ self.config = config
120
+ self.inference_config = config.inference
121
+
122
+ # Device
123
+ self.device = next(model.parameters()).device
124
+
125
+ # Set model to eval mode
126
+ self.model.eval()
127
+
128
+ @torch.no_grad()
129
+ def generate(
130
+ self,
131
+ trigger: str = "",
132
+ evidence_nodes: Optional[list[str]] = None,
133
+ compositions: Optional[list[str]] = None,
134
+ confidence_map: Optional[dict[str, float]] = None,
135
+ anomalies: Optional[list[str]] = None,
136
+ reasoning_steps: Optional[list[str]] = None,
137
+ source_trust: float = 1.0,
138
+ n_steps: Optional[int] = None,
139
+ temperature: Optional[float] = None,
140
+ language: Optional[str] = None,
141
+ max_sentences: Optional[int] = None,
142
+ ) -> GenerationResult:
143
+ """Generate a narrative from graph conditioning.
144
+
145
+ This is the main generation method. It:
146
+ 1. Tokenizes the graph conditioning data
147
+ 2. Encodes it through the graph encoder
148
+ 3. Starts from noise and iteratively denoises
149
+ 4. Converts the result to text
150
+
151
+ Args:
152
+ trigger: The trigger question or topic.
153
+ evidence_nodes: Evidence node descriptions.
154
+ compositions: Composition descriptions.
155
+ confidence_map: Node confidence scores.
156
+ anomalies: Anomaly descriptions.
157
+ reasoning_steps: Reasoning step descriptions.
158
+ source_trust: Source trust score.
159
+ n_steps: Override number of denoising steps.
160
+ temperature: Override sampling temperature.
161
+ language: Override output language.
162
+ max_sentences: Maximum sentences in output.
163
+
164
+ Returns:
165
+ GenerationResult with the narrative and metadata.
166
+ """
167
+ start_time = time.time()
168
+
169
+ # Use config defaults if not overridden
170
+ n_steps = n_steps or self.inference_config.n_steps
171
+ temperature = temperature or self.inference_config.temperature
172
+ language = language or self.inference_config.language
173
+ max_sentences = max_sentences or self.inference_config.max_output_sentences
174
+
175
+ # --- Step 1: Tokenize graph conditioning ---
176
+ evidence_ids_tensor = None
177
+ evidence_conf_tensor = None
178
+ anomaly_ids_tensor = None
179
+ anomaly_conf_tensor = None
180
+ reasoning_ids_tensor = None
181
+ reasoning_conf_tensor = None
182
+
183
+ if evidence_nodes:
184
+ evidence_ids_list = []
185
+ evidence_conf_list = []
186
+ for node in evidence_nodes[:self.config.graph_encoder.max_evidence_nodes]:
187
+ ids = self.tokenizer.encode(node, add_special=False)
188
+ ids = self.tokenizer.pad_sequence(ids, 32)
189
+ evidence_ids_list.append(ids)
190
+ conf = (confidence_map or {}).get(node, 0.7)
191
+ evidence_conf_list.append(conf)
192
+
193
+ while len(evidence_ids_list) < self.config.graph_encoder.max_evidence_nodes:
194
+ evidence_ids_list.append([0] * 32)
195
+ evidence_conf_list.append(0.0)
196
+
197
+ evidence_ids_tensor = torch.tensor(
198
+ [evidence_ids_list], dtype=torch.long, device=self.device
199
+ )
200
+ evidence_conf_tensor = torch.tensor(
201
+ [evidence_conf_list], dtype=torch.float32, device=self.device
202
+ )
203
+
204
+ if anomalies:
205
+ anomaly_ids_list = []
206
+ for anom in anomalies[:self.config.graph_encoder.max_anomalies]:
207
+ ids = self.tokenizer.encode(anom, add_special=False)
208
+ ids = self.tokenizer.pad_sequence(ids, 32)
209
+ anomaly_ids_list.append(ids)
210
+
211
+ while len(anomaly_ids_list) < self.config.graph_encoder.max_anomalies:
212
+ anomaly_ids_list.append([0] * 32)
213
+
214
+ anomaly_ids_tensor = torch.tensor(
215
+ [anomaly_ids_list], dtype=torch.long, device=self.device
216
+ )
217
+ anomaly_conf_tensor = torch.full(
218
+ (1, self.config.graph_encoder.max_anomalies),
219
+ 0.6, dtype=torch.float32, device=self.device,
220
+ )
221
+
222
+ if reasoning_steps:
223
+ reasoning_ids_list = []
224
+ for step in reasoning_steps[:self.config.graph_encoder.max_reasoning_steps]:
225
+ ids = self.tokenizer.encode(step, add_special=False)
226
+ ids = self.tokenizer.pad_sequence(ids, 32)
227
+ reasoning_ids_list.append(ids)
228
+
229
+ while len(reasoning_ids_list) < self.config.graph_encoder.max_reasoning_steps:
230
+ reasoning_ids_list.append([0] * 32)
231
+
232
+ reasoning_ids_tensor = torch.tensor(
233
+ [reasoning_ids_list], dtype=torch.long, device=self.device
234
+ )
235
+ reasoning_conf_tensor = torch.full(
236
+ (1, self.config.graph_encoder.max_reasoning_steps),
237
+ 0.7, dtype=torch.float32, device=self.device,
238
+ )
239
+
240
+ source_trust_tensor = torch.tensor(
241
+ [source_trust], dtype=torch.float32, device=self.device
242
+ )
243
+
244
+ # --- Step 2: Encode graph conditioning ---
245
+ graph_cond = self.model.graph_encoder(
246
+ evidence_ids=evidence_ids_tensor,
247
+ evidence_confidence=evidence_conf_tensor,
248
+ anomaly_ids=anomaly_ids_tensor,
249
+ anomaly_confidence=anomaly_conf_tensor,
250
+ reasoning_ids=reasoning_ids_tensor,
251
+ reasoning_confidence=reasoning_conf_tensor,
252
+ source_trust=source_trust_tensor,
253
+ )
254
+
255
+ # --- Step 3: Generate via diffusion denoising ---
256
+ shape = (
257
+ 1,
258
+ self.config.model.max_seq_len,
259
+ self.config.model.d_model,
260
+ )
261
+
262
+ denoised = self.model.sample(
263
+ graph_cond=graph_cond,
264
+ n_steps=n_steps,
265
+ method=self.config.diffusion.sampling_method,
266
+ shape=shape,
267
+ device=self.device,
268
+ )
269
+
270
+ # --- Step 4: Convert to tokens ---
271
+ token_ids = self.model.embeddings_to_tokens(
272
+ denoised, temperature=temperature,
273
+ top_k=self.inference_config.top_k,
274
+ )
275
+
276
+ # --- Step 5: Detokenize ---
277
+ token_list = token_ids[0].cpu().tolist()
278
+ narrative = self.tokenizer.decode(token_list, skip_special=True)
279
+
280
+ # Truncate to max sentences
281
+ if max_sentences:
282
+ sentences = self.tokenizer._split_sentences(narrative)
283
+ if len(sentences) > max_sentences:
284
+ narrative = ". ".join(sentences[:max_sentences]) + "."
285
+
286
+ generation_time = time.time() - start_time
287
+
288
+ # Compute average confidence
289
+ avg_confidence = source_trust
290
+ if confidence_map:
291
+ avg_confidence = sum(confidence_map.values()) / len(confidence_map)
292
+
293
+ return GenerationResult(
294
+ narrative=narrative,
295
+ token_ids=token_list,
296
+ n_diffusion_steps=n_steps,
297
+ generation_time_s=generation_time,
298
+ model_name=self.config.model_name,
299
+ evidence_used=evidence_nodes or [],
300
+ confidence=avg_confidence,
301
+ language=language,
302
+ )
303
+
304
+ def generate_batch(
305
+ self,
306
+ triggers: list[str],
307
+ evidence_nodes_list: Optional[list[list[str]]] = None,
308
+ anomalies_list: Optional[list[list[str]]] = None,
309
+ **kwargs,
310
+ ) -> list[GenerationResult]:
311
+ """Generate narratives for multiple triggers.
312
+
313
+ Args:
314
+ triggers: List of trigger questions.
315
+ evidence_nodes_list: List of evidence node lists.
316
+ anomalies_list: List of anomaly lists.
317
+ **kwargs: Additional arguments passed to generate().
318
+
319
+ Returns:
320
+ List of GenerationResult objects.
321
+ """
322
+ results = []
323
+ for i, trigger in enumerate(triggers):
324
+ evidence = evidence_nodes_list[i] if evidence_nodes_list else None
325
+ anomalies = anomalies_list[i] if anomalies_list else None
326
+ result = self.generate(
327
+ trigger=trigger,
328
+ evidence_nodes=evidence,
329
+ anomalies=anomalies,
330
+ **kwargs,
331
+ )
332
+ results.append(result)
333
+ return results
diffusion_llm/model/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model components for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.model.noise_scheduler import NoiseScheduler
4
+ from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
5
+ from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
6
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
7
+
8
+ __all__ = [
9
+ "NoiseScheduler",
10
+ "GraphConditioningEncoder",
11
+ "DiffusionTransformer",
12
+ "AamDiffusionModel",
13
+ ]
diffusion_llm/model/aam_diffusion_model.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Complete Model
3
+
4
+ Combines the Diffusion Transformer, Graph Encoder, and Noise Scheduler
5
+ into a single, unified model for training and inference.
6
+
7
+ This is the "body" of AAM — the specialized sentence composer that
8
+ takes graph conditioning as input and produces coherent narratives
9
+ through iterative denoising.
10
+
11
+ Architecture:
12
+ ┌──────────────────────────────────────────────────┐
13
+ │ AAM Diffusion Model (The Body) │
14
+ │ │
15
+ │ Input: │
16
+ │ - Token IDs (text) │
17
+ │ - Graph conditioning (evidence, compositions, │
18
+ │ confidence, anomalies, reasoning chains) │
19
+ │ │
20
+ │ Training Process: │
21
+ │ 1. Tokenize text → embeddings │
22
+ │ 2. Sample random timestep t │
23
+ │ 3. Add noise: x_t = schedule.add_noise(x_0, t) │
24
+ │ 4. Encode graph conditioning │
25
+ │ 5. Predict noise: eps = transformer(x_t, t, c) │
26
+ │ 6. Compute loss: L = MSE(eps, eps_target) │
27
+ │ │
28
+ │ Inference Process: │
29
+ │ 1. Start from pure noise x_T │
30
+ │ 2. Encode graph conditioning │
31
+ │ 3. For t = T, T-1, ..., 1: │
32
+ │ a. Predict noise: eps = transformer(x_t, t) │
33
+ │ b. Denoise: x_{t-1} = schedule.step(eps) │
34
+ │ 4. Decode final x_0 → text tokens │
35
+ │ 5. Detokenize → natural language narrative │
36
+ │ │
37
+ │ Key Constraint: │
38
+ │ The model CANNOT generate information not │
39
+ │ present in the graph conditioning. It can only │
40
+ │ ARRANGE what the graph knows into sentences. │
41
+ │ │
42
+ │ Analogi: Jin Soun (mind/graph) + tubuhnya │
43
+ │ (this model). Tubuhnya hanya bisa mengucapkan │
44
+ │ apa yang dipikirkannya — tidak bisa mengarang. │
45
+ └──────────────────────────────────────────────────┘
46
+
47
+ Analogi: Ini adalah seluruh "tubuh" Jin Soun — bukan hanya
48
+ ototnya (transformer), tapi juga sistem saraf (graph encoder)
49
+ dan kemampuan untuk memperbaiki diri (diffusion denoising).
50
+ """
51
+
52
+ from __future__ import annotations
53
+
54
+ import logging
55
+ from typing import Optional
56
+
57
+ import torch
58
+ import torch.nn as nn
59
+
60
+ from diffusion_llm.config.model_config import AamDiffusionConfig
61
+ from diffusion_llm.model.noise_scheduler import NoiseScheduler
62
+ from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
63
+ from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+
68
+ class AamDiffusionModel(nn.Module):
69
+ """Complete AAM Diffusion LLM model.
70
+
71
+ Combines:
72
+ - DiffusionTransformer: Core denoising network
73
+ - GraphConditioningEncoder: Encodes graph structure for conditioning
74
+ - NoiseScheduler: Manages the diffusion process
75
+
76
+ This model is designed to be trained on Graph→Narrative pairs,
77
+ where the graph data comes from the RSVS Knowledge Graph and
78
+ the narrative is the target natural language output.
79
+
80
+ Args:
81
+ config: AamDiffusionConfig with all hyperparameters.
82
+ """
83
+
84
+ def __init__(self, config: AamDiffusionConfig):
85
+ super().__init__()
86
+ self.config = config
87
+
88
+ # Core components
89
+ self.noise_scheduler = NoiseScheduler(
90
+ n_timesteps=config.diffusion.n_timesteps,
91
+ schedule_type=config.diffusion.schedule_type,
92
+ beta_start=config.diffusion.beta_start,
93
+ beta_end=config.diffusion.beta_end,
94
+ prediction_type=config.diffusion.prediction_type,
95
+ )
96
+
97
+ self.graph_encoder = GraphConditioningEncoder(
98
+ config=config.graph_encoder,
99
+ vocab_size=config.model.vocab_size,
100
+ )
101
+ # Align graph encoder output dim with transformer's d_model
102
+ self.graph_encoder.set_output_dim(config.model.d_model)
103
+
104
+ self.transformer = DiffusionTransformer(config.model)
105
+
106
+ # Token-to-embedding projection (shared with transformer)
107
+ # The transformer's token_embedding is used for both
108
+ # encoding input text and decoding output text
109
+
110
+ # Output head: project from d_model to vocab_size
111
+ self.lm_head = nn.Linear(
112
+ config.model.d_model, config.model.vocab_size, bias=False
113
+ )
114
+
115
+ # Tie weights between token embedding and LM head
116
+ # This is standard practice and reduces parameter count
117
+ self.lm_head.weight = self.transformer.token_embedding.weight
118
+
119
+ # EMA model (for inference, updated during training)
120
+ self._ema_model: Optional[AamDiffusionModel] = None
121
+ self._ema_decay = config.training.ema_decay
122
+
123
+ logger.info(
124
+ "AamDiffusionModel initialized: %s params, %s",
125
+ self._format_params(self.get_num_params()),
126
+ config.model_name,
127
+ )
128
+
129
+ def forward(
130
+ self,
131
+ token_ids: torch.Tensor,
132
+ timestep: torch.Tensor,
133
+ evidence_ids: Optional[torch.Tensor] = None,
134
+ evidence_confidence: Optional[torch.Tensor] = None,
135
+ evidence_timestamps: Optional[torch.Tensor] = None,
136
+ composition_ids: Optional[torch.Tensor] = None,
137
+ composition_confidence: Optional[torch.Tensor] = None,
138
+ anomaly_ids: Optional[torch.Tensor] = None,
139
+ anomaly_confidence: Optional[torch.Tensor] = None,
140
+ anomaly_timestamps: Optional[torch.Tensor] = None,
141
+ reasoning_ids: Optional[torch.Tensor] = None,
142
+ reasoning_confidence: Optional[torch.Tensor] = None,
143
+ source_trust: Optional[torch.Tensor] = None,
144
+ ) -> torch.Tensor:
145
+ """Forward pass for training.
146
+
147
+ 1. Get clean embeddings from token IDs
148
+ 2. Add noise at the given timestep
149
+ 3. Encode graph conditioning
150
+ 4. Predict noise via transformer
151
+ 5. Return predicted noise (loss computed externally)
152
+
153
+ Args:
154
+ token_ids: Clean text token IDs, shape (batch, seq_len).
155
+ timestep: Random timestep indices, shape (batch,).
156
+ evidence_ids: Evidence node token IDs.
157
+ evidence_confidence: Evidence confidence scores.
158
+ evidence_timestamps: Evidence timestamps.
159
+ composition_ids: Composition token IDs.
160
+ composition_confidence: Composition confidence.
161
+ anomaly_ids: Anomaly token IDs.
162
+ anomaly_confidence: Anomaly confidence.
163
+ anomaly_timestamps: Anomaly timestamps.
164
+ reasoning_ids: Reasoning step token IDs.
165
+ reasoning_confidence: Reasoning confidence.
166
+ source_trust: Source trust score.
167
+
168
+ Returns:
169
+ Predicted noise tensor of shape (batch, seq_len, d_model).
170
+ """
171
+ # Step 1: Get clean embeddings (x_0)
172
+ x_0 = self.transformer.token_embedding(token_ids)
173
+
174
+ # Step 2: Add noise
175
+ noise = torch.randn_like(x_0)
176
+ x_t = self.noise_scheduler.add_noise(x_0, noise, timestep)
177
+
178
+ # Step 3: Encode graph conditioning
179
+ batch_size = token_ids.shape[0]
180
+ graph_cond = self.graph_encoder(
181
+ evidence_ids=evidence_ids,
182
+ evidence_confidence=evidence_confidence,
183
+ evidence_timestamps=evidence_timestamps,
184
+ composition_ids=composition_ids,
185
+ composition_confidence=composition_confidence,
186
+ anomaly_ids=anomaly_ids,
187
+ anomaly_confidence=anomaly_confidence,
188
+ anomaly_timestamps=anomaly_timestamps,
189
+ reasoning_ids=reasoning_ids,
190
+ reasoning_confidence=reasoning_confidence,
191
+ source_trust=source_trust,
192
+ batch_size=batch_size,
193
+ )
194
+
195
+ # Extract cross-attention keys/values from graph conditioning
196
+ graph_keys = graph_cond.get("keys")
197
+ graph_values = graph_cond.get("values")
198
+
199
+ # Step 4: Predict noise via transformer
200
+ predicted = self.transformer(
201
+ x_t=x_t,
202
+ t=timestep,
203
+ graph_keys=graph_keys,
204
+ graph_values=graph_values,
205
+ )
206
+
207
+ return predicted, noise
208
+
209
+ def compute_loss(
210
+ self,
211
+ predicted: torch.Tensor,
212
+ target: torch.Tensor,
213
+ timestep: torch.Tensor,
214
+ ) -> torch.Tensor:
215
+ """Compute diffusion training loss.
216
+
217
+ Supports different loss types and weighting strategies.
218
+
219
+ Args:
220
+ predicted: Model output (predicted noise/x0/v).
221
+ target: Target (actual noise/x0/v).
222
+ timestep: Timestep indices for loss weighting.
223
+
224
+ Returns:
225
+ Scalar loss value.
226
+ """
227
+ # Base loss
228
+ if self.config.diffusion.loss_type == "mse":
229
+ loss = nn.functional.mse_loss(predicted, target, reduction="none")
230
+ elif self.config.diffusion.loss_type == "mae":
231
+ loss = nn.functional.l1_loss(predicted, target, reduction="none")
232
+ elif self.config.diffusion.loss_type == "huber":
233
+ loss = nn.functional.smooth_l1_loss(predicted, target, reduction="none")
234
+ else:
235
+ raise ValueError(f"Unknown loss_type: {self.config.diffusion.loss_type}")
236
+
237
+ # Average over feature dimension
238
+ loss = loss.mean(dim=-1) # (batch, seq_len)
239
+
240
+ # Apply loss weighting
241
+ if self.config.diffusion.loss_weighting == "min_snr":
242
+ loss = self._apply_min_snr_weighting(loss, timestep)
243
+ elif self.config.diffusion.loss_weighting == "p2":
244
+ loss = self._apply_p2_weighting(loss, timestep)
245
+
246
+ # Average over sequence and batch
247
+ return loss.mean()
248
+
249
+ def _apply_min_snr_weighting(
250
+ self,
251
+ loss: torch.Tensor,
252
+ timestep: torch.Tensor,
253
+ gamma: float = 5.0,
254
+ ) -> torch.Tensor:
255
+ """Apply Min-SNR weighting strategy.
256
+
257
+ Weights the loss by min(SNR, gamma) / SNR, where
258
+ SNR = alpha_bar / (1 - alpha_bar).
259
+
260
+ This helps balance the loss across timesteps, preventing
261
+ high-noise steps from dominating.
262
+
263
+ Args:
264
+ loss: Unweighted loss.
265
+ timestep: Timestep indices.
266
+ gamma: SNR clipping value.
267
+
268
+ Returns:
269
+ Weighted loss.
270
+ """
271
+ alpha_bar = self.noise_scheduler.alphas_cumprod.to(loss.device)
272
+ snr = alpha_bar[timestep] / (1 - alpha_bar[timestep] + 1e-8)
273
+ weight = torch.clamp(snr, max=gamma) / (snr + 1e-8)
274
+ # Expand weight to match loss shape
275
+ weight = weight.unsqueeze(-1).expand_as(loss)
276
+ return loss * weight
277
+
278
+ def _apply_p2_weighting(
279
+ self,
280
+ loss: torch.Tensor,
281
+ timestep: torch.Tensor,
282
+ ) -> torch.Tensor:
283
+ """Apply P2 weighting strategy.
284
+
285
+ weight = 1 / (SNR^gamma + k)
286
+
287
+ Args:
288
+ loss: Unweighted loss.
289
+ timestep: Timestep indices.
290
+
291
+ Returns:
292
+ Weighted loss.
293
+ """
294
+ alpha_bar = self.noise_scheduler.alphas_cumprod.to(loss.device)
295
+ snr = alpha_bar[timestep] / (1 - alpha_bar[timestep] + 1e-8)
296
+ gamma = self.config.diffusion.p2_gamma
297
+ k = self.config.diffusion.p2_k
298
+ weight = 1.0 / (snr ** gamma + k)
299
+ weight = weight.unsqueeze(-1).expand_as(loss)
300
+ return loss * weight
301
+
302
+ @torch.no_grad()
303
+ def sample(
304
+ self,
305
+ graph_cond: dict[str, torch.Tensor],
306
+ n_steps: Optional[int] = None,
307
+ method: str = "ddim",
308
+ shape: Optional[tuple[int, ...]] = None,
309
+ device: Optional[torch.device] = None,
310
+ ) -> torch.Tensor:
311
+ """Generate samples via iterative denoising.
312
+
313
+ This is the INFERENCE method — start from pure noise and
314
+ iteratively denoise to produce coherent text embeddings.
315
+
316
+ Args:
317
+ graph_cond: Graph conditioning dict from GraphConditioningEncoder.
318
+ n_steps: Number of denoising steps. Uses config if None.
319
+ method: Sampling method ('ddpm' or 'ddim').
320
+ shape: Shape of the output (batch, seq_len, d_model).
321
+ device: Device to generate on.
322
+
323
+ Returns:
324
+ Denoised embeddings of shape (batch, seq_len, d_model).
325
+ """
326
+ if n_steps is None:
327
+ n_steps = self.config.diffusion.n_inference_steps
328
+ if device is None:
329
+ device = next(self.parameters()).device
330
+ if shape is None:
331
+ shape = (1, self.config.model.max_seq_len, self.config.model.d_model)
332
+
333
+ # Start from pure noise
334
+ x = torch.randn(shape, device=device)
335
+
336
+ # Get graph conditioning
337
+ graph_keys = graph_cond.get("keys")
338
+ graph_values = graph_cond.get("values")
339
+
340
+ if method == "ddpm":
341
+ # Full DDPM sampling
342
+ for t in reversed(range(self.config.diffusion.n_timesteps)):
343
+ t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
344
+ predicted = self.transformer(
345
+ x_t=x, t=t_tensor,
346
+ graph_keys=graph_keys,
347
+ graph_values=graph_values,
348
+ )
349
+ x = self.noise_scheduler.step_ddpm(predicted, x, t_tensor)
350
+
351
+ elif method == "ddim":
352
+ # Fast DDIM sampling
353
+ timesteps = self.noise_scheduler.get_timestep_schedule(n_steps)
354
+ for i in range(len(timesteps) - 1):
355
+ t = timesteps[i]
356
+ t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else 0
357
+
358
+ t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
359
+ predicted = self.transformer(
360
+ x_t=x, t=t_tensor,
361
+ graph_keys=graph_keys,
362
+ graph_values=graph_values,
363
+ )
364
+ x = self.noise_scheduler.step_ddim(
365
+ predicted, x, t, t_prev,
366
+ eta=self.config.diffusion.eta_ddim,
367
+ )
368
+
369
+ return x
370
+
371
+ def embeddings_to_tokens(
372
+ self,
373
+ embeddings: torch.Tensor,
374
+ temperature: float = 1.0,
375
+ top_k: int = 50,
376
+ ) -> torch.Tensor:
377
+ """Convert continuous embeddings to discrete token IDs.
378
+
379
+ This is the final step of generation — project embeddings
380
+ to vocabulary logits and sample tokens.
381
+
382
+ Args:
383
+ embeddings: Denoised embeddings of shape (batch, seq_len, d_model).
384
+ temperature: Sampling temperature.
385
+ top_k: Top-k sampling cutoff.
386
+
387
+ Returns:
388
+ Token IDs of shape (batch, seq_len).
389
+ """
390
+ logits = self.lm_head(embeddings) / temperature
391
+
392
+ # Top-k sampling
393
+ if top_k > 0:
394
+ top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
395
+ probs = torch.softmax(top_k_values, dim=-1)
396
+ sampled_indices = torch.multinomial(
397
+ probs.view(-1, top_k), 1
398
+ ).view(logits.shape[0], logits.shape[1])
399
+ token_ids = top_k_indices.gather(
400
+ -1, sampled_indices.unsqueeze(-1)
401
+ ).squeeze(-1)
402
+ else:
403
+ probs = torch.softmax(logits, dim=-1)
404
+ token_ids = torch.argmax(logits, dim=-1)
405
+
406
+ return token_ids
407
+
408
+ def get_num_params(self) -> int:
409
+ """Get total number of parameters."""
410
+ return sum(p.numel() for p in self.parameters())
411
+
412
+ @staticmethod
413
+ def _format_params(n: int) -> str:
414
+ """Format parameter count for display."""
415
+ if n >= 1e9:
416
+ return f"{n / 1e9:.1f}B"
417
+ elif n >= 1e6:
418
+ return f"{n / 1e6:.1f}M"
419
+ elif n >= 1e3:
420
+ return f"{n / 1e3:.1f}K"
421
+ return str(n)
422
+
423
+ def save(self, path: str) -> None:
424
+ """Save model checkpoint.
425
+
426
+ Args:
427
+ path: Output file path.
428
+ """
429
+ torch.save({
430
+ "model_state_dict": self.state_dict(),
431
+ "config": self.config.to_dict(),
432
+ }, path)
433
+ logger.info("Model saved to %s", path)
434
+
435
+ @classmethod
436
+ def load(cls, path: str, device: str = "cpu") -> AamDiffusionModel:
437
+ """Load model from checkpoint.
438
+
439
+ Args:
440
+ path: Checkpoint file path.
441
+ device: Device to load to.
442
+
443
+ Returns:
444
+ Loaded AamDiffusionModel.
445
+ """
446
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
447
+ config_dict = checkpoint.get("config", {})
448
+ if isinstance(config_dict, dict):
449
+ config = AamDiffusionConfig()
450
+ # Try to reconstruct config from dict
451
+ try:
452
+ from diffusion_llm.config.model_config import (
453
+ ModelConfig, DiffusionConfig, GraphEncoderConfig,
454
+ TokenizerConfig, TrainingConfig, InferenceConfig,
455
+ )
456
+ config = AamDiffusionConfig(
457
+ model=ModelConfig(**config_dict.get("model", {})),
458
+ diffusion=DiffusionConfig(**config_dict.get("diffusion", {})),
459
+ graph_encoder=GraphEncoderConfig(**config_dict.get("graph_encoder", {})),
460
+ tokenizer=TokenizerConfig(**config_dict.get("tokenizer", {})),
461
+ training=TrainingConfig(**config_dict.get("training", {})),
462
+ inference=InferenceConfig(**config_dict.get("inference", {})),
463
+ model_name=config_dict.get("model_name", "aam-diffusion-v0.1"),
464
+ output_dir=config_dict.get("output_dir", "./output"),
465
+ seed=config_dict.get("seed", 42),
466
+ )
467
+ except Exception:
468
+ logger.warning("Could not reconstruct config from checkpoint, using defaults")
469
+ else:
470
+ config = config_dict
471
+ model = cls(config)
472
+ model.load_state_dict(checkpoint["model_state_dict"])
473
+ model.to(device)
474
+ logger.info("Model loaded from %s", path)
475
+ return model
diffusion_llm/model/diffusion_transformer.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Diffusion Transformer (Denoiser)
3
+
4
+ The core denoising network. Takes noisy text embeddings and graph
5
+ conditioning, and predicts the noise (or clean data) at each
6
+ diffusion timestep.
7
+
8
+ Architecture:
9
+ Input: Noisy embeddings x_t + timestep t + graph conditioning
10
+ Output: Predicted noise epsilon (or x_0 or v)
11
+
12
+ The transformer uses:
13
+ - Self-attention over the text sequence
14
+ - Cross-attention to graph conditioning (evidence, anomalies, etc.)
15
+ - Timestep embedding (sinusoidal) injected via adaptive layer norm
16
+ - Optional flash attention for efficiency
17
+
18
+ This is the "brainstem" of the body — the core computation that
19
+ transforms noisy signals into coherent patterns.
20
+
21
+ Analogi: Seperti otot Jin Soun yang merespons sinyal dari otak —
22
+ model ini menerima "sinyal noise" dan "instruksi dari graph",
23
+ lalu mengubahnya menjadi gerakan yang koheren (kalimat).
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import math
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ from diffusion_llm.config.model_config import ModelConfig
36
+
37
+
38
+ class SinusoidalTimestepEmbedding(nn.Module):
39
+ """Sinusoidal embedding for diffusion timesteps.
40
+
41
+ Maps integer timesteps to d_model-dimensional vectors using
42
+ sinusoidal position encoding, similar to Transformers.
43
+
44
+ This allows the model to know "how noisy" the current input is,
45
+ which is essential for the denoising process.
46
+ """
47
+
48
+ def __init__(self, d_model: int, max_period: int = 10000):
49
+ super().__init__()
50
+ self.d_model = d_model
51
+ self.max_period = max_period
52
+
53
+ # Two-layer MLP to project sinusoidal features
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(d_model, d_model * 4),
56
+ nn.GELU(),
57
+ nn.Linear(d_model * 4, d_model),
58
+ )
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ """Embed timesteps.
62
+
63
+ Args:
64
+ t: Timestep indices of shape (batch,).
65
+
66
+ Returns:
67
+ Timestep embeddings of shape (batch, d_model).
68
+ """
69
+ device = t.device
70
+ half_dim = self.d_model // 2
71
+ emb = math.log(self.max_period) / (half_dim - 1)
72
+ emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
73
+ emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
74
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
75
+
76
+ if emb.shape[-1] < self.d_model:
77
+ emb = F.pad(emb, (0, self.d_model - emb.shape[-1]))
78
+
79
+ return self.mlp(emb)
80
+
81
+
82
+ class AdaptiveLayerNorm(nn.Module):
83
+ """Adaptive Layer Normalization conditioned on timestep.
84
+
85
+ Instead of fixed scale/shift parameters, this layer norm
86
+ uses the timestep embedding to produce scale and shift:
87
+
88
+ y = (1 + scale(t)) * norm(x) + shift(t)
89
+
90
+ This allows the model to behave differently at different
91
+ noise levels — more "creative" at high noise, more
92
+ "precise" at low noise.
93
+
94
+ Analogi: Jin Soun menyesuaikan intensitas pikirannya
95
+ berdasarkan seberapa kabur situasinya — semakin kabur,
96
+ semakin "kreatif" pendekatannya.
97
+ """
98
+
99
+ def __init__(self, d_model: int, eps: float = 1e-6):
100
+ super().__init__()
101
+ self.norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=eps)
102
+ self.scale_proj = nn.Linear(d_model, d_model)
103
+ self.shift_proj = nn.Linear(d_model, d_model)
104
+
105
+ # Initialize shift to zero, scale to one
106
+ nn.init.zeros_(self.shift_proj.weight)
107
+ nn.init.zeros_(self.shift_proj.bias)
108
+ nn.init.ones_(self.scale_proj.weight)
109
+ nn.init.zeros_(self.scale_proj.bias)
110
+
111
+ def forward(
112
+ self,
113
+ x: torch.Tensor,
114
+ timestep_emb: torch.Tensor,
115
+ ) -> torch.Tensor:
116
+ """Apply adaptive layer norm.
117
+
118
+ Args:
119
+ x: Input tensor of shape (batch, seq_len, d_model).
120
+ timestep_emb: Timestep embedding of shape (batch, d_model).
121
+
122
+ Returns:
123
+ Normalized and modulated tensor.
124
+ """
125
+ normalized = self.norm(x)
126
+ scale = (1 + self.scale_proj(timestep_emb)).unsqueeze(1)
127
+ shift = self.shift_proj(timestep_emb).unsqueeze(1)
128
+ return normalized * scale + shift
129
+
130
+
131
+ class TransformerBlock(nn.Module):
132
+ """Single transformer block with self-attention, cross-attention, and FFN.
133
+
134
+ The block structure:
135
+ 1. Adaptive Layer Norm + Self-Attention
136
+ 2. Adaptive Layer Norm + Cross-Attention (to graph conditioning)
137
+ 3. Adaptive Layer Norm + Feed-Forward Network
138
+
139
+ Each sub-layer has a residual connection.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ d_model: int,
145
+ n_heads: int,
146
+ d_ff: int,
147
+ dropout: float = 0.1,
148
+ norm_eps: float = 1e-6,
149
+ norm_type: str = "rmsnorm",
150
+ use_flash_attention: bool = True,
151
+ ):
152
+ super().__init__()
153
+ self.d_model = d_model
154
+ self.n_heads = n_heads
155
+
156
+ # Norms
157
+ NormClass = nn.RMSNorm if norm_type == "rmsnorm" else nn.LayerNorm
158
+
159
+ # Self-attention
160
+ self.self_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
161
+ self.self_attn = nn.MultiheadAttention(
162
+ embed_dim=d_model,
163
+ num_heads=n_heads,
164
+ dropout=dropout,
165
+ batch_first=True,
166
+ )
167
+ self.self_attn_dropout = nn.Dropout(dropout)
168
+
169
+ # Cross-attention (to graph conditioning)
170
+ self.cross_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
171
+ self.cross_attn = nn.MultiheadAttention(
172
+ embed_dim=d_model,
173
+ num_heads=n_heads,
174
+ dropout=dropout,
175
+ batch_first=True,
176
+ kdim=d_model,
177
+ vdim=d_model,
178
+ )
179
+ self.cross_attn_dropout = nn.Dropout(dropout)
180
+
181
+ # Feed-forward
182
+ self.ff_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
183
+ self.ff = nn.Sequential(
184
+ nn.Linear(d_model, d_ff),
185
+ nn.GELU(),
186
+ nn.Dropout(dropout),
187
+ nn.Linear(d_ff, d_model),
188
+ nn.Dropout(dropout),
189
+ )
190
+
191
+ # Layer scales (optional, helps with deep networks)
192
+ self.self_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
193
+ self.cross_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
194
+ self.ff_scale = nn.Parameter(torch.ones(1) * 0.1)
195
+
196
+ def forward(
197
+ self,
198
+ x: torch.Tensor,
199
+ timestep_emb: torch.Tensor,
200
+ graph_keys: Optional[torch.Tensor] = None,
201
+ graph_values: Optional[torch.Tensor] = None,
202
+ causal_mask: Optional[torch.Tensor] = None,
203
+ ) -> torch.Tensor:
204
+ """Forward pass.
205
+
206
+ Args:
207
+ x: Input sequence of shape (batch, seq_len, d_model).
208
+ timestep_emb: Timestep embedding of shape (batch, d_model).
209
+ graph_keys: Graph conditioning keys for cross-attention,
210
+ shape (batch, n_graph_nodes, d_model).
211
+ graph_values: Graph conditioning values for cross-attention,
212
+ shape (batch, n_graph_nodes, d_model).
213
+ causal_mask: Optional causal mask for self-attention.
214
+
215
+ Returns:
216
+ Output sequence of shape (batch, seq_len, d_model).
217
+ """
218
+ # 1. Self-attention with adaptive layer norm
219
+ normed = self.self_attn_norm(x, timestep_emb)
220
+ attn_out, _ = self.self_attn(
221
+ normed, normed, normed,
222
+ attn_mask=causal_mask,
223
+ need_weights=False,
224
+ )
225
+ x = x + self.self_attn_scale * self.self_attn_dropout(attn_out)
226
+
227
+ # 2. Cross-attention to graph conditioning (if available)
228
+ if graph_keys is not None and graph_values is not None:
229
+ normed = self.cross_attn_norm(x, timestep_emb)
230
+ cross_out, _ = self.cross_attn(
231
+ normed, graph_keys, graph_values,
232
+ need_weights=False,
233
+ )
234
+ x = x + self.cross_attn_scale * self.cross_attn_dropout(cross_out)
235
+
236
+ # 3. Feed-forward with adaptive layer norm
237
+ normed = self.ff_norm(x, timestep_emb)
238
+ ff_out = self.ff(normed)
239
+ x = x + self.ff_scale * ff_out
240
+
241
+ return x
242
+
243
+
244
+ class DiffusionTransformer(nn.Module):
245
+ """Diffusion Transformer — the core denoising network for AAM.
246
+
247
+ This transformer takes:
248
+ - Noisy text embeddings (x_t)
249
+ - Diffusion timestep (t)
250
+ - Graph conditioning (evidence, anomalies, reasoning chains)
251
+
252
+ And predicts the noise that was added (or the clean data,
253
+ depending on prediction_type).
254
+
255
+ Architecture Overview:
256
+ ┌────────────────────────────────────────────────┐
257
+ │ Input Embedding: x_t (noisy) → embedding │
258
+ │ + Positional Encoding (RoPE or learned) │
259
+ │ │
260
+ │ N x TransformerBlock: │
261
+ │ ├─ AdaLN + Self-Attention │
262
+ │ ├─ AdaLN + Cross-Attention (to graph) │
263
+ │ └─ AdaLN + Feed-Forward │
264
+ │ │
265
+ │ Output Projection: → predicted noise │
266
+ └────────────────────────────────────────────────┘
267
+
268
+ Key Features:
269
+ - Adaptive Layer Norm: timestep-conditioned normalization
270
+ - Cross-Attention: graph conditioning guides generation
271
+ - Layer Scales: helps training deep networks
272
+ - RoPE: better length generalization than learned positions
273
+
274
+ Args:
275
+ config: ModelConfig with architecture hyperparameters.
276
+ """
277
+
278
+ def __init__(self, config: ModelConfig):
279
+ super().__init__()
280
+ self.config = config
281
+
282
+ # Input embedding (from token IDs to d_model)
283
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
284
+
285
+ # Timestep embedding
286
+ self.timestep_embedding = SinusoidalTimestepEmbedding(config.d_model)
287
+
288
+ # Positional encoding
289
+ if config.pos_encoding_type == "learned":
290
+ self.position_embedding = nn.Embedding(
291
+ config.max_seq_len, config.d_model
292
+ )
293
+ else:
294
+ # RoPE is applied inside attention (no separate embedding)
295
+ self.position_embedding = None
296
+
297
+ # Transformer blocks
298
+ self.blocks = nn.ModuleList([
299
+ TransformerBlock(
300
+ d_model=config.d_model,
301
+ n_heads=config.n_heads,
302
+ d_ff=config.d_ff,
303
+ dropout=config.dropout,
304
+ norm_eps=config.norm_eps,
305
+ norm_type=config.norm_type,
306
+ use_flash_attention=config.use_flash_attention,
307
+ )
308
+ for _ in range(config.n_layers)
309
+ ])
310
+
311
+ # Final norm
312
+ NormClass = nn.RMSNorm if config.norm_type == "rmsnorm" else nn.LayerNorm
313
+ self.final_norm = NormClass(config.d_model, eps=config.norm_eps)
314
+
315
+ # Output projection (predict noise/x0/v)
316
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
317
+
318
+ # Initialize weights
319
+ self.apply(self._init_weights)
320
+
321
+ def _init_weights(self, module: nn.Module) -> None:
322
+ """Initialize weights with Xavier/GPT-2 style."""
323
+ if isinstance(module, nn.Linear):
324
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
325
+ if module.bias is not None:
326
+ torch.nn.init.zeros_(module.bias)
327
+ elif isinstance(module, nn.Embedding):
328
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
329
+
330
+ def forward(
331
+ self,
332
+ x_t: torch.Tensor,
333
+ t: torch.Tensor,
334
+ token_ids: Optional[torch.Tensor] = None,
335
+ graph_keys: Optional[torch.Tensor] = None,
336
+ graph_values: Optional[torch.Tensor] = None,
337
+ ) -> torch.Tensor:
338
+ """Forward pass: predict noise given noisy input and timestep.
339
+
340
+ Args:
341
+ x_t: Noisy text embeddings of shape (batch, seq_len, d_model).
342
+ If None, token_ids must be provided.
343
+ t: Timestep indices of shape (batch,).
344
+ token_ids: Token IDs of shape (batch, seq_len).
345
+ Used to create embeddings if x_t is not provided directly.
346
+ In training, x_t comes from the noise scheduler.
347
+ graph_keys: Graph conditioning keys for cross-attention,
348
+ shape (batch, n_graph_nodes, d_model).
349
+ graph_values: Graph conditioning values for cross-attention,
350
+ shape (batch, n_graph_nodes, d_model).
351
+
352
+ Returns:
353
+ Predicted noise of shape (batch, seq_len, d_model).
354
+ """
355
+ # Get input embeddings
356
+ if x_t is None and token_ids is not None:
357
+ # Create embeddings from token IDs (used for initial x_0)
358
+ h = self.token_embedding(token_ids)
359
+ elif x_t is not None:
360
+ h = x_t
361
+ else:
362
+ raise ValueError("Either x_t or token_ids must be provided")
363
+
364
+ # Add positional encoding
365
+ if self.position_embedding is not None:
366
+ seq_len = h.shape[1]
367
+ positions = torch.arange(seq_len, device=h.device).unsqueeze(0)
368
+ h = h + self.position_embedding(positions)
369
+
370
+ # Embed timestep
371
+ t_emb = self.timestep_embedding(t)
372
+
373
+ # Pass through transformer blocks
374
+ for block in self.blocks:
375
+ h = block(
376
+ h,
377
+ timestep_emb=t_emb,
378
+ graph_keys=graph_keys,
379
+ graph_values=graph_values,
380
+ )
381
+
382
+ # Final norm and projection
383
+ h = self.final_norm(h)
384
+ output = self.output_proj(h)
385
+
386
+ return output
387
+
388
+ def get_num_params(self) -> int:
389
+ """Get total number of parameters."""
390
+ return sum(p.numel() for p in self.parameters())
391
+
392
+ def get_num_trainable_params(self) -> int:
393
+ """Get number of trainable parameters."""
394
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
diffusion_llm/model/graph_encoder.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Graph Conditioning Encoder
3
+
4
+ Encodes structured graph data into a conditioning vector that guides
5
+ the diffusion process. This is the KEY differentiator from general LLMs:
6
+ the model is conditioned on GRAPH STRUCTURE, not just text prompts.
7
+
8
+ The graph encoder takes:
9
+ - Evidence nodes (what the graph knows)
10
+ - Compositions (how concepts compose)
11
+ - Confidence scores (how sure the graph is)
12
+ - Anomalies (what doesn't fit)
13
+ - Reasoning chains (how the graph reached conclusions)
14
+ - Temporal context (when events happened)
15
+
16
+ And produces a conditioning representation that the diffusion model
17
+ uses to guide denoising.
18
+
19
+ Analogi: Seperti otak Jin Soun mengirimkan sinyal ke pita suaranya —
20
+ graph memberi "tahu" apa yang harus dikatakan, dan encoder ini
21
+ menerjemahkan "pengetahuan graph" menjadi "instruksi untuk tubuh".
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import math
27
+ from typing import Optional
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+
33
+ from diffusion_llm.config.model_config import GraphEncoderConfig
34
+
35
+
36
+ class ConfidenceEmbedding(nn.Module):
37
+ """Embed confidence scores as continuous values.
38
+
39
+ Maps [0, 1] confidence scores to d_graph-dimensional vectors
40
+ using sinusoidal encoding for smooth interpolation.
41
+
42
+ Analogi: Jin Soun tahu bedanya "aku yakin 100%" vs "mungkin 60%"
43
+ — encoding ini mengajarkan model membedakan juga.
44
+ """
45
+
46
+ def __init__(self, d_graph: int):
47
+ super().__init__()
48
+ self.d_graph = d_graph
49
+ # Learnable projection from scalar to d_graph
50
+ self.projection = nn.Sequential(
51
+ nn.Linear(1, d_graph // 4),
52
+ nn.GELU(),
53
+ nn.Linear(d_graph // 4, d_graph),
54
+ )
55
+
56
+ def forward(self, confidence: torch.Tensor) -> torch.Tensor:
57
+ """Embed confidence scores.
58
+
59
+ Args:
60
+ confidence: Tensor of shape (..., 1) with values in [0, 1].
61
+
62
+ Returns:
63
+ Tensor of shape (..., d_graph).
64
+ """
65
+ if confidence.dim() == 0:
66
+ confidence = confidence.unsqueeze(0)
67
+ if confidence.dim() == 1:
68
+ confidence = confidence.unsqueeze(-1)
69
+ return self.projection(confidence)
70
+
71
+
72
+ class TemporalEmbedding(nn.Module):
73
+ """Embed temporal context as position-aware vectors.
74
+
75
+ Uses sinusoidal positional encoding adapted for timestamps,
76
+ allowing the model to understand time-based relationships.
77
+
78
+ Analogi: Jin Soun mengingat bahwa "kejadian A terjadi 3 hari
79
+ sebelum kejadian B" — temporal embedding mengajarkan model
80
+ memahami hubungan waktu antar kejadian.
81
+ """
82
+
83
+ def __init__(self, d_graph: int, max_period: int = 10000):
84
+ super().__init__()
85
+ self.d_graph = d_graph
86
+ self.max_period = max_period
87
+ self.projection = nn.Sequential(
88
+ nn.Linear(d_graph, d_graph),
89
+ nn.GELU(),
90
+ nn.Linear(d_graph, d_graph),
91
+ )
92
+
93
+ def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
94
+ """Embed timestamps.
95
+
96
+ Args:
97
+ timestamps: Tensor of shape (batch, n_events) with normalized
98
+ timestamps (0 = earliest, 1 = latest).
99
+
100
+ Returns:
101
+ Tensor of shape (batch, n_events, d_graph).
102
+ """
103
+ batch_size, n_events = timestamps.shape
104
+ device = timestamps.device
105
+
106
+ # Sinusoidal encoding
107
+ half_dim = self.d_graph // 2
108
+ emb = math.log(self.max_period) / (half_dim - 1)
109
+ emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
110
+ emb = timestamps.float().unsqueeze(-1) * emb.unsqueeze(0).unsqueeze(0)
111
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
112
+
113
+ if emb.shape[-1] < self.d_graph:
114
+ # Pad if d_graph is odd
115
+ emb = F.pad(emb, (0, self.d_graph - emb.shape[-1]))
116
+
117
+ return self.projection(emb)
118
+
119
+
120
+ class NodeEncoder(nn.Module):
121
+ """Encode a single evidence node or composition.
122
+
123
+ Each node is represented as:
124
+ - Text embedding (from the tokenizer's vocabulary)
125
+ - Confidence score
126
+ - Optional temporal context
127
+ - Source trust score
128
+
129
+ These are combined into a single d_graph-dimensional vector.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ d_graph: int,
135
+ vocab_size: int = 32000,
136
+ embed_confidence: bool = True,
137
+ embed_temporal: bool = True,
138
+ ):
139
+ super().__init__()
140
+ self.d_graph = d_graph
141
+
142
+ # Text embedding (will be shared with the main model)
143
+ self.text_embed = nn.Embedding(vocab_size, d_graph)
144
+
145
+ # Confidence embedding
146
+ self.use_confidence = embed_confidence
147
+ if embed_confidence:
148
+ self.conf_embed = ConfidenceEmbedding(d_graph)
149
+
150
+ # Temporal embedding
151
+ self.use_temporal = embed_temporal
152
+ if embed_temporal:
153
+ self.temporal_embed = TemporalEmbedding(d_graph)
154
+
155
+ # Fusion layer — always build for max possible inputs
156
+ # At runtime, we may have fewer (e.g., no temporal data provided),
157
+ # so we use a flexible approach: always concatenate all available
158
+ # embeddings and project through a layer that handles the max size.
159
+ self._n_max_inputs = 1 + int(embed_confidence) + int(embed_temporal)
160
+ self.fusion = nn.Sequential(
161
+ nn.Linear(d_graph * self._n_max_inputs, d_graph),
162
+ nn.GELU(),
163
+ nn.LayerNorm(d_graph),
164
+ )
165
+
166
+ def forward(
167
+ self,
168
+ token_ids: torch.Tensor,
169
+ confidence: Optional[torch.Tensor] = None,
170
+ timestamps: Optional[torch.Tensor] = None,
171
+ ) -> torch.Tensor:
172
+ """Encode a batch of evidence nodes.
173
+
174
+ Args:
175
+ token_ids: Token IDs of shape (batch, n_nodes, seq_len).
176
+ confidence: Confidence scores of shape (batch, n_nodes).
177
+ timestamps: Timestamps of shape (batch, n_nodes).
178
+
179
+ Returns:
180
+ Encoded nodes of shape (batch, n_nodes, d_graph).
181
+ """
182
+ # Text embedding: mean pool over sequence length
183
+ text_emb = self.text_embed(token_ids).mean(dim=-2) # (batch, n_nodes, d_graph)
184
+
185
+ embeddings = [text_emb]
186
+
187
+ if self.use_confidence:
188
+ if confidence is not None:
189
+ conf_emb = self.conf_embed(confidence.unsqueeze(-1)) # (batch, n_nodes, d_graph)
190
+ embeddings.append(conf_emb)
191
+ else:
192
+ # Zero-pad to maintain consistent dimension
193
+ embeddings.append(torch.zeros_like(text_emb))
194
+
195
+ if self.use_temporal:
196
+ if timestamps is not None:
197
+ temp_emb = self.temporal_embed(timestamps) # (batch, n_nodes, d_graph)
198
+ embeddings.append(temp_emb)
199
+ else:
200
+ embeddings.append(torch.zeros_like(text_emb))
201
+
202
+ # Fuse all embeddings
203
+ combined = torch.cat(embeddings, dim=-1)
204
+ return self.fusion(combined)
205
+
206
+
207
+ class GraphAttentionLayer(nn.Module):
208
+ """Multi-head attention layer for graph-structured data.
209
+
210
+ Unlike standard self-attention, this operates on graph nodes
211
+ where edges represent structural relationships (compositions,
212
+ evidence links, temporal connections).
213
+
214
+ For now, we use standard multi-head attention over the node
215
+ sequence, as the structural information is already encoded
216
+ in the node features. Future versions can incorporate explicit
217
+ edge structure via graph attention networks (GAT).
218
+ """
219
+
220
+ def __init__(self, d_graph: int, n_heads: int, dropout: float = 0.1):
221
+ super().__init__()
222
+ self.attention = nn.MultiheadAttention(
223
+ embed_dim=d_graph,
224
+ num_heads=n_heads,
225
+ dropout=dropout,
226
+ batch_first=True,
227
+ )
228
+ self.norm = nn.LayerNorm(d_graph)
229
+ self.ff = nn.Sequential(
230
+ nn.Linear(d_graph, d_graph * 4),
231
+ nn.GELU(),
232
+ nn.Dropout(dropout),
233
+ nn.Linear(d_graph * 4, d_graph),
234
+ nn.Dropout(dropout),
235
+ )
236
+ self.norm_ff = nn.LayerNorm(d_graph)
237
+
238
+ def forward(
239
+ self,
240
+ x: torch.Tensor,
241
+ mask: Optional[torch.Tensor] = None,
242
+ ) -> torch.Tensor:
243
+ """Forward pass.
244
+
245
+ Args:
246
+ x: Node features of shape (batch, n_nodes, d_graph).
247
+ mask: Optional attention mask.
248
+
249
+ Returns:
250
+ Updated node features of same shape.
251
+ """
252
+ # Self-attention with residual
253
+ attn_out, _ = self.attention(x, x, x, attn_mask=mask)
254
+ x = self.norm(x + attn_out)
255
+
256
+ # Feed-forward with residual
257
+ ff_out = self.ff(x)
258
+ x = self.norm_ff(x + ff_out)
259
+
260
+ return x
261
+
262
+
263
+ class GraphConditioningEncoder(nn.Module):
264
+ """Encode graph-structured conditioning data for the diffusion model.
265
+
266
+ This encoder takes structured data from the RSVS Knowledge Graph
267
+ and produces conditioning vectors that guide the diffusion process.
268
+
269
+ The encoding process:
270
+ 1. Encode each evidence node (text + confidence + temporal)
271
+ 2. Encode compositions (how concepts relate)
272
+ 3. Encode anomalies (what doesn't fit)
273
+ 4. Encode reasoning chain (step-by-step logic)
274
+ 5. Aggregate via graph attention layers
275
+ 6. Project to conditioning vector for the diffusion model
276
+
277
+ Output modes (conditioning_method):
278
+ - 'cross_attention': Returns (K, V) pairs for cross-attention in transformer
279
+ - 'ada_ln': Returns scale/shift parameters for adaptive layer norm
280
+ - 'concat': Returns a conditioning prefix to concatenate with input
281
+
282
+ Args:
283
+ config: GraphEncoderConfig with hyperparameters.
284
+ vocab_size: Vocabulary size (must match tokenizer).
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ config: GraphEncoderConfig,
290
+ vocab_size: int = 32000,
291
+ ):
292
+ super().__init__()
293
+ self.config = config
294
+ self.conditioning_method = config.conditioning_method
295
+
296
+ # Node encoders for different graph element types
297
+ self.evidence_encoder = NodeEncoder(
298
+ d_graph=config.d_graph,
299
+ vocab_size=vocab_size,
300
+ embed_confidence=config.embed_confidence,
301
+ embed_temporal=config.embed_temporal,
302
+ )
303
+
304
+ self.composition_encoder = NodeEncoder(
305
+ d_graph=config.d_graph,
306
+ vocab_size=vocab_size,
307
+ embed_confidence=config.embed_confidence,
308
+ embed_temporal=False, # Compositions don't have temporal info
309
+ )
310
+
311
+ self.anomaly_encoder = NodeEncoder(
312
+ d_graph=config.d_graph,
313
+ vocab_size=vocab_size,
314
+ embed_confidence=True, # Anomalies always have confidence
315
+ embed_temporal=config.embed_temporal,
316
+ )
317
+
318
+ self.reasoning_encoder = NodeEncoder(
319
+ d_graph=config.d_graph,
320
+ vocab_size=vocab_size,
321
+ embed_confidence=True, # Reasoning steps have confidence
322
+ embed_temporal=False,
323
+ )
324
+
325
+ # Source trust embedding
326
+ self.trust_embed = ConfidenceEmbedding(config.d_graph)
327
+
328
+ # Graph attention layers for cross-node interaction
329
+ self.graph_layers = nn.ModuleList([
330
+ GraphAttentionLayer(
331
+ d_graph=config.d_graph,
332
+ n_heads=config.n_graph_heads,
333
+ dropout=0.1,
334
+ )
335
+ for _ in range(config.n_graph_layers)
336
+ ])
337
+
338
+ # Conditioning projection depends on method
339
+ # d_model_out will be set via set_output_dim() or defaults to d_graph
340
+ self._d_model_out = config.d_graph
341
+
342
+ if self.conditioning_method == "cross_attention":
343
+ # Project to (K, V) for cross-attention
344
+ self.key_proj = nn.Linear(config.d_graph, self._d_model_out)
345
+ self.value_proj = nn.Linear(config.d_graph, self._d_model_out)
346
+
347
+ elif self.conditioning_method == "ada_ln":
348
+ # Project to scale and shift for adaptive layer norm
349
+ self.scale_proj = nn.Linear(config.d_graph, self._d_model_out)
350
+ self.shift_proj = nn.Linear(config.d_graph, self._d_model_out)
351
+
352
+ elif self.conditioning_method == "concat":
353
+ # Project to a prefix sequence
354
+ self.concat_proj = nn.Linear(config.d_graph, self._d_model_out)
355
+
356
+ # Global pooling for summary
357
+ self.global_pool_proj = nn.Sequential(
358
+ nn.Linear(config.d_graph, config.d_graph),
359
+ nn.GELU(),
360
+ nn.Linear(config.d_graph, config.d_graph),
361
+ )
362
+
363
+ # Type embeddings for different graph element types
364
+ self.type_embeddings = nn.Embedding(4, config.d_graph)
365
+ # 0 = evidence, 1 = composition, 2 = anomaly, 3 = reasoning
366
+
367
+ def set_output_dim(self, d_model_out: int) -> None:
368
+ """Set the output dimension for the projection layers.
369
+
370
+ This must be called after __init__ if d_graph != d_model
371
+ (which is typically the case when the graph encoder's d_graph
372
+ differs from the transformer's d_model).
373
+
374
+ Args:
375
+ d_model_out: Output dimension (typically the transformer's d_model).
376
+ """
377
+ if d_model_out == self._d_model_out:
378
+ return # No change needed
379
+
380
+ self._d_model_out = d_model_out
381
+
382
+ # Rebuild projection layers with new output dim
383
+ if self.conditioning_method == "cross_attention":
384
+ self.key_proj = nn.Linear(self.config.d_graph, d_model_out)
385
+ self.value_proj = nn.Linear(self.config.d_graph, d_model_out)
386
+ elif self.conditioning_method == "ada_ln":
387
+ self.scale_proj = nn.Linear(self.config.d_graph, d_model_out)
388
+ self.shift_proj = nn.Linear(self.config.d_graph, d_model_out)
389
+ elif self.conditioning_method == "concat":
390
+ self.concat_proj = nn.Linear(self.config.d_graph, d_model_out)
391
+
392
+ def forward(
393
+ self,
394
+ evidence_ids: Optional[torch.Tensor] = None,
395
+ evidence_confidence: Optional[torch.Tensor] = None,
396
+ evidence_timestamps: Optional[torch.Tensor] = None,
397
+ composition_ids: Optional[torch.Tensor] = None,
398
+ composition_confidence: Optional[torch.Tensor] = None,
399
+ anomaly_ids: Optional[torch.Tensor] = None,
400
+ anomaly_confidence: Optional[torch.Tensor] = None,
401
+ anomaly_timestamps: Optional[torch.Tensor] = None,
402
+ reasoning_ids: Optional[torch.Tensor] = None,
403
+ reasoning_confidence: Optional[torch.Tensor] = None,
404
+ source_trust: Optional[torch.Tensor] = None,
405
+ batch_size: Optional[int] = None,
406
+ ) -> dict[str, torch.Tensor]:
407
+ """Encode graph conditioning data.
408
+
409
+ All inputs are optional — the encoder handles missing data gracefully.
410
+
411
+ Args:
412
+ evidence_ids: Evidence node token IDs, shape (batch, n_evidence, seq_len).
413
+ evidence_confidence: Evidence confidence scores, shape (batch, n_evidence).
414
+ evidence_timestamps: Evidence timestamps, shape (batch, n_evidence).
415
+ composition_ids: Composition token IDs, shape (batch, n_compositions, seq_len).
416
+ composition_confidence: Composition confidence, shape (batch, n_compositions).
417
+ anomaly_ids: Anomaly token IDs, shape (batch, n_anomalies, seq_len).
418
+ anomaly_confidence: Anomaly confidence, shape (batch, n_anomalies).
419
+ anomaly_timestamps: Anomaly timestamps, shape (batch, n_anomalies).
420
+ reasoning_ids: Reasoning step token IDs, shape (batch, n_steps, seq_len).
421
+ reasoning_confidence: Reasoning confidence, shape (batch, n_steps).
422
+ source_trust: Source trust score, shape (batch,).
423
+
424
+ Returns:
425
+ Dictionary with conditioning tensors depending on conditioning_method:
426
+ - 'cross_attention': {'keys': ..., 'values': ..., 'global': ...}
427
+ - 'ada_ln': {'scale': ..., 'shift': ..., 'global': ...}
428
+ - 'concat': {'prefix': ..., 'global': ...}
429
+ """
430
+ batch_size_inferred = self._infer_batch_size(
431
+ evidence_ids, composition_ids, anomaly_ids, reasoning_ids
432
+ )
433
+ device = next(self.parameters()).device
434
+
435
+ # Encode each type of graph element
436
+ node_embeddings = []
437
+ type_indices = []
438
+
439
+ # Evidence nodes
440
+ if evidence_ids is not None:
441
+ evidence_emb = self.evidence_encoder(
442
+ evidence_ids, evidence_confidence, evidence_timestamps
443
+ )
444
+ # Add type embedding
445
+ type_emb = self.type_embeddings(
446
+ torch.zeros(evidence_emb.shape[1], dtype=torch.long, device=device)
447
+ )
448
+ evidence_emb = evidence_emb + type_emb.unsqueeze(0)
449
+ node_embeddings.append(evidence_emb)
450
+ type_indices.extend([0] * evidence_emb.shape[1])
451
+
452
+ # Compositions
453
+ if composition_ids is not None:
454
+ comp_emb = self.composition_encoder(
455
+ composition_ids, composition_confidence
456
+ )
457
+ type_emb = self.type_embeddings(
458
+ torch.ones(comp_emb.shape[1], dtype=torch.long, device=device)
459
+ )
460
+ comp_emb = comp_emb + type_emb.unsqueeze(0)
461
+ node_embeddings.append(comp_emb)
462
+ type_indices.extend([1] * comp_emb.shape[1])
463
+
464
+ # Anomalies
465
+ if anomaly_ids is not None:
466
+ anom_emb = self.anomaly_encoder(
467
+ anomaly_ids, anomaly_confidence, anomaly_timestamps
468
+ )
469
+ type_emb = self.type_embeddings(
470
+ torch.full((anom_emb.shape[1],), 2, dtype=torch.long, device=device)
471
+ )
472
+ anom_emb = anom_emb + type_emb.unsqueeze(0)
473
+ node_embeddings.append(anom_emb)
474
+ type_indices.extend([2] * anom_emb.shape[1])
475
+
476
+ # Reasoning steps
477
+ if reasoning_ids is not None:
478
+ reason_emb = self.reasoning_encoder(
479
+ reasoning_ids, reasoning_confidence
480
+ )
481
+ type_emb = self.type_embeddings(
482
+ torch.full((reason_emb.shape[1],), 3, dtype=torch.long, device=device)
483
+ )
484
+ reason_emb = reason_emb + type_emb.unsqueeze(0)
485
+ node_embeddings.append(reason_emb)
486
+ type_indices.extend([3] * reason_emb.shape[1])
487
+
488
+ # If no graph data, return zero conditioning
489
+ if not node_embeddings:
490
+ bsz = batch_size or batch_size_inferred
491
+ dummy = torch.zeros(
492
+ bsz, 1, self.config.d_graph, device=device
493
+ )
494
+ return self._project_conditioning(dummy)
495
+
496
+ # Concatenate all node embeddings
497
+ all_nodes = torch.cat(node_embeddings, dim=1) # (batch, n_total_nodes, d_graph)
498
+
499
+ # Add source trust as a global bias
500
+ if source_trust is not None:
501
+ trust_emb = self.trust_embed(source_trust.unsqueeze(-1)) # (batch, d_graph)
502
+ # Broadcast trust to all nodes
503
+ all_nodes = all_nodes + trust_emb.unsqueeze(1) * 0.1 # Small influence
504
+
505
+ # Apply graph attention layers
506
+ for layer in self.graph_layers:
507
+ all_nodes = layer(all_nodes)
508
+
509
+ # Compute global conditioning (mean pool)
510
+ global_cond = all_nodes.mean(dim=1) # (batch, d_graph)
511
+ global_cond = self.global_pool_proj(global_cond)
512
+
513
+ # Project based on conditioning method
514
+ result = self._project_conditioning(all_nodes)
515
+ result["global"] = global_cond
516
+
517
+ return result
518
+
519
+ def _project_conditioning(
520
+ self, node_features: torch.Tensor
521
+ ) -> dict[str, torch.Tensor]:
522
+ """Project node features to conditioning format.
523
+
524
+ Args:
525
+ node_features: Shape (batch, n_nodes, d_graph).
526
+
527
+ Returns:
528
+ Dictionary with conditioning tensors.
529
+ """
530
+ result = {}
531
+
532
+ if self.conditioning_method == "cross_attention":
533
+ result["keys"] = self.key_proj(node_features)
534
+ result["values"] = self.value_proj(node_features)
535
+
536
+ elif self.conditioning_method == "ada_ln":
537
+ # Use mean-pooled features for scale/shift
538
+ pooled = node_features.mean(dim=1)
539
+ result["scale"] = self.scale_proj(pooled)
540
+ result["shift"] = self.shift_proj(pooled)
541
+
542
+ elif self.conditioning_method == "concat":
543
+ result["prefix"] = self.concat_proj(node_features)
544
+
545
+ return result
546
+
547
+ @staticmethod
548
+ def _infer_batch_size(*tensors) -> int:
549
+ """Infer batch size from the first non-None tensor."""
550
+ for t in tensors:
551
+ if t is not None:
552
+ return t.shape[0]
553
+ return 1
diffusion_llm/model/noise_scheduler.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Noise Scheduler
3
+
4
+ Implements the forward (noising) and reverse (denoising) diffusion process.
5
+
6
+ Forward Process:
7
+ q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
8
+
9
+ Reverse Process:
10
+ p(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)
11
+
12
+ This scheduler supports:
13
+ - Linear noise schedule (Ho et al., 2020)
14
+ - Cosine noise schedule (Nichol & Dhariwal, 2021) — recommended
15
+ - Sigmoid noise schedule
16
+
17
+ Analogi: Seperti Jin Soun membentuk pikirannya — dari noise
18
+ (kabur, tidak jelas) menjadi sinyal (pola yang jelas).
19
+ Setiap langkah denoising = satu langkah lebih dekat ke
20
+ kesimpulan yang koheren.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import math
26
+ from typing import Optional
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+
32
+ class NoiseScheduler(nn.Module):
33
+ """Noise scheduler for the diffusion process.
34
+
35
+ Manages the noise schedule (beta values, alpha values, etc.)
36
+ and provides methods for adding noise and computing posterior
37
+ distributions.
38
+
39
+ Args:
40
+ n_timesteps: Total number of diffusion timesteps.
41
+ schedule_type: Type of noise schedule ('linear', 'cosine', 'sigmoid').
42
+ beta_start: Starting beta for linear schedule.
43
+ beta_end: Ending beta for linear schedule.
44
+ prediction_type: What the model predicts ('epsilon', 'x0', or 'v').
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ n_timesteps: int = 1000,
50
+ schedule_type: str = "cosine",
51
+ beta_start: float = 1e-4,
52
+ beta_end: float = 0.02,
53
+ prediction_type: str = "epsilon",
54
+ ):
55
+ super().__init__()
56
+ self.n_timesteps = n_timesteps
57
+ self.schedule_type = schedule_type
58
+ self.beta_start = beta_start
59
+ self.beta_end = beta_end
60
+ self.prediction_type = prediction_type
61
+
62
+ # Compute and register noise schedule buffers
63
+ betas = self._compute_betas()
64
+ alphas = 1.0 - betas
65
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
66
+ alphas_cumprod_prev = torch.cat(
67
+ [torch.ones(1, dtype=betas.dtype), alphas_cumprod[:-1]]
68
+ )
69
+
70
+ # Register all as buffers (part of model state but not parameters)
71
+ self.register_buffer("betas", betas)
72
+ self.register_buffer("alphas", alphas)
73
+ self.register_buffer("alphas_cumprod", alphas_cumprod)
74
+ self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
75
+
76
+ # For q(x_t | x_0) computation
77
+ self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
78
+ self.register_buffer(
79
+ "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
80
+ )
81
+
82
+ # For posterior q(x_{t-1} | x_t, x_0)
83
+ posterior_variance = (
84
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
85
+ )
86
+ self.register_buffer("posterior_variance", posterior_variance)
87
+ self.register_buffer(
88
+ "posterior_log_variance_clipped",
89
+ torch.log(posterior_variance.clamp(min=1e-20)),
90
+ )
91
+ self.register_buffer(
92
+ "posterior_mean_coef1",
93
+ betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
94
+ )
95
+ self.register_buffer(
96
+ "posterior_mean_coef2",
97
+ (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
98
+ )
99
+
100
+ def _compute_betas(self) -> torch.Tensor:
101
+ """Compute beta schedule.
102
+
103
+ Returns:
104
+ Tensor of shape (n_timesteps,) with beta values.
105
+ """
106
+ if self.schedule_type == "linear":
107
+ return torch.linspace(
108
+ self.beta_start, self.beta_end, self.n_timesteps
109
+ )
110
+ elif self.schedule_type == "cosine":
111
+ return self._cosine_schedule()
112
+ elif self.schedule_type == "sigmoid":
113
+ return self._sigmoid_schedule()
114
+ else:
115
+ raise ValueError(
116
+ f"Unknown schedule_type '{self.schedule_type}'. "
117
+ f"Use 'linear', 'cosine', or 'sigmoid'."
118
+ )
119
+
120
+ def _cosine_schedule(self, s: float = 0.008) -> torch.Tensor:
121
+ """Cosine schedule as proposed in Nichol & Dhariwal 2021.
122
+
123
+ alpha_bar(t) = cos^2((t/T + s) / (1 + s) * pi/2)
124
+ beta(t) = 1 - alpha_bar(t) / alpha_bar(t-1)
125
+
126
+ This schedule avoids too much noise at the end and too
127
+ little at the beginning, leading to more stable training.
128
+
129
+ Args:
130
+ s: Offset to prevent singularity at t=0.
131
+
132
+ Returns:
133
+ Tensor of beta values.
134
+ """
135
+ steps = self.n_timesteps + 1
136
+ t = torch.linspace(0, self.n_timesteps, steps)
137
+ alphas_cumprod = torch.cos(
138
+ ((t / self.n_timesteps) + s) / (1 + s) * math.pi * 0.5
139
+ ) ** 2
140
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
141
+ betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
142
+ return torch.clamp(betas, 0.0001, 0.9999)
143
+
144
+ def _sigmoid_schedule(self) -> torch.Tensor:
145
+ """Sigmoid-based noise schedule.
146
+
147
+ beta(t) = sigmoid(-gamma * (t - T/2) + offset) * (beta_end - beta_start) + beta_start
148
+
149
+ Provides a smooth transition between low and high noise.
150
+ """
151
+ betas = torch.linspace(-6, 6, self.n_timesteps)
152
+ betas = torch.sigmoid(betas) * (self.beta_end - self.beta_start) + self.beta_start
153
+ return torch.clamp(betas, 0.0001, 0.9999)
154
+
155
+ def add_noise(
156
+ self,
157
+ x_0: torch.Tensor,
158
+ noise: torch.Tensor,
159
+ t: torch.Tensor,
160
+ ) -> torch.Tensor:
161
+ """Forward diffusion: add noise to clean data.
162
+
163
+ q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
164
+
165
+ x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
166
+
167
+ Args:
168
+ x_0: Clean data tensor of shape (batch, seq_len, d_model).
169
+ noise: Noise tensor of same shape as x_0.
170
+ t: Timestep indices of shape (batch,).
171
+
172
+ Returns:
173
+ Noisy data x_t of same shape as x_0.
174
+ """
175
+ # Gather schedule values for timesteps
176
+ sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_0)
177
+ sqrt_one_minus_alpha = self._gather(
178
+ self.sqrt_one_minus_alphas_cumprod, t, x_0
179
+ )
180
+
181
+ return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
182
+
183
+ def compute_loss_target(
184
+ self,
185
+ x_0: torch.Tensor,
186
+ noise: torch.Tensor,
187
+ t: torch.Tensor,
188
+ ) -> torch.Tensor:
189
+ """Compute the target for the diffusion loss.
190
+
191
+ Depending on prediction_type:
192
+ - 'epsilon': target = noise (predict the noise that was added)
193
+ - 'x0': target = x_0 (predict the clean data directly)
194
+ - 'v': target = v (velocity prediction, combines both)
195
+
196
+ Args:
197
+ x_0: Clean data.
198
+ noise: Noise that was added.
199
+ t: Timestep indices.
200
+
201
+ Returns:
202
+ Target tensor for loss computation.
203
+ """
204
+ if self.prediction_type == "epsilon":
205
+ return noise
206
+ elif self.prediction_type == "x0":
207
+ return x_0
208
+ elif self.prediction_type == "v":
209
+ # v = sqrt(alpha_bar) * noise - sqrt(1 - alpha_bar) * x_0
210
+ sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_0)
211
+ sqrt_one_minus_alpha = self._gather(
212
+ self.sqrt_one_minus_alphas_cumprod, t, x_0
213
+ )
214
+ return sqrt_alpha * noise - sqrt_one_minus_alpha * x_0
215
+ else:
216
+ raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
217
+
218
+ def predict_x0_from_epsilon(
219
+ self,
220
+ x_t: torch.Tensor,
221
+ epsilon: torch.Tensor,
222
+ t: torch.Tensor,
223
+ ) -> torch.Tensor:
224
+ """Predict x_0 from the model's epsilon prediction.
225
+
226
+ x_0 = (x_t - sqrt(1 - alpha_bar_t) * epsilon) / sqrt(alpha_bar_t)
227
+
228
+ Args:
229
+ x_t: Noisy data.
230
+ epsilon: Predicted noise.
231
+ t: Timestep indices.
232
+
233
+ Returns:
234
+ Predicted clean data x_0.
235
+ """
236
+ sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_t)
237
+ sqrt_one_minus_alpha = self._gather(
238
+ self.sqrt_one_minus_alphas_cumprod, t, x_t
239
+ )
240
+ return (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alpha
241
+
242
+ def predict_x0_from_v(
243
+ self,
244
+ x_t: torch.Tensor,
245
+ v: torch.Tensor,
246
+ t: torch.Tensor,
247
+ ) -> torch.Tensor:
248
+ """Predict x_0 from velocity prediction.
249
+
250
+ Args:
251
+ x_t: Noisy data.
252
+ v: Predicted velocity.
253
+ t: Timestep indices.
254
+
255
+ Returns:
256
+ Predicted clean data x_0.
257
+ """
258
+ sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_t)
259
+ sqrt_one_minus_alpha = self._gather(
260
+ self.sqrt_one_minus_alphas_cumprod, t, x_t
261
+ )
262
+ return sqrt_alpha * x_t - sqrt_one_minus_alpha * v
263
+
264
+ def posterior_mean(
265
+ self,
266
+ x_0: torch.Tensor,
267
+ x_t: torch.Tensor,
268
+ t: torch.Tensor,
269
+ ) -> torch.Tensor:
270
+ """Compute the posterior mean q(x_{t-1} | x_t, x_0).
271
+
272
+ mu = coef1 * x_0 + coef2 * x_t
273
+
274
+ Args:
275
+ x_0: Predicted or actual clean data.
276
+ x_t: Noisy data at timestep t.
277
+ t: Timestep indices.
278
+
279
+ Returns:
280
+ Posterior mean tensor.
281
+ """
282
+ coef1 = self._gather(self.posterior_mean_coef1, t, x_t)
283
+ coef2 = self._gather(self.posterior_mean_coef2, t, x_t)
284
+ return coef1 * x_0 + coef2 * x_t
285
+
286
+ def step_ddpm(
287
+ self,
288
+ model_output: torch.Tensor,
289
+ x_t: torch.Tensor,
290
+ t: torch.Tensor,
291
+ ) -> torch.Tensor:
292
+ """Single DDPM reverse step: x_t -> x_{t-1}.
293
+
294
+ Args:
295
+ model_output: Model prediction (epsilon, x0, or v).
296
+ x_t: Noisy data at timestep t.
297
+ t: Current timestep indices.
298
+
299
+ Returns:
300
+ Denoised data at timestep t-1.
301
+ """
302
+ # Get predicted x_0
303
+ if self.prediction_type == "epsilon":
304
+ x_0_pred = self.predict_x0_from_epsilon(x_t, model_output, t)
305
+ elif self.prediction_type == "x0":
306
+ x_0_pred = model_output
307
+ elif self.prediction_type == "v":
308
+ x_0_pred = self.predict_x0_from_v(x_t, model_output, t)
309
+ else:
310
+ raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
311
+
312
+ # Clamp x_0 prediction for stability
313
+ x_0_pred = x_0_pred.clamp(-5.0, 5.0)
314
+
315
+ # Compute posterior mean
316
+ mean = self.posterior_mean(x_0_pred, x_t, t)
317
+
318
+ # Add noise (except for t=0)
319
+ if t.min() > 0:
320
+ noise = torch.randn_like(x_t)
321
+ # Get posterior variance
322
+ log_variance = self._gather(
323
+ self.posterior_log_variance_clipped, t, x_t
324
+ )
325
+ noise_scale = torch.exp(0.5 * log_variance)
326
+ return mean + noise_scale * noise
327
+ else:
328
+ return mean
329
+
330
+ def step_ddim(
331
+ self,
332
+ model_output: torch.Tensor,
333
+ x_t: torch.Tensor,
334
+ t: int,
335
+ t_prev: int,
336
+ eta: float = 0.0,
337
+ ) -> torch.Tensor:
338
+ """Single DDIM reverse step: x_t -> x_{t_prev}.
339
+
340
+ DDIM is deterministic when eta=0, allowing fewer steps
341
+ at inference time while maintaining quality.
342
+
343
+ Args:
344
+ model_output: Model prediction.
345
+ x_t: Noisy data at timestep t.
346
+ t: Current timestep (scalar).
347
+ t_prev: Previous timestep (scalar, < t).
348
+ eta: Stochasticity parameter (0 = deterministic).
349
+
350
+ Returns:
351
+ Denoised data at timestep t_prev.
352
+ """
353
+ device = x_t.device
354
+ t_tensor = torch.tensor([t], device=device).expand(x_t.shape[0])
355
+
356
+ # Get predicted x_0
357
+ if self.prediction_type == "epsilon":
358
+ x_0_pred = self.predict_x0_from_epsilon(x_t, model_output, t_tensor)
359
+ elif self.prediction_type == "x0":
360
+ x_0_pred = model_output
361
+ elif self.prediction_type == "v":
362
+ x_0_pred = self.predict_x0_from_v(x_t, model_output, t_tensor)
363
+ else:
364
+ raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
365
+
366
+ x_0_pred = x_0_pred.clamp(-5.0, 5.0)
367
+
368
+ # alpha_bar values
369
+ alpha_t = self.alphas_cumprod[t]
370
+ alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
371
+
372
+ # Compute sigma
373
+ sigma = eta * torch.sqrt(
374
+ (1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)
375
+ )
376
+
377
+ # Direction pointing to x_t
378
+ pred_dir = torch.sqrt(1 - alpha_prev - sigma ** 2) * (
379
+ (x_t - torch.sqrt(alpha_t) * x_0_pred) / torch.sqrt(1 - alpha_t)
380
+ )
381
+
382
+ # DDIM update
383
+ x_prev = torch.sqrt(alpha_prev) * x_0_pred + pred_dir
384
+
385
+ if eta > 0 and sigma > 0:
386
+ noise = torch.randn_like(x_t)
387
+ x_prev = x_prev + sigma * noise
388
+
389
+ return x_prev
390
+
391
+ @staticmethod
392
+ def _gather(
393
+ values: torch.Tensor,
394
+ t: torch.Tensor,
395
+ target: torch.Tensor,
396
+ ) -> torch.Tensor:
397
+ """Gather schedule values for timesteps and reshape for broadcasting.
398
+
399
+ Args:
400
+ values: Schedule values of shape (n_timesteps,).
401
+ t: Timestep indices of shape (batch,).
402
+ target: Target tensor to match shape.
403
+
404
+ Returns:
405
+ Gathered values reshaped for broadcasting with target.
406
+ """
407
+ gathered = values.gather(0, t)
408
+ # Reshape to (batch, 1, 1, ...) for broadcasting
409
+ ndim = target.ndim - 1 # minus batch dim
410
+ for _ in range(ndim):
411
+ gathered = gathered.unsqueeze(-1)
412
+ return gathered.expand_as(target)
413
+
414
+ def get_timestep_schedule(self, n_inference_steps: int) -> list[int]:
415
+ """Get evenly-spaced timestep schedule for inference.
416
+
417
+ For DDIM: use a subset of the training timesteps.
418
+
419
+ Args:
420
+ n_inference_steps: Number of inference steps.
421
+
422
+ Returns:
423
+ List of timestep indices in descending order.
424
+ """
425
+ step_size = self.n_timesteps // n_inference_steps
426
+ return list(range(self.n_timesteps - 1, 0, -step_size))
diffusion_llm/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AAM Diffusion LLM — Dependencies
2
+
3
+ # Core
4
+ torch>=2.0.0
5
+ numpy>=1.24.0
6
+
7
+ # Training
8
+ tensorboard>=2.13.0
9
+
10
+ # Optional (for logging and monitoring)
11
+ # wandb>=0.15.0
12
+
13
+ # Testing
14
+ pytest>=7.4.0
15
+
16
+ # Note: This framework is designed to be lightweight.
17
+ # No heavy ML framework dependencies beyond PyTorch.
diffusion_llm/scripts/evaluate.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AAM Diffusion LLM — Evaluation Script
4
+
5
+ Evaluates a trained AAM Diffusion Model on test data or
6
+ generates sample narratives from graph conditioning.
7
+
8
+ Usage:
9
+ # Evaluate on test data
10
+ python scripts/evaluate.py --checkpoint output/best.pt
11
+
12
+ # Generate sample narratives
13
+ python scripts/evaluate.py --checkpoint output/best.pt --generate
14
+
15
+ # Interactive mode
16
+ python scripts/evaluate.py --checkpoint output/best.pt --interactive
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import logging
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ sys.path.insert(0, str(Path(__file__).parent.parent))
28
+
29
+ from diffusion_llm.config.model_config import AamDiffusionConfig
30
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
31
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
32
+ from diffusion_llm.inference.generator import AamGenerator
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def parse_args() -> argparse.Namespace:
42
+ parser = argparse.ArgumentParser(description="Evaluate AAM Diffusion LLM")
43
+ parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
44
+ parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
45
+ parser.add_argument("--generate", action="store_true", help="Generate sample narratives")
46
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode")
47
+ parser.add_argument("--test_data", type=str, default=None, help="Test data path (JSONL)")
48
+ parser.add_argument("--n_steps", type=int, default=50, help="Inference denoising steps")
49
+ parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
50
+ parser.add_argument("--language", type=str, default="id", help="Output language")
51
+ return parser.parse_args()
52
+
53
+
54
+ def generate_samples(generator: AamGenerator, language: str) -> None:
55
+ """Generate sample narratives from predefined graph conditioning."""
56
+ samples = [
57
+ {
58
+ "trigger": "Siapa yang mencuri Snow Plum Pill?",
59
+ "evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
60
+ "anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
61
+ "reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola"],
62
+ },
63
+ {
64
+ "trigger": "Analisis pergerakan Diancang Five Swords",
65
+ "evidence_nodes": ["Gu Ilmu", "Jang Hangi", "Diancang Five Swords", "Hefei"],
66
+ "anomalies": ["Success rate pair lebih tinggi dari biasanya"],
67
+ "reasoning_steps": ["Recall laporan terkait", "Pattern completion dari bukti"],
68
+ },
69
+ {
70
+ "trigger": "Hubungan antara Ju Jangmok dan pencurian",
71
+ "evidence_nodes": ["Ju Jangmok", "Snow Plum Pill", "dark_faction"],
72
+ "anomalies": ["Ju Jangmok menghilang hari yang sama"],
73
+ "reasoning_steps": ["Eliminasi tersangka obvious", "Verify konsistensi"],
74
+ },
75
+ ]
76
+
77
+ print("\n" + "=" * 60)
78
+ print(" AAM Diffusion LLM — Sample Generation")
79
+ print("=" * 60)
80
+
81
+ for i, sample in enumerate(samples, 1):
82
+ result = generator.generate(
83
+ trigger=sample["trigger"],
84
+ evidence_nodes=sample["evidence_nodes"],
85
+ anomalies=sample["anomalies"],
86
+ reasoning_steps=sample["reasoning_steps"],
87
+ language=language,
88
+ )
89
+ print(f"\n--- Sample {i} ---")
90
+ print(f"Trigger: {sample['trigger']}")
91
+ print(f"Evidence: {', '.join(sample['evidence_nodes'])}")
92
+ print(f"Anomalies: {'; '.join(sample['anomalies'])}")
93
+ print(f"\nGenerated Narrative:")
94
+ print(result.narrative)
95
+ print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s]")
96
+
97
+
98
+ def interactive_mode(generator: AamGenerator, language: str) -> None:
99
+ """Interactive generation mode."""
100
+ print("\n" + "=" * 60)
101
+ print(" AAM Diffusion LLM — Interactive Mode")
102
+ print(" Type 'quit' to exit")
103
+ print("=" * 60)
104
+
105
+ while True:
106
+ trigger = input("\nTrigger/Question: ").strip()
107
+ if trigger.lower() in ("quit", "exit", "q"):
108
+ break
109
+
110
+ evidence = input("Evidence nodes (comma-separated): ").strip()
111
+ evidence_nodes = [e.strip() for e in evidence.split(",") if e.strip()] if evidence else None
112
+
113
+ anomalies_input = input("Anomalies (comma-separated): ").strip()
114
+ anomalies = [a.strip() for a in anomalies_input.split(",") if a.strip()] if anomalies_input else None
115
+
116
+ result = generator.generate(
117
+ trigger=trigger,
118
+ evidence_nodes=evidence_nodes,
119
+ anomalies=anomalies,
120
+ language=language,
121
+ )
122
+ print(f"\nGenerated Narrative:\n{result.narrative}")
123
+ print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s, Confidence: {result.confidence:.1%}]")
124
+
125
+
126
+ def main() -> None:
127
+ args = parse_args()
128
+
129
+ # Load model
130
+ logger.info("Loading model from %s", args.checkpoint)
131
+ model = AamDiffusionModel.load(args.checkpoint)
132
+
133
+ # Load or create tokenizer
134
+ if args.tokenizer:
135
+ tokenizer = AamTokenizer.load(args.tokenizer)
136
+ else:
137
+ # Try to find tokenizer in same directory as checkpoint
138
+ tokenizer_path = Path(args.checkpoint).parent / "data" / "tokenizer.json"
139
+ if tokenizer_path.exists():
140
+ tokenizer = AamTokenizer.load(tokenizer_path)
141
+ else:
142
+ logger.warning("No tokenizer found. Using untrained tokenizer.")
143
+ tokenizer = AamTokenizer()
144
+
145
+ # Create generator
146
+ generator = AamGenerator(model, tokenizer, model.config)
147
+
148
+ if args.interactive:
149
+ interactive_mode(generator, args.language)
150
+ elif args.generate:
151
+ generate_samples(generator, args.language)
152
+ else:
153
+ logger.info("Use --generate or --interactive flag")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
diffusion_llm/scripts/export.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AAM Diffusion LLM — Export Script
4
+
5
+ Export a trained model for deployment.
6
+
7
+ Usage:
8
+ python scripts/export.py --checkpoint output/best.pt --output model_export/
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import logging
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+
20
+ from diffusion_llm.config.model_config import AamDiffusionConfig
21
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
22
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
23
+
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def main() -> None:
29
+ parser = argparse.ArgumentParser(description="Export AAM Diffusion Model")
30
+ parser.add_argument("--checkpoint", type=str, required=True)
31
+ parser.add_argument("--output", type=str, default="./model_export")
32
+ parser.add_argument("--format", type=str, default="pt", choices=["pt", "onnx"])
33
+ args = parser.parse_args()
34
+
35
+ output_dir = Path(args.output)
36
+ output_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ # Load model
39
+ model = AamDiffusionModel.load(args.checkpoint)
40
+ model.eval()
41
+
42
+ # Save model
43
+ model_path = output_dir / "model.pt"
44
+ model.save(str(model_path))
45
+ logger.info("Model exported to %s", model_path)
46
+
47
+ # Save config
48
+ config_path = output_dir / "config.json"
49
+ model.config.to_json(config_path)
50
+ logger.info("Config saved to %s", config_path)
51
+
52
+ # Try to copy tokenizer
53
+ checkpoint_dir = Path(args.checkpoint).parent
54
+ tokenizer_path = checkpoint_dir / "data" / "tokenizer.json"
55
+ if tokenizer_path.exists():
56
+ import shutil
57
+ shutil.copy(tokenizer_path, output_dir / "tokenizer.json")
58
+ logger.info("Tokenizer copied to %s", output_dir / "tokenizer.json")
59
+
60
+ # Summary
61
+ print(f"\nExport complete!")
62
+ print(f" Model: {model_path}")
63
+ print(f" Config: {config_path}")
64
+ print(f" Parameters: {model._format_params(model.get_num_params())}")
65
+ print(f"\n This is AAM's own body — 1 mind + 1 body.")
66
+ print(f" Mind = RSVS Knowledge Graph")
67
+ print(f" Body = This Diffusion Model ({model.config.model_name})")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
diffusion_llm/scripts/train.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AAM Diffusion LLM — Training Script
4
+
5
+ Main entry point for training the AAM Diffusion Model.
6
+
7
+ Usage:
8
+ # Train with default config (base model)
9
+ python scripts/train.py
10
+
11
+ # Train with specific model size
12
+ python scripts/train.py --model_size small
13
+
14
+ # Train with custom config
15
+ python scripts/train.py --config path/to/config.json
16
+
17
+ # Train with specific data
18
+ python scripts/train.py --train_data path/to/train.jsonl --val_data path/to/val.jsonl
19
+
20
+ Analogi: Seperti Jin Soun memulai latihan fisiknya —
21
+ ini adalah titik awal di mana "tubuh" AAM mulai dilatih.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import logging
28
+ import sys
29
+ from pathlib import Path
30
+
31
+ # Add parent directory to path for imports
32
+ sys.path.insert(0, str(Path(__file__).parent.parent))
33
+
34
+ from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
35
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
36
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
37
+ from diffusion_llm.training.trainer import AamTrainer
38
+ from diffusion_llm.training.dataset import GraphNarrativeDataset
39
+ from diffusion_llm.data.data_pipeline import DataPipeline
40
+
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ def parse_args() -> argparse.Namespace:
49
+ """Parse command-line arguments."""
50
+ parser = argparse.ArgumentParser(
51
+ description="Train AAM Diffusion LLM",
52
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
53
+ )
54
+
55
+ # Model configuration
56
+ parser.add_argument(
57
+ "--model_size", type=str, default="base",
58
+ choices=["tiny", "small", "base", "medium"],
59
+ help="Model size preset",
60
+ )
61
+ parser.add_argument(
62
+ "--config", type=str, default=None,
63
+ help="Path to custom config JSON (overrides --model_size)",
64
+ )
65
+
66
+ # Data
67
+ parser.add_argument(
68
+ "--train_data", type=str, default=None,
69
+ help="Path to training data (JSONL)",
70
+ )
71
+ parser.add_argument(
72
+ "--val_data", type=str, default=None,
73
+ help="Path to validation data (JSONL)",
74
+ )
75
+ parser.add_argument(
76
+ "--output_dir", type=str, default="./output",
77
+ help="Output directory for checkpoints and logs",
78
+ )
79
+ parser.add_argument(
80
+ "--force_regenerate", action="store_true",
81
+ help="Force regenerate synthetic data",
82
+ )
83
+
84
+ # Training overrides
85
+ parser.add_argument("--batch_size", type=int, default=None)
86
+ parser.add_argument("--learning_rate", type=float, default=None)
87
+ parser.add_argument("--max_steps", type=int, default=None)
88
+ parser.add_argument("--n_timesteps", type=int, default=None)
89
+ parser.add_argument("--seed", type=int, default=42)
90
+
91
+ return parser.parse_args()
92
+
93
+
94
+ def main() -> None:
95
+ """Main training entry point."""
96
+ args = parse_args()
97
+
98
+ # Load or create config
99
+ if args.config:
100
+ config = AamDiffusionConfig.from_json(args.config)
101
+ logger.info("Loaded config from %s", args.config)
102
+ else:
103
+ config = get_default_config(args.model_size)
104
+ logger.info("Using %s model config", args.model_size)
105
+
106
+ # Apply CLI overrides
107
+ if args.output_dir:
108
+ config.output_dir = args.output_dir
109
+ if args.train_data:
110
+ config.training.train_data_path = args.train_data
111
+ if args.val_data:
112
+ config.training.val_data_path = args.val_data
113
+ if args.batch_size:
114
+ config.training.batch_size = args.batch_size
115
+ if args.learning_rate:
116
+ config.training.learning_rate = args.learning_rate
117
+ if args.max_steps:
118
+ config.training.max_steps = args.max_steps
119
+ if args.n_timesteps:
120
+ config.diffusion.n_timesteps = args.n_timesteps
121
+ config.seed = args.seed
122
+
123
+ # Print config summary
124
+ print(config.summary())
125
+
126
+ # Save config
127
+ config_path = Path(config.output_dir) / "config.json"
128
+ config.to_json(config_path)
129
+ logger.info("Config saved to %s", config_path)
130
+
131
+ # Step 1: Prepare data
132
+ pipeline = DataPipeline(config)
133
+ tokenizer, train_loader, val_loader = pipeline.prepare(
134
+ force_regenerate=args.force_regenerate,
135
+ )
136
+
137
+ # Step 2: Create model
138
+ model = AamDiffusionModel(config)
139
+ logger.info(
140
+ "Model created: %s parameters",
141
+ model._format_params(model.get_num_params()),
142
+ )
143
+
144
+ # Step 3: Create datasets (using pre-created loaders)
145
+ train_dataset = train_loader.dataset
146
+ val_dataset = val_loader.dataset if val_loader else None
147
+
148
+ # Step 4: Create trainer and train
149
+ trainer = AamTrainer(
150
+ config=config,
151
+ model=model,
152
+ tokenizer=tokenizer,
153
+ train_dataset=train_dataset,
154
+ val_dataset=val_dataset,
155
+ )
156
+
157
+ # Override data loaders (already created by pipeline)
158
+ trainer.train_loader = train_loader
159
+ trainer.val_loader = val_loader
160
+
161
+ # Start training
162
+ trainer.train()
163
+
164
+ logger.info("Training complete! Output saved to %s", config.output_dir)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
diffusion_llm/scripts/train_final.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AAM Diffusion LLM — Final Training Script
4
+
5
+ Trains the complete AAM Diffusion LLM pipeline:
6
+ 1. Generate synthetic training data (Graph→Narrative pairs)
7
+ 2. Train the AAM Sentence-Level + BPE Tokenizer
8
+ 3. Train the Diffusion Transformer model
9
+ 4. Save final model, tokenizer, and config for HuggingFace upload
10
+
11
+ This is the "birth" of AAM's body — from random weights to
12
+ a model that can arrange sentences from graph conditioning.
13
+
14
+ Usage:
15
+ python scripts/train_final.py --output_dir ./aam-diffusion-v1
16
+ python scripts/train_final.py --model_size tiny --max_steps 500
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import logging
24
+ import sys
25
+ import time
26
+ from pathlib import Path
27
+
28
+ # Add parent directory to path
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ import torch
32
+ import numpy as np
33
+
34
+ from diffusion_llm.config.model_config import (
35
+ AamDiffusionConfig, get_default_config, ModelConfig,
36
+ DiffusionConfig, GraphEncoderConfig, TokenizerConfig,
37
+ TrainingConfig, InferenceConfig,
38
+ )
39
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
40
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
41
+ from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
42
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
43
+
44
+ logging.basicConfig(
45
+ level=logging.INFO,
46
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
47
+ )
48
+ logger = logging.getLogger("train_final")
49
+
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser(description="Train AAM Diffusion LLM (Final)")
53
+ parser.add_argument("--model_size", type=str, default="tiny",
54
+ choices=["tiny", "small", "base", "medium"])
55
+ parser.add_argument("--output_dir", type=str, default="./aam-diffusion-v1")
56
+ parser.add_argument("--max_steps", type=int, default=500)
57
+ parser.add_argument("--batch_size", type=int, default=8)
58
+ parser.add_argument("--learning_rate", type=float, default=3e-4)
59
+ parser.add_argument("--n_synthetic_train", type=int, default=500)
60
+ parser.add_argument("--n_synthetic_val", type=int, default=50)
61
+ parser.add_argument("--seed", type=int, default=42)
62
+ parser.add_argument("--log_every", type=int, default=50)
63
+ parser.add_argument("--save_every", type=int, default=500)
64
+ parser.add_argument("--eval_every", type=int, default=200)
65
+ return parser.parse_args()
66
+
67
+
68
+ def set_seed(seed: int):
69
+ """Set random seeds for reproducibility."""
70
+ torch.manual_seed(seed)
71
+ np.random.seed(seed)
72
+ import random
73
+ random.seed(seed)
74
+
75
+
76
+ def generate_data(output_dir: Path, n_train: int, n_val: int, seed: int):
77
+ """Generate synthetic training data."""
78
+ logger.info("=" * 60)
79
+ logger.info("STEP 1: Generating Synthetic Training Data")
80
+ logger.info("=" * 60)
81
+
82
+ data_dir = output_dir / "data"
83
+ data_dir.mkdir(parents=True, exist_ok=True)
84
+
85
+ train_path, val_path = SyntheticDataGenerator.generate_training_split(
86
+ output_dir=data_dir,
87
+ n_train=n_train,
88
+ n_val=n_val,
89
+ language="id",
90
+ seed=seed,
91
+ )
92
+
93
+ logger.info(f" Train data: {train_path} ({n_train} examples)")
94
+ logger.info(f" Val data: {val_path} ({n_val} examples)")
95
+ return train_path, val_path
96
+
97
+
98
+ def train_tokenizer(train_path: Path, output_dir: Path, config: AamDiffusionConfig) -> AamTokenizer:
99
+ """Train the AAM Tokenizer on synthetic data."""
100
+ logger.info("=" * 60)
101
+ logger.info("STEP 2: Training AAM Sentence-Level + BPE Tokenizer")
102
+ logger.info("=" * 60)
103
+
104
+ tokenizer = AamTokenizer(config=config.tokenizer)
105
+
106
+ # Read training texts
107
+ texts = []
108
+ with open(train_path, "r", encoding="utf-8") as f:
109
+ for line in f:
110
+ line = line.strip()
111
+ if not line:
112
+ continue
113
+ try:
114
+ data = json.loads(line)
115
+ if data.get("narrative"):
116
+ texts.append(data["narrative"])
117
+ if data.get("trigger"):
118
+ texts.append(data["trigger"])
119
+ for ev in data.get("evidence_nodes", []):
120
+ texts.append(ev)
121
+ for anom in data.get("anomalies", []):
122
+ texts.append(anom)
123
+ for step in data.get("reasoning_steps", []):
124
+ texts.append(step)
125
+ for comp in data.get("compositions", []):
126
+ texts.append(comp)
127
+ except json.JSONDecodeError:
128
+ continue
129
+
130
+ logger.info(f" Training tokenizer on {len(texts)} texts...")
131
+ tokenizer.train(texts, vocab_size=config.tokenizer.bpe_vocab_size)
132
+
133
+ # Save tokenizer
134
+ tokenizer_path = output_dir / "tokenizer.json"
135
+ tokenizer.save(tokenizer_path)
136
+ logger.info(f" Tokenizer saved: {tokenizer_path}")
137
+ logger.info(f" Vocab size: {tokenizer.vocab_size}")
138
+ logger.info(f" BPE merges: {len(tokenizer.merges)}")
139
+
140
+ return tokenizer
141
+
142
+
143
+ def create_dataloaders(
144
+ train_path: Path, val_path: Path,
145
+ tokenizer: AamTokenizer, config: AamDiffusionConfig
146
+ ):
147
+ """Create training and validation data loaders."""
148
+ logger.info("=" * 60)
149
+ logger.info("STEP 3: Creating DataLoaders")
150
+ logger.info("=" * 60)
151
+
152
+ train_dataset = GraphNarrativeDataset(
153
+ data_path=train_path,
154
+ tokenizer=tokenizer,
155
+ max_seq_len=config.model.max_seq_len,
156
+ max_evidence=config.graph_encoder.max_evidence_nodes,
157
+ max_anomalies=config.graph_encoder.max_anomalies,
158
+ max_reasoning=config.graph_encoder.max_reasoning_steps,
159
+ augment=True,
160
+ )
161
+
162
+ val_dataset = GraphNarrativeDataset(
163
+ data_path=val_path,
164
+ tokenizer=tokenizer,
165
+ max_seq_len=config.model.max_seq_len,
166
+ max_evidence=config.graph_encoder.max_evidence_nodes,
167
+ max_anomalies=config.graph_encoder.max_anomalies,
168
+ max_reasoning=config.graph_encoder.max_reasoning_steps,
169
+ augment=False,
170
+ )
171
+
172
+ from torch.utils.data import DataLoader
173
+
174
+ train_loader = DataLoader(
175
+ train_dataset,
176
+ batch_size=config.training.batch_size,
177
+ shuffle=True,
178
+ num_workers=0, # CPU training: use 0 workers
179
+ collate_fn=collate_fn,
180
+ pin_memory=False, # CPU: no pin_memory
181
+ )
182
+
183
+ val_loader = DataLoader(
184
+ val_dataset,
185
+ batch_size=config.training.batch_size,
186
+ shuffle=False,
187
+ num_workers=0,
188
+ collate_fn=collate_fn,
189
+ pin_memory=False,
190
+ )
191
+
192
+ logger.info(f" Train: {len(train_dataset)} examples, {len(train_loader)} batches")
193
+ logger.info(f" Val: {len(val_dataset)} examples, {len(val_loader)} batches")
194
+
195
+ return train_loader, val_loader
196
+
197
+
198
+ def train_model(
199
+ model: AamDiffusionModel,
200
+ tokenizer: AamTokenizer,
201
+ train_loader,
202
+ val_loader,
203
+ config: AamDiffusionConfig,
204
+ output_dir: Path,
205
+ args,
206
+ ):
207
+ """Train the AAM Diffusion Model."""
208
+ logger.info("=" * 60)
209
+ logger.info("STEP 4: Training AAM Diffusion LLM")
210
+ logger.info("=" * 60)
211
+
212
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
+ logger.info(f" Device: {device}")
214
+ logger.info(f" Parameters: {model._format_params(model.get_num_params())}")
215
+
216
+ model.to(device)
217
+
218
+ # Optimizer
219
+ optimizer = torch.optim.AdamW(
220
+ model.parameters(),
221
+ lr=args.learning_rate,
222
+ weight_decay=config.training.weight_decay,
223
+ betas=(config.training.adam_beta1, config.training.adam_beta2),
224
+ )
225
+
226
+ # LR scheduler with warmup
227
+ warmup_steps = min(200, args.max_steps // 10)
228
+
229
+ def lr_lambda(step):
230
+ if step < warmup_steps:
231
+ return step / max(warmup_steps, 1)
232
+ progress = (step - warmup_steps) / max(args.max_steps - warmup_steps, 1)
233
+ return 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.14159)).item())
234
+
235
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
236
+
237
+ # Training loop
238
+ global_step = 0
239
+ best_val_loss = float("inf")
240
+ train_losses = []
241
+ start_time = time.time()
242
+
243
+ logger.info(f" Max steps: {args.max_steps}")
244
+ logger.info(f" Batch size: {args.batch_size}")
245
+ logger.info(f" Learning rate: {args.learning_rate}")
246
+ logger.info(f" Warmup steps: {warmup_steps}")
247
+ logger.info("")
248
+
249
+ epoch = 0
250
+ while global_step < args.max_steps:
251
+ epoch += 1
252
+ model.train()
253
+ epoch_loss = 0.0
254
+ n_batches = 0
255
+
256
+ for batch_idx, batch in enumerate(train_loader):
257
+ if global_step >= args.max_steps:
258
+ break
259
+
260
+ # Move batch to device
261
+ batch = {
262
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
263
+ for k, v in batch.items()
264
+ }
265
+
266
+ # Sample random timesteps
267
+ batch_size = batch["token_ids"].shape[0]
268
+ t = torch.randint(
269
+ 0, config.diffusion.n_timesteps,
270
+ (batch_size,), device=device,
271
+ )
272
+
273
+ # Forward pass
274
+ predicted, target = model(
275
+ token_ids=batch["token_ids"],
276
+ timestep=t,
277
+ evidence_ids=batch.get("evidence_ids"),
278
+ evidence_confidence=batch.get("evidence_confidence"),
279
+ anomaly_ids=batch.get("anomaly_ids"),
280
+ anomaly_confidence=batch.get("anomaly_confidence"),
281
+ reasoning_ids=batch.get("reasoning_ids"),
282
+ reasoning_confidence=batch.get("reasoning_confidence"),
283
+ source_trust=batch.get("source_trust"),
284
+ )
285
+
286
+ # Compute loss
287
+ loss = model.compute_loss(predicted, target, t)
288
+
289
+ # Backward pass
290
+ optimizer.zero_grad()
291
+ loss.backward()
292
+
293
+ # Gradient clipping
294
+ torch.nn.utils.clip_grad_norm_(
295
+ model.parameters(), config.training.grad_clip_norm
296
+ )
297
+
298
+ optimizer.step()
299
+ scheduler.step()
300
+
301
+ loss_val = loss.item()
302
+ train_losses.append(loss_val)
303
+ epoch_loss += loss_val
304
+ n_batches += 1
305
+ global_step += 1
306
+
307
+ # Logging
308
+ if global_step % args.log_every == 0:
309
+ lr = optimizer.param_groups[0]["lr"]
310
+ avg_loss = sum(train_losses[-args.log_every:]) / len(train_losses[-args.log_every:])
311
+ elapsed = time.time() - start_time
312
+ steps_per_sec = global_step / max(elapsed, 1)
313
+ logger.info(
314
+ f" Step {global_step:>6d}/{args.max_steps} | "
315
+ f"Loss: {avg_loss:.4f} | "
316
+ f"LR: {lr:.2e} | "
317
+ f"Speed: {steps_per_sec:.1f} steps/s"
318
+ )
319
+
320
+ # Evaluation
321
+ if global_step % args.eval_every == 0 and val_loader is not None:
322
+ val_loss = evaluate(model, val_loader, config, device)
323
+ logger.info(f" >>> Validation loss: {val_loss:.4f}")
324
+ if val_loss < best_val_loss:
325
+ best_val_loss = val_loss
326
+ save_model(model, tokenizer, config, output_dir / "best.pt")
327
+ logger.info(f" >>> New best model saved! (val_loss: {val_loss:.4f})")
328
+
329
+ # Checkpoint
330
+ if global_step % args.save_every == 0:
331
+ save_model(model, tokenizer, config, output_dir / f"step_{global_step}.pt")
332
+
333
+ avg_epoch_loss = epoch_loss / max(n_batches, 1)
334
+ logger.info(f" Epoch {epoch} complete. Avg loss: {avg_epoch_loss:.4f}")
335
+
336
+ # Final save
337
+ save_model(model, tokenizer, config, output_dir / "final.pt")
338
+ elapsed = time.time() - start_time
339
+ logger.info("")
340
+ logger.info(f" Training complete! {global_step} steps in {elapsed/60:.1f} minutes")
341
+ logger.info(f" Best val loss: {best_val_loss:.4f}")
342
+ logger.info(f" Final train loss: {train_losses[-1]:.4f}")
343
+
344
+ return model
345
+
346
+
347
+ def evaluate(model, val_loader, config, device):
348
+ """Evaluate on validation set."""
349
+ model.eval()
350
+ total_loss = 0.0
351
+ n_batches = 0
352
+
353
+ with torch.no_grad():
354
+ for batch in val_loader:
355
+ batch = {
356
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
357
+ for k, v in batch.items()
358
+ }
359
+
360
+ batch_size = batch["token_ids"].shape[0]
361
+ t = torch.randint(
362
+ 0, config.diffusion.n_timesteps,
363
+ (batch_size,), device=device,
364
+ )
365
+
366
+ predicted, target = model(
367
+ token_ids=batch["token_ids"],
368
+ timestep=t,
369
+ evidence_ids=batch.get("evidence_ids"),
370
+ evidence_confidence=batch.get("evidence_confidence"),
371
+ anomaly_ids=batch.get("anomaly_ids"),
372
+ anomaly_confidence=batch.get("anomaly_confidence"),
373
+ reasoning_ids=batch.get("reasoning_ids"),
374
+ reasoning_confidence=batch.get("reasoning_confidence"),
375
+ source_trust=batch.get("source_trust"),
376
+ )
377
+ loss = model.compute_loss(predicted, target, t)
378
+ total_loss += loss.item()
379
+ n_batches += 1
380
+
381
+ model.train()
382
+ return total_loss / max(n_batches, 1)
383
+
384
+
385
+ def save_model(model, tokenizer, config, path):
386
+ """Save model checkpoint with tokenizer."""
387
+ path = Path(path)
388
+ path.parent.mkdir(parents=True, exist_ok=True)
389
+
390
+ checkpoint = {
391
+ "model_state_dict": model.state_dict(),
392
+ "config": config.to_dict(),
393
+ }
394
+ torch.save(checkpoint, path)
395
+
396
+
397
+ def export_for_huggingface(model, tokenizer, config, output_dir: Path):
398
+ """Export model in HuggingFace-compatible format."""
399
+ logger.info("=" * 60)
400
+ logger.info("STEP 5: Exporting for HuggingFace")
401
+ logger.info("=" * 60)
402
+
403
+ hf_dir = output_dir / "huggingface"
404
+ hf_dir.mkdir(parents=True, exist_ok=True)
405
+
406
+ # Save model weights
407
+ model_path = hf_dir / "model.pt"
408
+ model.save(str(model_path))
409
+ logger.info(f" Model saved: {model_path}")
410
+
411
+ # Save tokenizer
412
+ tokenizer_path = hf_dir / "tokenizer.json"
413
+ tokenizer.save(tokenizer_path)
414
+ logger.info(f" Tokenizer saved: {tokenizer_path}")
415
+
416
+ # Save config
417
+ config_path = hf_dir / "config.json"
418
+ config.to_json(config_path)
419
+ logger.info(f" Config saved: {config_path}")
420
+
421
+ # Save model card
422
+ model_card = f"""---
423
+ language:
424
+ - id
425
+ - en
426
+ license: mit
427
+ library_name: pytorch
428
+ tags:
429
+ - diffusion
430
+ - text-generation
431
+ - aam
432
+ - aphantasic-abstraction-model
433
+ - sentence-arrangement
434
+ - graph-conditioned
435
+ ---
436
+
437
+ # AAM Diffusion LLM v1.0
438
+
439
+ > **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
440
+
441
+ The dedicated "body" of the Aphantasic Abstraction Model (AAM) — a small diffusion LLM specifically trained to arrange sentences from structured graph data.
442
+
443
+ ## What is this?
444
+
445
+ This is NOT a general-purpose LLM. This is a SPECIALIZED sentence composer that:
446
+ - Takes **graph-structured conditioning** as input (evidence, anomalies, reasoning chains, confidence scores)
447
+ - Produces **coherent natural language narratives** through iterative denoising
448
+ - **Cannot hallucinate** — it can only narrate what the graph knows
449
+
450
+ ## Architecture
451
+
452
+ ```
453
+ Graph Conditioning Encoder → Diffusion Transformer → Noise Scheduler
454
+ (Mind input) (The Body) (Iterative refinement)
455
+ ```
456
+
457
+ ### Key Components
458
+ - **Graph Conditioning Encoder**: Encodes evidence nodes, compositions, anomalies, reasoning chains with confidence and temporal embeddings
459
+ - **Diffusion Transformer**: Core denoising network with adaptive layer norm, self-attention, and cross-attention to graph conditioning
460
+ - **Noise Scheduler**: Cosine noise schedule with DDPM/DDIM sampling support
461
+
462
+ ## Model Details
463
+
464
+ | Parameter | Value |
465
+ |-----------|-------|
466
+ | Architecture | Diffusion Transformer |
467
+ | d_model | {config.model.d_model} |
468
+ | n_layers | {config.model.n_layers} |
469
+ | n_heads | {config.model.n_heads} |
470
+ | d_ff | {config.model.d_ff} |
471
+ | Parameters | {model._format_params(model.get_num_params())} |
472
+ | Vocab size | {config.model.vocab_size} |
473
+ | Max sequence length | {config.model.max_seq_len} |
474
+ | Diffusion timesteps (train) | {config.diffusion.n_timesteps} |
475
+ | Diffusion timesteps (inference) | {config.diffusion.n_inference_steps} |
476
+ | Noise schedule | {config.diffusion.schedule_type} |
477
+ | Prediction type | {config.diffusion.prediction_type} |
478
+ | Sampling method | {config.diffusion.sampling_method} |
479
+
480
+ ## Usage
481
+
482
+ ```python
483
+ from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
484
+
485
+ # Load model
486
+ config = AamDiffusionConfig.from_json("config.json")
487
+ model = AamDiffusionModel.load("model.pt")
488
+ tokenizer = AamTokenizer.load("tokenizer.json")
489
+
490
+ # Create generator
491
+ generator = AamGenerator(model, tokenizer, config)
492
+
493
+ # Generate narrative from graph conditioning
494
+ result = generator.generate(
495
+ trigger="Siapa yang mencuri Snow Plum Pill?",
496
+ evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
497
+ anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
498
+ reasoning_steps=["Cross-reference tanggal kejadian"],
499
+ source_trust=0.85,
500
+ )
501
+ print(result.narrative)
502
+ ```
503
+
504
+ ## Philosophy
505
+
506
+ **AAM = 1 Mind + 1 Body**
507
+
508
+ - **Mind** = RSVS Knowledge Graph (structural memory, perfect recall, relational understanding)
509
+ - **Body** = This Diffusion LLM (sentence arranger, graph-conditioned, anti-hallucination)
510
+
511
+ Unlike using a rented LLM (GPT, Claude) as the "body", this model is specifically trained for AAM:
512
+ - It cannot generate information not present in the graph conditioning
513
+ - It arranges sentences based on structured evidence
514
+ - It uses diffusion (non-sequential generation) instead of autoregressive generation
515
+ - It is small ({model._format_params(model.get_num_params())}) but specialized
516
+
517
+ ## Training
518
+
519
+ Trained on synthetic Graph→Narrative pairs with:
520
+ - Indonesian and English narrative templates
521
+ - Evidence nodes, anomalies, reasoning chains
522
+ - Confidence score distributions
523
+ - Source trust scores
524
+
525
+ ## License
526
+
527
+ MIT
528
+ """
529
+ model_card_path = hf_dir / "README.md"
530
+ with open(model_card_path, "w", encoding="utf-8") as f:
531
+ f.write(model_card)
532
+ logger.info(f" Model card saved: {model_card_path}")
533
+
534
+ # Copy full framework code
535
+ import shutil
536
+ framework_src = Path(__file__).parent.parent # diffusion_llm/
537
+ framework_dst = hf_dir / "diffusion_llm"
538
+ if framework_dst.exists():
539
+ shutil.rmtree(framework_dst)
540
+ shutil.copytree(framework_src, framework_dst,
541
+ ignore=shutil.ignore_patterns('__pycache__', '*.pyc', 'output', 'data'))
542
+ logger.info(f" Framework code copied to: {framework_dst}")
543
+
544
+ # Save training script
545
+ train_script_dst = hf_dir / "train.py"
546
+ shutil.copy2(Path(__file__), train_script_dst)
547
+
548
+ # Save inference example
549
+ inference_example = hf_dir / "inference_example.py"
550
+ with open(inference_example, "w", encoding="utf-8") as f:
551
+ f.write('''#!/usr/bin/env python3
552
+ """AAM Diffusion LLM — Inference Example"""
553
+
554
+ import sys
555
+ from pathlib import Path
556
+ sys.path.insert(0, str(Path(__file__).parent))
557
+
558
+ import torch
559
+ from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
560
+
561
+ def main():
562
+ # Load model and tokenizer
563
+ config = AamDiffusionConfig.from_json("config.json")
564
+ model = AamDiffusionModel.load("model.pt", device="cpu")
565
+ tokenizer = AamTokenizer.load("tokenizer.json")
566
+
567
+ # Create generator
568
+ generator = AamGenerator(model, tokenizer, config)
569
+
570
+ # Generate narrative
571
+ result = generator.generate(
572
+ trigger="Siapa yang mencuri Snow Plum Pill?",
573
+ evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
574
+ anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
575
+ reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
576
+ source_trust=0.85,
577
+ )
578
+
579
+ print("=" * 60)
580
+ print(" AAM Diffusion LLM — Generated Narrative")
581
+ print("=" * 60)
582
+ print(f" Trigger: {result.evidence_used}")
583
+ print(f" Narrative: {result.narrative}")
584
+ print(f" Confidence: {result.confidence:.1%}")
585
+ print(f" Steps: {result.n_diffusion_steps}")
586
+ print(f" Time: {result.generation_time_s:.2f}s")
587
+
588
+ if __name__ == "__main__":
589
+ main()
590
+ ''')
591
+ logger.info(f" Inference example saved: {inference_example}")
592
+
593
+ logger.info(f"\n HuggingFace export complete: {hf_dir}")
594
+ return hf_dir
595
+
596
+
597
+ def main():
598
+ args = parse_args()
599
+ set_seed(args.seed)
600
+
601
+ output_dir = Path(args.output_dir)
602
+ output_dir.mkdir(parents=True, exist_ok=True)
603
+
604
+ print("=" * 60)
605
+ print(" AAM Diffusion LLM — Final Training")
606
+ print(" \"1 Pikiran + 1 Tubuh\" (1 Mind + 1 Body)")
607
+ print("=" * 60)
608
+ print()
609
+
610
+ # Get config
611
+ config = get_default_config(args.model_size)
612
+
613
+ # CPU-optimized overrides for faster training
614
+ config.model.max_seq_len = 128
615
+ config.model.vocab_size = 8000
616
+ config.graph_encoder.max_evidence_nodes = 10
617
+ config.graph_encoder.max_anomalies = 5
618
+ config.graph_encoder.max_reasoning_steps = 5
619
+ config.graph_encoder.max_compositions = 5
620
+ config.diffusion.n_timesteps = 200
621
+ config.diffusion.n_inference_steps = 20
622
+ config.tokenizer.bpe_vocab_size = 8000 - 13 # minus special tokens
623
+
624
+ # Override settings for CPU training
625
+ config.training.batch_size = args.batch_size
626
+ config.training.learning_rate = args.learning_rate
627
+ config.training.max_steps = args.max_steps
628
+ config.training.use_amp = False # No AMP on CPU
629
+ config.training.num_workers = 0 # No multiprocessing on CPU
630
+ config.training.warmup_steps = min(100, args.max_steps // 5)
631
+ config.output_dir = str(output_dir)
632
+ config.seed = args.seed
633
+ config.model_name = "aam-diffusion-v1.0"
634
+
635
+ # Print config
636
+ print(config.summary())
637
+
638
+ # Step 1: Generate synthetic data
639
+ train_path, val_path = generate_data(
640
+ output_dir, args.n_synthetic_train, args.n_synthetic_val, args.seed
641
+ )
642
+
643
+ # Step 2: Train tokenizer
644
+ tokenizer = train_tokenizer(train_path, output_dir, config)
645
+
646
+ # Update vocab_size to match actual tokenizer
647
+ actual_vocab = tokenizer.vocab_size
648
+ if actual_vocab != config.model.vocab_size:
649
+ logger.info(f" Updating vocab_size: {config.model.vocab_size} → {actual_vocab}")
650
+ config.model.vocab_size = actual_vocab
651
+
652
+ # Step 3: Create dataloaders
653
+ train_loader, val_loader = create_dataloaders(
654
+ train_path, val_path, tokenizer, config
655
+ )
656
+
657
+ # Step 4: Create and train model
658
+ model = AamDiffusionModel(config)
659
+ logger.info(f" Model parameters: {model._format_params(model.get_num_params())}")
660
+
661
+ model = train_model(
662
+ model, tokenizer, train_loader, val_loader,
663
+ config, output_dir, args
664
+ )
665
+
666
+ # Step 5: Export for HuggingFace
667
+ hf_dir = export_for_huggingface(model, tokenizer, config, output_dir)
668
+
669
+ # Final summary
670
+ print()
671
+ print("=" * 60)
672
+ print(" TRAINING COMPLETE!")
673
+ print("=" * 60)
674
+ print(f" Model: {config.model_name}")
675
+ print(f" Parameters: {model._format_params(model.get_num_params())}")
676
+ print(f" Output: {output_dir}")
677
+ print(f" HuggingFace export: {hf_dir}")
678
+ print()
679
+ print(" AAM = 1 Pikiran + 1 Tubuh")
680
+ print(" Pikiran = RSVS Knowledge Graph")
681
+ print(" Tubuh = This Diffusion LLM")
682
+ print("=" * 60)
683
+
684
+
685
+ if __name__ == "__main__":
686
+ main()
diffusion_llm/scripts/train_minimal.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AAM Diffusion LLM — Minimal Training Script for CPU
4
+
5
+ Trains a very small AAM Diffusion LLM model on CPU.
6
+ """
7
+
8
+ import sys
9
+ import json
10
+ import time
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ import torch
17
+ import numpy as np
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
20
+ logger = logging.getLogger("train")
21
+
22
+ def main():
23
+ from diffusion_llm.config.model_config import (
24
+ AamDiffusionConfig, ModelConfig, DiffusionConfig,
25
+ GraphEncoderConfig, TokenizerConfig, TrainingConfig, InferenceConfig,
26
+ )
27
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
28
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
29
+ from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
30
+ from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
31
+ from torch.utils.data import DataLoader
32
+
33
+ output_dir = Path("./aam-diffusion-v1")
34
+ output_dir.mkdir(parents=True, exist_ok=True)
35
+ data_dir = output_dir / "data"
36
+ data_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ # ===== STEP 1: Generate Data =====
39
+ logger.info("STEP 1: Generating synthetic data...")
40
+ train_path, val_path = SyntheticDataGenerator.generate_training_split(
41
+ output_dir=data_dir, n_train=200, n_val=20, language="id", seed=42,
42
+ )
43
+
44
+ # ===== STEP 2: Train Tokenizer =====
45
+ logger.info("STEP 2: Training tokenizer...")
46
+ tokenizer = AamTokenizer()
47
+
48
+ texts = []
49
+ with open(train_path, "r", encoding="utf-8") as f:
50
+ for line in f:
51
+ line = line.strip()
52
+ if not line:
53
+ continue
54
+ try:
55
+ data = json.loads(line)
56
+ for key in ["narrative", "trigger"]:
57
+ if data.get(key):
58
+ texts.append(data[key])
59
+ for key in ["evidence_nodes", "anomalies", "reasoning_steps"]:
60
+ for item in data.get(key, []):
61
+ texts.append(item)
62
+ except json.JSONDecodeError:
63
+ continue
64
+
65
+ tokenizer.train(texts, vocab_size=2000)
66
+ tokenizer.save(data_dir / "tokenizer.json")
67
+ actual_vocab = tokenizer.vocab_size
68
+ logger.info(f" Tokenizer: vocab_size={actual_vocab}, merges={len(tokenizer.merges)}")
69
+
70
+ # ===== STEP 3: Config =====
71
+ config = AamDiffusionConfig(
72
+ model=ModelConfig(
73
+ d_model=128,
74
+ n_layers=2,
75
+ n_heads=4,
76
+ d_ff=256,
77
+ vocab_size=actual_vocab,
78
+ max_seq_len=64,
79
+ pos_encoding_type="learned",
80
+ use_flash_attention=False,
81
+ norm_type="layernorm",
82
+ init_std=0.02,
83
+ ),
84
+ diffusion=DiffusionConfig(
85
+ n_timesteps=100,
86
+ n_inference_steps=10,
87
+ schedule_type="cosine",
88
+ prediction_type="epsilon",
89
+ loss_type="mse",
90
+ loss_weighting="none",
91
+ ),
92
+ graph_encoder=GraphEncoderConfig(
93
+ d_graph=64,
94
+ n_graph_layers=1,
95
+ n_graph_heads=2,
96
+ max_evidence_nodes=5,
97
+ max_compositions=3,
98
+ max_anomalies=3,
99
+ max_reasoning_steps=3,
100
+ conditioning_method="cross_attention",
101
+ embed_confidence=False,
102
+ embed_temporal=False,
103
+ ),
104
+ tokenizer=TokenizerConfig(bpe_vocab_size=2000),
105
+ training=TrainingConfig(
106
+ batch_size=4,
107
+ learning_rate=1e-3,
108
+ max_steps=100,
109
+ warmup_steps=10,
110
+ use_amp=False,
111
+ num_workers=0,
112
+ grad_clip_norm=1.0,
113
+ ),
114
+ inference=InferenceConfig(n_steps=10),
115
+ model_name="aam-diffusion-v1.0",
116
+ output_dir=str(output_dir),
117
+ seed=42,
118
+ )
119
+
120
+ # ===== STEP 4: Create Model =====
121
+ logger.info("STEP 3: Creating model...")
122
+ model = AamDiffusionModel(config)
123
+ n_params = model.get_num_params()
124
+ logger.info(f" Parameters: {model._format_params(n_params)} ({n_params:,})")
125
+
126
+ # ===== STEP 5: Create DataLoaders =====
127
+ logger.info("STEP 4: Creating dataloaders...")
128
+ train_dataset = GraphNarrativeDataset(
129
+ data_path=train_path, tokenizer=tokenizer,
130
+ max_seq_len=config.model.max_seq_len,
131
+ max_evidence=config.graph_encoder.max_evidence_nodes,
132
+ max_anomalies=config.graph_encoder.max_anomalies,
133
+ max_reasoning=config.graph_encoder.max_reasoning_steps,
134
+ augment=True,
135
+ )
136
+ val_dataset = GraphNarrativeDataset(
137
+ data_path=val_path, tokenizer=tokenizer,
138
+ max_seq_len=config.model.max_seq_len,
139
+ max_evidence=config.graph_encoder.max_evidence_nodes,
140
+ max_anomalies=config.graph_encoder.max_anomalies,
141
+ max_reasoning=config.graph_encoder.max_reasoning_steps,
142
+ augment=False,
143
+ )
144
+
145
+ train_loader = DataLoader(
146
+ train_dataset, batch_size=4, shuffle=True,
147
+ num_workers=0, collate_fn=collate_fn,
148
+ )
149
+ val_loader = DataLoader(
150
+ val_dataset, batch_size=4, shuffle=False,
151
+ num_workers=0, collate_fn=collate_fn,
152
+ )
153
+
154
+ # ===== STEP 6: Train =====
155
+ logger.info("STEP 5: Training...")
156
+ device = torch.device("cpu")
157
+ model.to(device)
158
+
159
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
160
+ max_steps = 100
161
+
162
+ start_time = time.time()
163
+ global_step = 0
164
+ train_losses = []
165
+
166
+ for epoch in range(50): # Max epochs
167
+ model.train()
168
+ for batch in train_loader:
169
+ if global_step >= max_steps:
170
+ break
171
+
172
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
173
+ for k, v in batch.items()}
174
+
175
+ batch_size = batch["token_ids"].shape[0]
176
+ t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
177
+
178
+ predicted, target = model(
179
+ token_ids=batch["token_ids"],
180
+ timestep=t,
181
+ evidence_ids=batch.get("evidence_ids"),
182
+ evidence_confidence=batch.get("evidence_confidence"),
183
+ anomaly_ids=batch.get("anomaly_ids"),
184
+ anomaly_confidence=batch.get("anomaly_confidence"),
185
+ reasoning_ids=batch.get("reasoning_ids"),
186
+ reasoning_confidence=batch.get("reasoning_confidence"),
187
+ source_trust=batch.get("source_trust"),
188
+ )
189
+
190
+ loss = model.compute_loss(predicted, target, t)
191
+ optimizer.zero_grad()
192
+ loss.backward()
193
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
194
+ optimizer.step()
195
+
196
+ train_losses.append(loss.item())
197
+ global_step += 1
198
+
199
+ if global_step % 10 == 0:
200
+ avg = sum(train_losses[-10:]) / len(train_losses[-10:])
201
+ elapsed = time.time() - start_time
202
+ logger.info(f" Step {global_step}/{max_steps} | Loss: {avg:.4f} | Time: {elapsed:.1f}s")
203
+
204
+ if global_step >= max_steps:
205
+ break
206
+
207
+ # ===== STEP 7: Evaluate =====
208
+ logger.info("STEP 6: Evaluating...")
209
+ model.eval()
210
+ val_losses = []
211
+ with torch.no_grad():
212
+ for batch in val_loader:
213
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
214
+ for k, v in batch.items()}
215
+ batch_size = batch["token_ids"].shape[0]
216
+ t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
217
+
218
+ predicted, target = model(
219
+ token_ids=batch["token_ids"],
220
+ timestep=t,
221
+ evidence_ids=batch.get("evidence_ids"),
222
+ evidence_confidence=batch.get("evidence_confidence"),
223
+ anomaly_ids=batch.get("anomaly_ids"),
224
+ anomaly_confidence=batch.get("anomaly_confidence"),
225
+ reasoning_ids=batch.get("reasoning_ids"),
226
+ reasoning_confidence=batch.get("reasoning_confidence"),
227
+ source_trust=batch.get("source_trust"),
228
+ )
229
+ loss = model.compute_loss(predicted, target, t)
230
+ val_losses.append(loss.item())
231
+
232
+ avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else 0
233
+ logger.info(f" Val loss: {avg_val_loss:.4f}")
234
+
235
+ # ===== STEP 8: Save =====
236
+ logger.info("STEP 7: Saving model...")
237
+
238
+ # Save model
239
+ model_path = output_dir / "model.pt"
240
+ torch.save({
241
+ "model_state_dict": model.state_dict(),
242
+ "config": config.to_dict(),
243
+ }, model_path)
244
+
245
+ # Save tokenizer (already saved)
246
+ # Save config
247
+ config.to_json(output_dir / "config.json")
248
+
249
+ elapsed = time.time() - start_time
250
+ logger.info(f"\n DONE! {global_step} steps in {elapsed:.1f}s")
251
+ logger.info(f" Final train loss: {train_losses[-1]:.4f}")
252
+ logger.info(f" Val loss: {avg_val_loss:.4f}")
253
+ logger.info(f" Parameters: {model._format_params(n_params)}")
254
+ logger.info(f" Output: {output_dir}")
255
+
256
+ return model, tokenizer, config, output_dir
257
+
258
+
259
+ if __name__ == "__main__":
260
+ main()
diffusion_llm/tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Tests for AAM Diffusion LLM framework."""
diffusion_llm/tests/test_model.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for AAM Diffusion Model components."""
2
+
3
+ import torch
4
+ import pytest
5
+ from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config, ModelConfig
6
+ from diffusion_llm.model.noise_scheduler import NoiseScheduler
7
+ from diffusion_llm.model.graph_encoder import GraphConditioningEncoder, GraphEncoderConfig
8
+ from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
9
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
10
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
11
+
12
+
13
+ class TestConfig:
14
+ """Test configuration system."""
15
+
16
+ def test_default_config(self):
17
+ """Test default configuration creation."""
18
+ config = get_default_config("base")
19
+ assert config.model.d_model == 768
20
+ assert config.model.n_layers == 12
21
+ assert config.diffusion.n_timesteps == 1000
22
+
23
+ def test_tiny_config(self):
24
+ """Test tiny model configuration."""
25
+ config = get_default_config("tiny")
26
+ assert config.model.d_model == 256
27
+ assert config.model.n_layers == 4
28
+
29
+ def test_config_serialization(self, tmp_path):
30
+ """Test config save/load roundtrip."""
31
+ config = get_default_config("small")
32
+ path = tmp_path / "config.json"
33
+ config.to_json(path)
34
+
35
+ loaded = AamDiffusionConfig.from_json(path)
36
+ assert loaded.model.d_model == config.model.d_model
37
+ assert loaded.model.n_layers == config.model.n_layers
38
+
39
+ def test_param_estimation(self):
40
+ """Test parameter count estimation."""
41
+ config = ModelConfig(d_model=768, n_layers=12, d_ff=3072)
42
+ params = config.estimate_params()
43
+ assert "M" in params # Should be in millions
44
+
45
+
46
+ class TestTokenizer:
47
+ """Test AAM Tokenizer."""
48
+
49
+ def test_basic_encoding(self):
50
+ """Test basic text encoding."""
51
+ tokenizer = AamTokenizer()
52
+ # Train on sample text first
53
+ tokenizer.train(["Hello world this is a test", "Another test sentence"])
54
+
55
+ ids = tokenizer.encode("Hello world")
56
+ assert isinstance(ids, list)
57
+ assert len(ids) > 0
58
+ assert ids[0] == tokenizer.bos_id
59
+ assert ids[-1] == tokenizer.eos_id
60
+
61
+ def test_decode_roundtrip(self):
62
+ """Test encode/decode roundtrip."""
63
+ tokenizer = AamTokenizer()
64
+ texts = [
65
+ "Berdasarkan analisis, pencuri adalah Diancang.",
66
+ "Anomali terdeteksi dalam laporan Hefei.",
67
+ "Evidence: Ju Jangmok, Snow Plum Pill.",
68
+ ]
69
+ tokenizer.train(texts)
70
+
71
+ for text in texts:
72
+ ids = tokenizer.encode(text)
73
+ decoded = tokenizer.decode(ids, skip_special=True)
74
+ # Decoded text should contain key words
75
+ assert len(decoded) > 0
76
+
77
+ def test_special_tokens(self):
78
+ """Test special token IDs."""
79
+ tokenizer = AamTokenizer()
80
+ assert tokenizer.pad_id == 0
81
+ assert tokenizer.bos_id == 1
82
+ assert tokenizer.eos_id == 2
83
+
84
+ def test_sentence_boundaries(self):
85
+ """Test sentence boundary detection."""
86
+ tokenizer = AamTokenizer()
87
+ ids = [1, 10, 20, 5, 30, 40, 5, 50, 2] # BOS, sent, sent, EOS
88
+ boundaries = tokenizer.get_sentence_boundaries(ids)
89
+ assert 3 in boundaries # Index of <sent> token
90
+ assert 6 in boundaries
91
+
92
+ def test_save_load(self, tmp_path):
93
+ """Test tokenizer save/load."""
94
+ tokenizer = AamTokenizer()
95
+ tokenizer.train(["Test text for tokenizer", "Another training example"])
96
+
97
+ path = tmp_path / "tokenizer.json"
98
+ tokenizer.save(path)
99
+
100
+ loaded = AamTokenizer.load(path)
101
+ assert loaded.vocab_size == tokenizer.vocab_size
102
+ assert loaded.is_trained
103
+
104
+ def test_structure_encoding(self):
105
+ """Test encoding with graph structure tokens."""
106
+ tokenizer = AamTokenizer()
107
+ tokenizer.train(["Evidence text", "Anomaly description", "Reasoning step"])
108
+
109
+ ids = tokenizer.encode_with_structure(
110
+ text="Main narrative text",
111
+ evidence_nodes=["evidence1", "evidence2"],
112
+ anomalies=["anomaly1"],
113
+ )
114
+ assert isinstance(ids, list)
115
+ assert len(ids) > 0
116
+
117
+ def test_padding(self):
118
+ """Test sequence padding."""
119
+ tokenizer = AamTokenizer()
120
+ ids = [1, 2, 3]
121
+ padded = tokenizer.pad_sequence(ids, max_len=10)
122
+ assert len(padded) == 10
123
+ assert padded[3:] == [0] * 7 # Padded with pad_id
124
+
125
+
126
+ class TestDiffusionTransformer:
127
+ """Test Diffusion Transformer model."""
128
+
129
+ def test_forward_pass(self):
130
+ """Test basic forward pass."""
131
+ config = ModelConfig(
132
+ d_model=128, n_layers=2, n_heads=4, d_ff=256,
133
+ vocab_size=1000, max_seq_len=64,
134
+ )
135
+ model = DiffusionTransformer(config)
136
+
137
+ x_t = torch.randn(2, 32, 128) # batch=2, seq=32, d=128
138
+ t = torch.tensor([100, 500])
139
+
140
+ output = model(x_t=x_t, t=t)
141
+ assert output.shape == (2, 32, 128)
142
+
143
+ def test_with_graph_conditioning(self):
144
+ """Test forward pass with graph conditioning."""
145
+ config = ModelConfig(
146
+ d_model=128, n_layers=2, n_heads=4, d_ff=256,
147
+ vocab_size=1000, max_seq_len=64,
148
+ )
149
+ model = DiffusionTransformer(config)
150
+
151
+ x_t = torch.randn(2, 32, 128)
152
+ t = torch.tensor([100, 500])
153
+ graph_keys = torch.randn(2, 10, 128) # 10 graph nodes
154
+ graph_values = torch.randn(2, 10, 128)
155
+
156
+ output = model(x_t=x_t, t=t, graph_keys=graph_keys, graph_values=graph_values)
157
+ assert output.shape == (2, 32, 128)
158
+
159
+
160
+ class TestAamDiffusionModel:
161
+ """Test complete AAM Diffusion Model."""
162
+
163
+ def test_model_creation_tiny(self):
164
+ """Test creating a tiny model."""
165
+ config = get_default_config("tiny")
166
+ model = AamDiffusionModel(config)
167
+ n_params = model.get_num_params()
168
+ assert n_params > 0
169
+ assert n_params < 100e6 # Tiny should be under 100M
170
+
171
+ def test_forward_training(self):
172
+ """Test training forward pass."""
173
+ config = get_default_config("tiny")
174
+ model = AamDiffusionModel(config)
175
+ model.eval()
176
+
177
+ token_ids = torch.randint(0, config.model.vocab_size, (2, 32))
178
+ timestep = torch.randint(0, config.diffusion.n_timesteps, (2,))
179
+
180
+ with torch.no_grad():
181
+ predicted, noise = model(token_ids=token_ids, timestep=timestep)
182
+
183
+ assert predicted.shape == noise.shape
184
+
185
+ def test_loss_computation(self):
186
+ """Test loss computation."""
187
+ config = get_default_config("tiny")
188
+ model = AamDiffusionModel(config)
189
+ model.eval()
190
+
191
+ token_ids = torch.randint(0, config.model.vocab_size, (2, 32))
192
+ timestep = torch.randint(0, config.diffusion.n_timesteps, (2,))
193
+
194
+ with torch.no_grad():
195
+ predicted, noise = model(token_ids=token_ids, timestep=timestep)
196
+ loss = model.compute_loss(predicted, noise, timestep)
197
+
198
+ assert loss.item() >= 0
199
+ assert not torch.isnan(loss)
200
+
201
+ def test_save_load(self, tmp_path):
202
+ """Test model save/load."""
203
+ config = get_default_config("tiny")
204
+ model = AamDiffusionModel(config)
205
+
206
+ path = str(tmp_path / "model.pt")
207
+ model.save(path)
208
+
209
+ loaded = AamDiffusionModel.load(path)
210
+ assert loaded.config.model.d_model == config.model.d_model
211
+
212
+
213
+ class TestGraphEncoder:
214
+ """Test Graph Conditioning Encoder."""
215
+
216
+ def test_evidence_encoding(self):
217
+ """Test encoding evidence nodes."""
218
+ config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4)
219
+ encoder = GraphConditioningEncoder(config, vocab_size=1000)
220
+
221
+ evidence_ids = torch.randint(0, 1000, (2, 5, 16)) # 2 batch, 5 nodes, 16 tokens each
222
+ evidence_conf = torch.tensor([[0.8, 0.6, 0.9, 0.7, 0.5],
223
+ [0.7, 0.8, 0.6, 0.9, 0.5]])
224
+
225
+ result = encoder(evidence_ids=evidence_ids, evidence_confidence=evidence_conf)
226
+ assert "keys" in result
227
+ assert "values" in result
228
+
229
+ def test_no_input(self):
230
+ """Test encoder with no graph data (should return zeros)."""
231
+ config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4)
232
+ encoder = GraphConditioningEncoder(config, vocab_size=1000)
233
+
234
+ result = encoder()
235
+ assert "keys" in result
236
+
237
+
238
+ if __name__ == "__main__":
239
+ pytest.main([__file__, "-v"])
diffusion_llm/tests/test_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Noise Scheduler."""
2
+
3
+ import torch
4
+ import pytest
5
+ from diffusion_llm.model.noise_scheduler import NoiseScheduler
6
+
7
+
8
+ class TestNoiseScheduler:
9
+ """Test suite for the NoiseScheduler."""
10
+
11
+ def test_cosine_schedule(self):
12
+ """Test cosine noise schedule creation."""
13
+ scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="cosine")
14
+ assert scheduler.betas.shape == (1000,)
15
+ assert (scheduler.betas > 0).all()
16
+ assert (scheduler.betas < 1).all()
17
+
18
+ def test_linear_schedule(self):
19
+ """Test linear noise schedule creation."""
20
+ scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="linear")
21
+ assert scheduler.betas.shape == (1000,)
22
+ assert scheduler.betas[0] < scheduler.betas[-1] # Increasing
23
+
24
+ def test_sigmoid_schedule(self):
25
+ """Test sigmoid noise schedule creation."""
26
+ scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="sigmoid")
27
+ assert scheduler.betas.shape == (1000,)
28
+ assert (scheduler.betas > 0).all()
29
+
30
+ def test_add_noise(self):
31
+ """Test forward diffusion (adding noise)."""
32
+ scheduler = NoiseScheduler(n_timesteps=1000)
33
+ x_0 = torch.randn(2, 10, 64) # batch=2, seq=10, d=64
34
+ noise = torch.randn_like(x_0)
35
+ t = torch.tensor([0, 500])
36
+
37
+ x_t = scheduler.add_noise(x_0, noise, t)
38
+ assert x_t.shape == x_0.shape
39
+ # At t=0, x_t should be close to x_0
40
+ # At t=500, x_t should be significantly different
41
+
42
+ def test_loss_target_epsilon(self):
43
+ """Test epsilon prediction target."""
44
+ scheduler = NoiseScheduler(prediction_type="epsilon")
45
+ x_0 = torch.randn(2, 10, 64)
46
+ noise = torch.randn_like(x_0)
47
+ t = torch.tensor([100, 500])
48
+
49
+ target = scheduler.compute_loss_target(x_0, noise, t)
50
+ assert torch.allclose(target, noise)
51
+
52
+ def test_loss_target_x0(self):
53
+ """Test x0 prediction target."""
54
+ scheduler = NoiseScheduler(prediction_type="x0")
55
+ x_0 = torch.randn(2, 10, 64)
56
+ noise = torch.randn_like(x_0)
57
+ t = torch.tensor([100, 500])
58
+
59
+ target = scheduler.compute_loss_target(x_0, noise, t)
60
+ assert torch.allclose(target, x_0)
61
+
62
+ def test_predict_x0_from_epsilon(self):
63
+ """Test x0 prediction from epsilon."""
64
+ scheduler = NoiseScheduler(prediction_type="epsilon")
65
+ x_0 = torch.randn(2, 10, 64)
66
+ noise = torch.randn_like(x_0)
67
+ t = torch.tensor([100])
68
+
69
+ x_t = scheduler.add_noise(x_0, noise, t)
70
+ x_0_pred = scheduler.predict_x0_from_epsilon(x_t, noise, t)
71
+ # Should be close to original x_0
72
+ assert x_0_pred.shape == x_0.shape
73
+
74
+ def test_ddpm_step(self):
75
+ """Test single DDPM reverse step."""
76
+ scheduler = NoiseScheduler(n_timesteps=1000)
77
+ x_t = torch.randn(2, 10, 64)
78
+ model_output = torch.randn_like(x_t)
79
+ t = torch.tensor([500, 500])
80
+
81
+ x_prev = scheduler.step_ddpm(model_output, x_t, t)
82
+ assert x_prev.shape == x_t.shape
83
+
84
+ def test_ddim_step(self):
85
+ """Test single DDIM reverse step."""
86
+ scheduler = NoiseScheduler(n_timesteps=1000)
87
+ x_t = torch.randn(2, 10, 64)
88
+ model_output = torch.randn_like(x_t)
89
+
90
+ x_prev = scheduler.step_ddim(model_output, x_t, t=500, t_prev=400)
91
+ assert x_prev.shape == x_t.shape
92
+
93
+ def test_timestep_schedule(self):
94
+ """Test inference timestep schedule."""
95
+ scheduler = NoiseScheduler(n_timesteps=1000)
96
+ schedule = scheduler.get_timestep_schedule(n_inference_steps=50)
97
+ assert len(schedule) > 0
98
+ assert schedule[0] > schedule[-1] # Descending order
diffusion_llm/tokenizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Tokenizer module for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
4
+
5
+ __all__ = ["AamTokenizer"]
diffusion_llm/tokenizer/aam_tokenizer.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Tokenizer
3
+
4
+ Sentence-level + subword BPE hybrid tokenizer designed specifically
5
+ for AAM's sentence arrangement task.
6
+
7
+ Unlike standard tokenizers (GPT-2 BPE, SentencePiece) that tokenize
8
+ at the subword level, AAM's tokenizer is designed with SENTENCE
9
+ ARRANGEMENT in mind:
10
+
11
+ 1. Sentences are the primary unit of generation (not individual tokens)
12
+ 2. Within sentences, subword BPE handles individual words
13
+ 3. Special tokens for graph structure (evidence, anomaly, confidence)
14
+ 4. Sentence boundary markers for the diffusion model
15
+
16
+ The tokenizer maintains two levels:
17
+ - Sentence level: Where sentences begin/end, for the diffusion model
18
+ to arrange and revise non-sequentially
19
+ - Token level: Subword units within sentences, for detailed generation
20
+
21
+ Analogi: Jin Soun tidak berpikir dalam kata-per-kata — dia
22
+ berpikir dalam KALIMAT. "Pencuri = Diancang pair. Ju Jangmok = cover."
23
+ Setiap kalimat sudah utuh, yang dia susun adalah URUTAN kalimat.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import json
29
+ import re
30
+ import unicodedata
31
+ from collections import Counter
32
+ from pathlib import Path
33
+ from typing import Optional
34
+
35
+ from diffusion_llm.config.model_config import TokenizerConfig
36
+
37
+
38
+ # Special token IDs (always at the start of vocabulary)
39
+ SPECIAL_TOKENS = [
40
+ "<pad>", # 0
41
+ "<bos>", # 1
42
+ "<eos>", # 2
43
+ "<mask>", # 3
44
+ "<noise>", # 4
45
+ "<sent>", # 5 - sentence boundary
46
+ "<evidence>", # 6
47
+ "<anomaly>", # 7
48
+ "<confidence>", # 8
49
+ "<reasoning>", # 9
50
+ "<composition>",# 10
51
+ "<temporal>", # 11
52
+ "<unk>", # 12
53
+ ]
54
+
55
+
56
+ class AamTokenizer:
57
+ """AAM Sentence-Level + Subword BPE Hybrid Tokenizer.
58
+
59
+ This tokenizer is specifically designed for the AAM Diffusion LLM:
60
+ - It understands sentence boundaries (<sent> tokens)
61
+ - It has special tokens for graph structure
62
+ - It uses BPE for subword tokenization within sentences
63
+ - It can encode/decode both plain text and graph-conditioned text
64
+
65
+ Usage:
66
+ tokenizer = AamTokenizer()
67
+ tokenizer.train(texts, vocab_size=28000)
68
+
69
+ # Encode text
70
+ ids = tokenizer.encode("Berdasarkan analisis, pencuri adalah Diancang.")
71
+
72
+ # Decode back
73
+ text = tokenizer.decode(ids)
74
+
75
+ # With graph structure tokens
76
+ ids = tokenizer.encode_with_structure(
77
+ "Pencuri = Diancang pair",
78
+ evidence_nodes=["hefei", "diancang"],
79
+ anomalies=[{"desc": "no external pill consumption"}],
80
+ )
81
+ """
82
+
83
+ def __init__(self, config: Optional[TokenizerConfig] = None):
84
+ """Initialize the tokenizer.
85
+
86
+ Args:
87
+ config: Tokenizer configuration. Uses defaults if None.
88
+ """
89
+ self.config = config or TokenizerConfig()
90
+
91
+ # Build initial vocabulary with special tokens
92
+ self.vocab: dict[str, int] = {}
93
+ self.id_to_token: dict[int, str] = {}
94
+ self._init_special_tokens()
95
+
96
+ # BPE merges (learned during training)
97
+ self.merges: dict[tuple[str, str], int] = {}
98
+ self._bpe_cache: dict[str, str] = {}
99
+
100
+ # Compiled patterns
101
+ self._sentence_pattern = re.compile(
102
+ r'(?<=[.!?])\s+|(?<=\n)\s*'
103
+ )
104
+ self._word_pattern = re.compile(
105
+ r'\w+|[^\w\s]'
106
+ )
107
+
108
+ # Flag
109
+ self._is_trained = False
110
+
111
+ def _init_special_tokens(self) -> None:
112
+ """Initialize special tokens in vocabulary."""
113
+ for i, token in enumerate(SPECIAL_TOKENS):
114
+ self.vocab[token] = i
115
+ self.id_to_token[i] = token
116
+
117
+ @property
118
+ def pad_id(self) -> int:
119
+ return self.vocab[self.config.pad_token]
120
+
121
+ @property
122
+ def bos_id(self) -> int:
123
+ return self.vocab[self.config.bos_token]
124
+
125
+ @property
126
+ def eos_id(self) -> int:
127
+ return self.vocab[self.config.eos_token]
128
+
129
+ @property
130
+ def mask_id(self) -> int:
131
+ return self.vocab[self.config.mask_token]
132
+
133
+ @property
134
+ def noise_id(self) -> int:
135
+ return self.vocab[self.config.noise_token]
136
+
137
+ @property
138
+ def sent_id(self) -> int:
139
+ return self.vocab[self.config.sentence_boundary_token]
140
+
141
+ @property
142
+ def unk_id(self) -> int:
143
+ return self.vocab.get("<unk>", len(SPECIAL_TOKENS) - 1)
144
+
145
+ @property
146
+ def vocab_size(self) -> int:
147
+ """Current vocabulary size."""
148
+ return len(self.vocab)
149
+
150
+ @property
151
+ def is_trained(self) -> bool:
152
+ """Whether the tokenizer has been trained."""
153
+ return self._is_trained
154
+
155
+ def train(
156
+ self,
157
+ texts: list[str],
158
+ vocab_size: Optional[int] = None,
159
+ ) -> None:
160
+ """Train the BPE tokenizer on a corpus.
161
+
162
+ Args:
163
+ texts: List of training texts.
164
+ vocab_size: Target vocabulary size. Uses config if None.
165
+ """
166
+ target_vocab = vocab_size or self.config.bpe_vocab_size
167
+
168
+ # Step 1: Pre-tokenize into words
169
+ word_freqs: Counter = Counter()
170
+ for text in texts:
171
+ words = self._pre_tokenize(text)
172
+ for word in words:
173
+ word_freqs[word] += 1
174
+
175
+ # Step 2: Initialize character-level vocabulary
176
+ char_vocab: set[str] = set()
177
+ for word in word_freqs:
178
+ for char in word:
179
+ char_vocab.add(char)
180
+
181
+ # Add character tokens to vocabulary
182
+ for char in sorted(char_vocab):
183
+ if char not in self.vocab:
184
+ idx = len(self.vocab)
185
+ self.vocab[char] = idx
186
+ self.id_to_token[idx] = char
187
+
188
+ # Step 3: Split words into character sequences
189
+ word_splits: dict[str, list[str]] = {}
190
+ for word in word_freqs:
191
+ word_splits[word] = list(word)
192
+ # Add end-of-word marker
193
+ if len(word_splits[word]) > 1:
194
+ word_splits[word][-1] = word_splits[word][-1] + "</w>"
195
+
196
+ # Step 4: Learn BPE merges
197
+ n_merges = target_vocab - len(self.vocab)
198
+ for i in range(n_merges):
199
+ # Count pairs
200
+ pair_freqs: Counter = Counter()
201
+ for word, freq in word_freqs.items():
202
+ symbols = word_splits.get(word, [])
203
+ for j in range(len(symbols) - 1):
204
+ pair = (symbols[j], symbols[j + 1])
205
+ pair_freqs[pair] += freq
206
+
207
+ if not pair_freqs:
208
+ break
209
+
210
+ # Find most frequent pair
211
+ best_pair = pair_freqs.most_common(1)[0][0]
212
+
213
+ # Record merge
214
+ self.merges[best_pair] = i
215
+
216
+ # Apply merge
217
+ new_symbol = best_pair[0] + best_pair[1]
218
+ for word in word_splits:
219
+ symbols = word_splits[word]
220
+ new_symbols = []
221
+ j = 0
222
+ while j < len(symbols):
223
+ if (
224
+ j < len(symbols) - 1
225
+ and symbols[j] == best_pair[0]
226
+ and symbols[j + 1] == best_pair[1]
227
+ ):
228
+ new_symbols.append(new_symbol)
229
+ j += 2
230
+ else:
231
+ new_symbols.append(symbols[j])
232
+ j += 1
233
+ word_splits[word] = new_symbols
234
+
235
+ # Add merged token to vocabulary
236
+ if new_symbol not in self.vocab:
237
+ idx = len(self.vocab)
238
+ self.vocab[new_symbol] = idx
239
+ self.id_to_token[idx] = new_symbol
240
+
241
+ self._is_trained = True
242
+ self._bpe_cache.clear()
243
+
244
+ def _pre_tokenize(self, text: str) -> list[str]:
245
+ """Pre-tokenize text into words.
246
+
247
+ Args:
248
+ text: Input text.
249
+
250
+ Returns:
251
+ List of words.
252
+ """
253
+ # Normalize unicode
254
+ text = unicodedata.normalize("NFC", text)
255
+ # Split into words and punctuation
256
+ words = self._word_pattern.findall(text.lower())
257
+ return words
258
+
259
+ def _bpe_encode(self, word: str) -> list[str]:
260
+ """Apply BPE to a single word.
261
+
262
+ Args:
263
+ word: Input word (lowercase).
264
+
265
+ Returns:
266
+ List of BPE tokens.
267
+ """
268
+ if word in self._bpe_cache:
269
+ return self._bpe_cache[word].split()
270
+
271
+ # Start with character-level split
272
+ symbols = list(word)
273
+ if len(symbols) > 1:
274
+ symbols[-1] = symbols[-1] + "</w>"
275
+
276
+ # Apply merges in order
277
+ while len(symbols) > 1:
278
+ # Find the pair with the lowest merge rank
279
+ best_pair = None
280
+ best_rank = float("inf")
281
+
282
+ for i in range(len(symbols) - 1):
283
+ pair = (symbols[i], symbols[i + 1])
284
+ rank = self.merges.get(pair, float("inf"))
285
+ if rank < best_rank:
286
+ best_rank = rank
287
+ best_pair = pair
288
+
289
+ if best_pair is None or best_rank == float("inf"):
290
+ break
291
+
292
+ # Apply merge
293
+ new_symbol = best_pair[0] + best_pair[1]
294
+ new_symbols = []
295
+ i = 0
296
+ while i < len(symbols):
297
+ if (
298
+ i < len(symbols) - 1
299
+ and symbols[i] == best_pair[0]
300
+ and symbols[i + 1] == best_pair[1]
301
+ ):
302
+ new_symbols.append(new_symbol)
303
+ i += 2
304
+ else:
305
+ new_symbols.append(symbols[i])
306
+ i += 1
307
+ symbols = new_symbols
308
+
309
+ # Cache result
310
+ self._bpe_cache[word] = " ".join(symbols)
311
+ return symbols
312
+
313
+ def encode(self, text: str, add_special: bool = True) -> list[int]:
314
+ """Encode text to token IDs.
315
+
316
+ The encoding process:
317
+ 1. Split text into sentences
318
+ 2. Insert sentence boundary tokens between sentences
319
+ 3. BPE-encode each word within sentences
320
+ 4. Add BOS/EOS tokens if requested
321
+
322
+ Args:
323
+ text: Input text.
324
+ add_special: Whether to add BOS/EOS tokens.
325
+
326
+ Returns:
327
+ List of token IDs.
328
+ """
329
+ ids = []
330
+
331
+ if add_special:
332
+ ids.append(self.bos_id)
333
+
334
+ # Split into sentences
335
+ sentences = self._split_sentences(text)
336
+
337
+ for i, sentence in enumerate(sentences):
338
+ if i > 0:
339
+ ids.append(self.sent_id) # Sentence boundary
340
+
341
+ # Tokenize words in the sentence
342
+ words = self._pre_tokenize(sentence)
343
+ for word in words:
344
+ if self._is_trained:
345
+ bpe_tokens = self._bpe_encode(word)
346
+ for token in bpe_tokens:
347
+ if token in self.vocab:
348
+ ids.append(self.vocab[token])
349
+ else:
350
+ ids.append(self.unk_id)
351
+ else:
352
+ # Fallback: character-level encoding
353
+ for char in word:
354
+ if char in self.vocab:
355
+ ids.append(self.vocab[char])
356
+ else:
357
+ ids.append(self.unk_id)
358
+
359
+ if add_special:
360
+ ids.append(self.eos_id)
361
+
362
+ return ids
363
+
364
+ def encode_with_structure(
365
+ self,
366
+ text: str,
367
+ evidence_nodes: Optional[list[str]] = None,
368
+ compositions: Optional[list[str]] = None,
369
+ anomalies: Optional[list[str]] = None,
370
+ reasoning_steps: Optional[list[str]] = None,
371
+ confidence: Optional[float] = None,
372
+ ) -> list[int]:
373
+ """Encode text with graph structure tokens.
374
+
375
+ Adds structural tokens that represent the graph conditioning,
376
+ so the model knows what kind of evidence/anomalies it's
377
+ generating from.
378
+
379
+ Args:
380
+ text: The narrative text.
381
+ evidence_nodes: List of evidence node labels.
382
+ compositions: List of composition descriptions.
383
+ anomalies: List of anomaly descriptions.
384
+ reasoning_steps: List of reasoning step descriptions.
385
+ confidence: Overall confidence score.
386
+
387
+ Returns:
388
+ List of token IDs with structure tokens.
389
+ """
390
+ ids = [self.bos_id]
391
+
392
+ # Evidence section
393
+ if evidence_nodes:
394
+ ids.append(self.vocab["<evidence>"])
395
+ for node in evidence_nodes:
396
+ node_ids = self.encode(node, add_special=False)
397
+ ids.extend(node_ids)
398
+ ids.append(self.vocab["<evidence>"]) # Close section
399
+
400
+ # Anomaly section
401
+ if anomalies:
402
+ ids.append(self.vocab["<anomaly>"])
403
+ for anomaly in anomalies:
404
+ anom_ids = self.encode(anomaly, add_special=False)
405
+ ids.extend(anom_ids)
406
+ ids.append(self.vocab["<anomaly>"])
407
+
408
+ # Reasoning section
409
+ if reasoning_steps:
410
+ ids.append(self.vocab["<reasoning>"])
411
+ for step in reasoning_steps:
412
+ step_ids = self.encode(step, add_special=False)
413
+ ids.extend(step_ids)
414
+ ids.append(self.sent_id)
415
+ ids.append(self.vocab["<reasoning>"])
416
+
417
+ # Confidence
418
+ if confidence is not None:
419
+ ids.append(self.vocab["<confidence>"])
420
+ # Encode confidence as a token (discretized)
421
+ conf_bucket = min(int(confidence * 10), 9)
422
+ conf_token = f"<conf_{conf_bucket}>"
423
+ if conf_token in self.vocab:
424
+ ids.append(self.vocab[conf_token])
425
+
426
+ # Composition section
427
+ if compositions:
428
+ ids.append(self.vocab["<composition>"])
429
+ for comp in compositions:
430
+ comp_ids = self.encode(comp, add_special=False)
431
+ ids.extend(comp_ids)
432
+ ids.append(self.sent_id)
433
+ ids.append(self.vocab["<composition>"])
434
+
435
+ # Main narrative
436
+ narrative_ids = self.encode(text, add_special=False)
437
+ ids.extend(narrative_ids)
438
+
439
+ ids.append(self.eos_id)
440
+ return ids
441
+
442
+ def decode(self, ids: list[int], skip_special: bool = False) -> str:
443
+ """Decode token IDs back to text.
444
+
445
+ Args:
446
+ ids: List of token IDs.
447
+ skip_special: Whether to skip special tokens in output.
448
+
449
+ Returns:
450
+ Decoded text string.
451
+ """
452
+ special_ids = set()
453
+ if skip_special:
454
+ for token in SPECIAL_TOKENS:
455
+ if token in self.vocab:
456
+ special_ids.add(self.vocab[token])
457
+
458
+ tokens = []
459
+ for id_ in ids:
460
+ if skip_special and id_ in special_ids:
461
+ continue
462
+ if id_ in self.id_to_token:
463
+ tokens.append(self.id_to_token[id_])
464
+ else:
465
+ tokens.append("<unk>")
466
+
467
+ # Join and clean up BPE tokens
468
+ text = "".join(tokens)
469
+ text = text.replace("</w>", " ")
470
+ # Clean up sentence boundaries
471
+ text = text.replace("<sent>", ". ")
472
+ # Clean up multiple spaces
473
+ text = re.sub(r'\s+', ' ', text).strip()
474
+
475
+ return text
476
+
477
+ def _split_sentences(self, text: str) -> list[str]:
478
+ """Split text into sentences.
479
+
480
+ Args:
481
+ text: Input text.
482
+
483
+ Returns:
484
+ List of sentence strings.
485
+ """
486
+ sentences = self._sentence_pattern.split(text)
487
+ return [s.strip() for s in sentences if s.strip()]
488
+
489
+ def pad_sequence(
490
+ self,
491
+ ids: list[int],
492
+ max_len: int,
493
+ pad_id: Optional[int] = None,
494
+ ) -> list[int]:
495
+ """Pad a sequence to max_len.
496
+
497
+ Args:
498
+ ids: Token IDs.
499
+ max_len: Target length.
500
+ pad_id: Padding token ID. Uses config if None.
501
+
502
+ Returns:
503
+ Padded sequence.
504
+ """
505
+ padding_id = pad_id if pad_id is not None else self.pad_id
506
+ if len(ids) >= max_len:
507
+ return ids[:max_len]
508
+ return ids + [padding_id] * (max_len - len(ids))
509
+
510
+ def get_sentence_boundaries(self, ids: list[int]) -> list[int]:
511
+ """Find sentence boundary positions in a token sequence.
512
+
513
+ This is used by the diffusion model to identify which tokens
514
+ belong to which sentence, enabling non-sequential generation
515
+ and revision at the sentence level.
516
+
517
+ Args:
518
+ ids: Token IDs.
519
+
520
+ Returns:
521
+ List of indices where sentence boundaries occur.
522
+ """
523
+ boundaries = []
524
+ for i, id_ in enumerate(ids):
525
+ if id_ == self.sent_id:
526
+ boundaries.append(i)
527
+ return boundaries
528
+
529
+ def save(self, path: str | Path) -> None:
530
+ """Save tokenizer to file.
531
+
532
+ Args:
533
+ path: Output file path (JSON).
534
+ """
535
+ path = Path(path)
536
+ path.parent.mkdir(parents=True, exist_ok=True)
537
+
538
+ data = {
539
+ "config": {
540
+ "bpe_vocab_size": self.config.bpe_vocab_size,
541
+ "max_sentences": self.config.max_sentences,
542
+ "sentence_boundary_token": self.config.sentence_boundary_token,
543
+ "pad_token": self.config.pad_token,
544
+ "bos_token": self.config.bos_token,
545
+ "eos_token": self.config.eos_token,
546
+ "mask_token": self.config.mask_token,
547
+ "noise_token": self.config.noise_token,
548
+ "min_frequency": self.config.min_frequency,
549
+ },
550
+ "vocab": self.vocab,
551
+ "merges": {f"{k[0]}|||{k[1]}": v for k, v in self.merges.items()},
552
+ "is_trained": self._is_trained,
553
+ }
554
+
555
+ with open(path, "w", encoding="utf-8") as f:
556
+ json.dump(data, f, ensure_ascii=False, indent=2)
557
+
558
+ @classmethod
559
+ def load(cls, path: str | Path) -> AamTokenizer:
560
+ """Load tokenizer from file.
561
+
562
+ Args:
563
+ path: Input file path (JSON).
564
+
565
+ Returns:
566
+ Loaded AamTokenizer.
567
+ """
568
+ with open(path, "r", encoding="utf-8") as f:
569
+ data = json.load(f)
570
+
571
+ config = TokenizerConfig(**data.get("config", {}))
572
+ tokenizer = cls(config=config)
573
+
574
+ # Restore vocabulary
575
+ tokenizer.vocab = data["vocab"]
576
+ tokenizer.id_to_token = {int(v): k for k, v in data["vocab"].items()}
577
+
578
+ # Restore merges
579
+ tokenizer.merges = {}
580
+ for k_str, v in data.get("merges", {}).items():
581
+ parts = k_str.split("|||")
582
+ tokenizer.merges[(parts[0], parts[1])] = v
583
+
584
+ tokenizer._is_trained = data.get("is_trained", False)
585
+
586
+ return tokenizer
587
+
588
+ def __len__(self) -> int:
589
+ return self.vocab_size
590
+
591
+ def __repr__(self) -> str:
592
+ status = "trained" if self._is_trained else "untrained"
593
+ return (
594
+ f"AamTokenizer(vocab_size={self.vocab_size}, "
595
+ f"merges={len(self.merges)}, status={status})"
596
+ )
diffusion_llm/training/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Training module for AAM Diffusion LLM."""
2
+
3
+ from diffusion_llm.training.trainer import AamTrainer
4
+ from diffusion_llm.training.dataset import GraphNarrativeDataset
5
+ from diffusion_llm.training.losses import DiffusionLoss, compute_loss
6
+
7
+ __all__ = ["AamTrainer", "GraphNarrativeDataset", "DiffusionLoss", "compute_loss"]
diffusion_llm/training/dataset.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Dataset
3
+
4
+ Dataset class for Graph→Narrative training pairs.
5
+
6
+ Each training example consists of:
7
+ - Graph conditioning: evidence nodes, compositions, confidence,
8
+ anomalies, reasoning chains, temporal context
9
+ - Target narrative: natural language text that represents
10
+ the graph data in sentence form
11
+
12
+ The dataset handles:
13
+ - Loading from JSONL files
14
+ - Tokenization of both graph data and narratives
15
+ - Padding and batching
16
+ - Data augmentation (sentence shuffling, noise injection)
17
+
18
+ Analogi: Seperti Jin Soun berlatih mengungkapkan kesimpulan —
19
+ dia diberi "kasus" (graph data) dan "jawaban yang benar"
20
+ (narrative target), lalu berlatih sampai bisa menyusun
21
+ kalimat yang tepat dari graph.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ import logging
28
+ import random
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Optional
32
+
33
+ import torch
34
+ from torch.utils.data import Dataset
35
+
36
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class GraphNarrativeExample:
43
+ """A single training example: graph conditioning + target narrative.
44
+
45
+ This represents the "input" and "expected output" for one
46
+ training step of the diffusion model.
47
+ """
48
+ # Target narrative (what the model should generate)
49
+ narrative: str = ""
50
+
51
+ # Graph conditioning inputs
52
+ trigger: str = ""
53
+ evidence_nodes: list[str] = field(default_factory=list)
54
+ compositions: list[str] = field(default_factory=list)
55
+ confidence_map: dict[str, float] = field(default_factory=dict)
56
+ anomalies: list[str] = field(default_factory=list)
57
+ reasoning_steps: list[str] = field(default_factory=list)
58
+ source_trust: float = 1.0
59
+ temporal_context: list[str] = field(default_factory=list)
60
+
61
+ # Metadata
62
+ language: str = "id"
63
+ source: str = "synthetic"
64
+
65
+ def to_dict(self) -> dict:
66
+ """Serialize to dictionary."""
67
+ return {
68
+ "narrative": self.narrative,
69
+ "trigger": self.trigger,
70
+ "evidence_nodes": self.evidence_nodes,
71
+ "compositions": self.compositions,
72
+ "confidence_map": self.confidence_map,
73
+ "anomalies": self.anomalies,
74
+ "reasoning_steps": self.reasoning_steps,
75
+ "source_trust": self.source_trust,
76
+ "temporal_context": self.temporal_context,
77
+ "language": self.language,
78
+ "source": self.source,
79
+ }
80
+
81
+ @classmethod
82
+ def from_dict(cls, data: dict) -> GraphNarrativeExample:
83
+ """Deserialize from dictionary."""
84
+ return cls(
85
+ narrative=data.get("narrative", ""),
86
+ trigger=data.get("trigger", ""),
87
+ evidence_nodes=data.get("evidence_nodes", []),
88
+ compositions=data.get("compositions", []),
89
+ confidence_map=data.get("confidence_map", {}),
90
+ anomalies=data.get("anomalies", []),
91
+ reasoning_steps=data.get("reasoning_steps", []),
92
+ source_trust=data.get("source_trust", 1.0),
93
+ temporal_context=data.get("temporal_context", []),
94
+ language=data.get("language", "id"),
95
+ source=data.get("source", "synthetic"),
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class BatchOutput:
101
+ """Output from a single batch.
102
+
103
+ All tensors are already padded to uniform length.
104
+ """
105
+ token_ids: torch.Tensor
106
+ """Target narrative token IDs, shape (batch, seq_len)."""
107
+
108
+ evidence_ids: Optional[torch.Tensor] = None
109
+ """Evidence node token IDs, shape (batch, n_evidence, ev_seq_len)."""
110
+
111
+ evidence_confidence: Optional[torch.Tensor] = None
112
+ """Evidence confidence, shape (batch, n_evidence)."""
113
+
114
+ anomaly_ids: Optional[torch.Tensor] = None
115
+ """Anomaly token IDs, shape (batch, n_anomalies, an_seq_len)."""
116
+
117
+ anomaly_confidence: Optional[torch.Tensor] = None
118
+ """Anomaly confidence, shape (batch, n_anomalies)."""
119
+
120
+ reasoning_ids: Optional[torch.Tensor] = None
121
+ """Reasoning step token IDs, shape (batch, n_steps, r_seq_len)."""
122
+
123
+ reasoning_confidence: Optional[torch.Tensor] = None
124
+ """Reasoning confidence, shape (batch, n_steps)."""
125
+
126
+ source_trust: Optional[torch.Tensor] = None
127
+ """Source trust scores, shape (batch,)."""
128
+
129
+
130
+ class GraphNarrativeDataset(Dataset):
131
+ """Dataset for Graph→Narrative training pairs.
132
+
133
+ Loads training examples from JSONL files and provides
134
+ tokenized, padded batches for training.
135
+
136
+ Args:
137
+ data_path: Path to JSONL file with training data.
138
+ tokenizer: AamTokenizer instance for encoding.
139
+ max_seq_len: Maximum sequence length for narratives.
140
+ max_evidence: Maximum number of evidence nodes.
141
+ max_anomalies: Maximum number of anomalies.
142
+ max_reasoning: Maximum number of reasoning steps.
143
+ augment: Whether to apply data augmentation.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ data_path: str | Path,
149
+ tokenizer: AamTokenizer,
150
+ max_seq_len: int = 512,
151
+ max_evidence: int = 50,
152
+ max_anomalies: int = 10,
153
+ max_reasoning: int = 15,
154
+ augment: bool = True,
155
+ ):
156
+ self.data_path = Path(data_path)
157
+ self.tokenizer = tokenizer
158
+ self.max_seq_len = max_seq_len
159
+ self.max_evidence = max_evidence
160
+ self.max_anomalies = max_anomalies
161
+ self.max_reasoning = max_reasoning
162
+ self.augment = augment
163
+
164
+ # Load data
165
+ self.examples: list[GraphNarrativeExample] = []
166
+ self._load_data()
167
+
168
+ logger.info(
169
+ "GraphNarrativeDataset: %d examples loaded from %s",
170
+ len(self.examples),
171
+ self.data_path,
172
+ )
173
+
174
+ def _load_data(self) -> None:
175
+ """Load examples from JSONL file."""
176
+ if not self.data_path.exists():
177
+ logger.warning("Data file not found: %s", self.data_path)
178
+ return
179
+
180
+ with open(self.data_path, "r", encoding="utf-8") as f:
181
+ for line_num, line in enumerate(f, 1):
182
+ line = line.strip()
183
+ if not line:
184
+ continue
185
+ try:
186
+ data = json.loads(line)
187
+ example = GraphNarrativeExample.from_dict(data)
188
+ if example.narrative: # Skip empty narratives
189
+ self.examples.append(example)
190
+ except json.JSONDecodeError:
191
+ logger.warning("Invalid JSON at line %d", line_num)
192
+
193
+ def __len__(self) -> int:
194
+ return len(self.examples)
195
+
196
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
197
+ """Get a single training example.
198
+
199
+ Returns:
200
+ Dictionary with tokenized inputs.
201
+ """
202
+ example = self.examples[idx]
203
+
204
+ # Data augmentation
205
+ if self.augment:
206
+ example = self._augment(example)
207
+
208
+ # Tokenize narrative (target)
209
+ narrative_ids = self.tokenizer.encode(example.narrative, add_special=True)
210
+ narrative_ids = self.tokenizer.pad_sequence(narrative_ids, self.max_seq_len)
211
+ narrative_tensor = torch.tensor(narrative_ids, dtype=torch.long)
212
+
213
+ # Tokenize evidence nodes
214
+ evidence_data = self._tokenize_node_list(
215
+ example.evidence_nodes, max_nodes=self.max_evidence
216
+ )
217
+
218
+ # Tokenize anomalies
219
+ anomaly_data = self._tokenize_node_list(
220
+ example.anomalies, max_nodes=self.max_anomalies
221
+ )
222
+
223
+ # Tokenize reasoning steps
224
+ reasoning_data = self._tokenize_node_list(
225
+ example.reasoning_steps, max_nodes=self.max_reasoning
226
+ )
227
+
228
+ # Source trust
229
+ source_trust = torch.tensor(example.source_trust, dtype=torch.float32)
230
+
231
+ # Evidence confidence
232
+ conf_values = list(example.confidence_map.values())[:self.max_evidence]
233
+ if conf_values:
234
+ evidence_conf = torch.tensor(conf_values, dtype=torch.float32)
235
+ evidence_conf = torch.nn.functional.pad(
236
+ evidence_conf, (0, self.max_evidence - len(conf_values))
237
+ )
238
+ else:
239
+ evidence_conf = torch.zeros(self.max_evidence, dtype=torch.float32)
240
+
241
+ # Anomaly confidence (default 0.6 for detected anomalies)
242
+ anomaly_conf = torch.full(
243
+ (self.max_anomalies,), 0.6, dtype=torch.float32
244
+ )
245
+
246
+ # Reasoning confidence (default 0.7)
247
+ reasoning_conf = torch.full(
248
+ (self.max_reasoning,), 0.7, dtype=torch.float32
249
+ )
250
+
251
+ return {
252
+ "token_ids": narrative_tensor,
253
+ "evidence_ids": evidence_data["ids"],
254
+ "evidence_confidence": evidence_conf,
255
+ "anomaly_ids": anomaly_data["ids"],
256
+ "anomaly_confidence": anomaly_conf,
257
+ "reasoning_ids": reasoning_data["ids"],
258
+ "reasoning_confidence": reasoning_conf,
259
+ "source_trust": source_trust,
260
+ }
261
+
262
+ def _tokenize_node_list(
263
+ self,
264
+ nodes: list[str],
265
+ max_nodes: int,
266
+ max_node_len: int = 32,
267
+ ) -> dict[str, torch.Tensor]:
268
+ """Tokenize a list of node descriptions.
269
+
270
+ Args:
271
+ nodes: List of node text descriptions.
272
+ max_nodes: Maximum number of nodes to encode.
273
+ max_node_len: Maximum token length per node.
274
+
275
+ Returns:
276
+ Dictionary with padded token IDs tensor.
277
+ """
278
+ if not nodes:
279
+ return {
280
+ "ids": torch.zeros(max_nodes, max_node_len, dtype=torch.long),
281
+ }
282
+
283
+ # Limit to max_nodes
284
+ nodes = nodes[:max_nodes]
285
+
286
+ all_ids = []
287
+ for node in nodes:
288
+ ids = self.tokenizer.encode(node, add_special=False)
289
+ ids = self.tokenizer.pad_sequence(ids, max_node_len)
290
+ all_ids.append(ids)
291
+
292
+ # Pad to max_nodes
293
+ while len(all_ids) < max_nodes:
294
+ all_ids.append([0] * max_node_len)
295
+
296
+ return {
297
+ "ids": torch.tensor(all_ids, dtype=torch.long),
298
+ }
299
+
300
+ def _augment(self, example: GraphNarrativeExample) -> GraphNarrativeExample:
301
+ """Apply data augmentation.
302
+
303
+ Augmentation strategies:
304
+ 1. Random sentence shuffling within the narrative
305
+ 2. Random evidence node dropping (simulate incomplete data)
306
+ 3. Random confidence perturbation
307
+
308
+ Args:
309
+ example: Original training example.
310
+
311
+ Returns:
312
+ Augmented example.
313
+ """
314
+ import copy
315
+ augmented = copy.deepcopy(example)
316
+
317
+ # 1. Sentence shuffling (with 20% probability)
318
+ if random.random() < 0.2:
319
+ sentences = self.tokenizer._split_sentences(augmented.narrative)
320
+ if len(sentences) > 2:
321
+ # Keep first sentence, shuffle the rest
322
+ first = sentences[0]
323
+ rest = sentences[1:]
324
+ random.shuffle(rest)
325
+ augmented.narrative = first + " " + " ".join(rest)
326
+
327
+ # 2. Evidence dropping (with 10% probability per node)
328
+ if augmented.evidence_nodes:
329
+ augmented.evidence_nodes = [
330
+ node for node in augmented.evidence_nodes
331
+ if random.random() > 0.1
332
+ ]
333
+
334
+ # 3. Confidence perturbation
335
+ if augmented.confidence_map:
336
+ perturbed = {}
337
+ for k, v in augmented.confidence_map.items():
338
+ noise = random.gauss(0, 0.05)
339
+ perturbed[k] = max(0.0, min(1.0, v + noise))
340
+ augmented.confidence_map = perturbed
341
+
342
+ return augmented
343
+
344
+
345
+ def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
346
+ """Custom collate function for DataLoader.
347
+
348
+ Handles variable-length graph conditioning by padding
349
+ all tensors in the batch to the same size.
350
+
351
+ Args:
352
+ batch: List of example dictionaries.
353
+
354
+ Returns:
355
+ Batched dictionary of tensors.
356
+ """
357
+ result = {}
358
+
359
+ # Stack all tensors
360
+ for key in batch[0]:
361
+ tensors = [item[key] for item in batch]
362
+ if tensors[0].dim() == 0:
363
+ result[key] = torch.stack(tensors)
364
+ elif tensors[0].dim() == 1:
365
+ result[key] = torch.stack(tensors)
366
+ elif tensors[0].dim() == 2:
367
+ result[key] = torch.stack(tensors)
368
+ else:
369
+ result[key] = torch.stack(tensors)
370
+
371
+ return result
diffusion_llm/training/losses.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Loss Functions
3
+
4
+ Implements various loss functions for training the diffusion model,
5
+ including MSE, MAE, Huber, and weighted variants.
6
+
7
+ Analogi: Seperti Jin Soun mengukur seberapa jauh prediksinya
8
+ dari kenyataan — semakin besar gap, semakin besar "rasa sakit"
9
+ yang mendorong perbaikan.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from diffusion_llm.config.model_config import DiffusionConfig
19
+
20
+
21
+ class DiffusionLoss(nn.Module):
22
+ """Loss function for diffusion model training.
23
+
24
+ Computes the loss between predicted and target values,
25
+ with optional weighting strategies to balance training
26
+ across different noise levels.
27
+
28
+ Args:
29
+ config: DiffusionConfig with loss hyperparameters.
30
+ """
31
+
32
+ def __init__(self, config: DiffusionConfig):
33
+ super().__init__()
34
+ self.config = config
35
+
36
+ def forward(
37
+ self,
38
+ predicted: torch.Tensor,
39
+ target: torch.Tensor,
40
+ timestep: torch.Tensor,
41
+ alphas_cumprod: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ """Compute diffusion loss.
44
+
45
+ Args:
46
+ predicted: Model output (predicted noise/x0/v).
47
+ target: Target values.
48
+ timestep: Timestep indices for weighting.
49
+ alphas_cumprod: Cumulative product of alphas from scheduler.
50
+
51
+ Returns:
52
+ Scalar loss value.
53
+ """
54
+ # Base loss
55
+ if self.config.loss_type == "mse":
56
+ loss = F.mse_loss(predicted, target, reduction="none")
57
+ elif self.config.loss_type == "mae":
58
+ loss = F.l1_loss(predicted, target, reduction="none")
59
+ elif self.config.loss_type == "huber":
60
+ loss = F.smooth_l1_loss(predicted, target, reduction="none")
61
+ else:
62
+ raise ValueError(f"Unknown loss_type: {self.config.loss_type}")
63
+
64
+ # Average over feature dimension
65
+ loss = loss.mean(dim=-1) # (batch, seq_len)
66
+
67
+ # Apply weighting
68
+ if self.config.loss_weighting == "min_snr":
69
+ loss = self._min_snr_weight(loss, timestep, alphas_cumprod)
70
+ elif self.config.loss_weighting == "p2":
71
+ loss = self._p2_weight(loss, timestep, alphas_cumprod)
72
+
73
+ return loss.mean()
74
+
75
+ def _min_snr_weight(
76
+ self,
77
+ loss: torch.Tensor,
78
+ timestep: torch.Tensor,
79
+ alphas_cumprod: torch.Tensor,
80
+ gamma: float = 5.0,
81
+ ) -> torch.Tensor:
82
+ """Min-SNR-gamma weighting (Hang et al., 2023)."""
83
+ snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8)
84
+ weight = torch.clamp(snr, max=gamma) / (snr + 1e-8)
85
+ weight = weight.unsqueeze(-1).expand_as(loss)
86
+ return loss * weight
87
+
88
+ def _p2_weight(
89
+ self,
90
+ loss: torch.Tensor,
91
+ timestep: torch.Tensor,
92
+ alphas_cumprod: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ """P2 weighting (Choi et al., 2022)."""
95
+ snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8)
96
+ weight = 1.0 / (snr ** self.config.p2_gamma + self.config.p2_k)
97
+ weight = weight.unsqueeze(-1).expand_as(loss)
98
+ return loss * weight
99
+
100
+
101
+ def compute_loss(
102
+ predicted: torch.Tensor,
103
+ target: torch.Tensor,
104
+ timestep: torch.Tensor,
105
+ alphas_cumprod: torch.Tensor,
106
+ loss_type: str = "mse",
107
+ loss_weighting: str = "none",
108
+ ) -> torch.Tensor:
109
+ """Convenience function to compute diffusion loss without creating a module.
110
+
111
+ Args:
112
+ predicted: Model output.
113
+ target: Target values.
114
+ timestep: Timestep indices.
115
+ alphas_cumprod: Alpha cumulative products.
116
+ loss_type: Loss function type.
117
+ loss_weighting: Weighting strategy.
118
+
119
+ Returns:
120
+ Scalar loss value.
121
+ """
122
+ config = DiffusionConfig(
123
+ loss_type=loss_type,
124
+ loss_weighting=loss_weighting,
125
+ )
126
+ loss_fn = DiffusionLoss(config)
127
+ return loss_fn(predicted, target, timestep, alphas_cumprod)
diffusion_llm/training/trainer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AAM Diffusion LLM — Trainer
3
+
4
+ Training loop for the AAM Diffusion Model.
5
+
6
+ Handles:
7
+ - Training loop with gradient accumulation
8
+ - Learning rate scheduling with warmup
9
+ - Mixed precision training (AMP)
10
+ - EMA model updates
11
+ - Checkpoint saving/loading
12
+ - Logging to console and Weights & Biases
13
+ - Evaluation on validation set
14
+
15
+ Analogi: Seperti latihan fisik Jin Soun — berulang-ulang,
16
+ bertahap meningkat intensitas, dengan instruktur yang
17
+ mengawasi dan memberi koreksi.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import logging
24
+ import math
25
+ import time
26
+ from pathlib import Path
27
+ from typing import Optional
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.utils.data import DataLoader
32
+
33
+ from diffusion_llm.config.model_config import AamDiffusionConfig
34
+ from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
35
+ from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
36
+ from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
37
+ from diffusion_llm.training.losses import DiffusionLoss
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class AamTrainer:
43
+ """Trainer for the AAM Diffusion Model.
44
+
45
+ Args:
46
+ config: AamDiffusionConfig with training settings.
47
+ model: AamDiffusionModel instance.
48
+ tokenizer: AamTokenizer instance.
49
+ train_dataset: Training dataset.
50
+ val_dataset: Optional validation dataset.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ config: AamDiffusionConfig,
56
+ model: AamDiffusionModel,
57
+ tokenizer: AamTokenizer,
58
+ train_dataset: GraphNarrativeDataset,
59
+ val_dataset: Optional[GraphNarrativeDataset] = None,
60
+ ):
61
+ self.config = config
62
+ self.model = model
63
+ self.tokenizer = tokenizer
64
+ self.train_dataset = train_dataset
65
+ self.val_dataset = val_dataset
66
+
67
+ # Device
68
+ self.device = torch.device(
69
+ "cuda" if torch.cuda.is_available() else "cpu"
70
+ )
71
+ self.model.to(self.device)
72
+ logger.info("Training on device: %s", self.device)
73
+
74
+ # Optimizer
75
+ self.optimizer = torch.optim.AdamW(
76
+ self.model.parameters(),
77
+ lr=config.training.learning_rate,
78
+ weight_decay=config.training.weight_decay,
79
+ betas=(config.training.adam_beta1, config.training.adam_beta2),
80
+ eps=config.training.adam_eps,
81
+ )
82
+
83
+ # Loss function
84
+ self.loss_fn = DiffusionLoss(config.diffusion)
85
+
86
+ # Data loaders
87
+ self.train_loader = DataLoader(
88
+ train_dataset,
89
+ batch_size=config.training.batch_size,
90
+ shuffle=True,
91
+ num_workers=config.training.num_workers,
92
+ collate_fn=collate_fn,
93
+ pin_memory=True,
94
+ )
95
+
96
+ if val_dataset:
97
+ self.val_loader = DataLoader(
98
+ val_dataset,
99
+ batch_size=config.training.batch_size,
100
+ shuffle=False,
101
+ num_workers=config.training.num_workers,
102
+ collate_fn=collate_fn,
103
+ pin_memory=True,
104
+ )
105
+ else:
106
+ self.val_loader = None
107
+
108
+ # LR scheduler
109
+ self.scheduler = self._create_lr_scheduler()
110
+
111
+ # AMP
112
+ self.scaler = None
113
+ if config.training.use_amp:
114
+ dtype = torch.bfloat16 if config.training.amp_dtype == "bf16" else torch.float16
115
+ self.scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
116
+
117
+ # EMA
118
+ self.ema_model = None
119
+ if config.training.use_ema:
120
+ self.ema_model = self._create_ema_model()
121
+
122
+ # State tracking
123
+ self.global_step = 0
124
+ self.best_val_loss = float("inf")
125
+ self.train_losses: list[float] = []
126
+
127
+ # Output directory
128
+ self.output_dir = Path(config.output_dir)
129
+ self.output_dir.mkdir(parents=True, exist_ok=True)
130
+
131
+ # Seed
132
+ torch.manual_seed(config.seed)
133
+
134
+ def _create_lr_scheduler(self):
135
+ """Create learning rate scheduler with warmup."""
136
+ total_steps = self.config.training.max_steps
137
+ warmup_steps = self.config.training.warmup_steps
138
+
139
+ def lr_lambda(step: int) -> float:
140
+ if step < warmup_steps:
141
+ return step / max(warmup_steps, 1)
142
+ if self.config.training.lr_schedule == "cosine":
143
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
144
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
145
+ elif self.config.training.lr_schedule == "linear":
146
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
147
+ return 1.0 - progress
148
+ else:
149
+ return 1.0
150
+
151
+ return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
152
+
153
+ def _create_ema_model(self) -> AamDiffusionModel:
154
+ """Create EMA copy of the model."""
155
+ import copy
156
+ ema = copy.deepcopy(self.model)
157
+ for param in ema.parameters():
158
+ param.requires_grad = False
159
+ return ema
160
+
161
+ @torch.no_grad()
162
+ def _update_ema(self) -> None:
163
+ """Update EMA model weights."""
164
+ if self.ema_model is None:
165
+ return
166
+ decay = self.config.training.ema_decay
167
+ for ema_param, model_param in zip(
168
+ self.ema_model.parameters(), self.model.parameters()
169
+ ):
170
+ ema_param.data.mul_(decay).add_(model_param.data, alpha=1 - decay)
171
+
172
+ def train(self) -> None:
173
+ """Main training loop.
174
+
175
+ Runs for max_steps or max_epochs, whichever comes first.
176
+ Saves checkpoints and runs evaluation periodically.
177
+ """
178
+ logger.info("Starting training...")
179
+ logger.info(" Max steps: %d", self.config.training.max_steps)
180
+ logger.info(" Batch size: %d", self.config.training.batch_size)
181
+ logger.info(" Gradient accumulation: %d", self.config.training.gradient_accumulation_steps)
182
+ logger.info(" Effective batch size: %d",
183
+ self.config.training.batch_size * self.config.training.gradient_accumulation_steps)
184
+
185
+ start_time = time.time()
186
+ epoch = 0
187
+
188
+ while self.global_step < self.config.training.max_steps:
189
+ epoch += 1
190
+ if epoch > self.config.training.max_epochs:
191
+ break
192
+
193
+ logger.info("=== Epoch %d ===", epoch)
194
+ epoch_loss = 0.0
195
+ n_batches = 0
196
+
197
+ for batch_idx, batch in enumerate(self.train_loader):
198
+ loss = self._train_step(batch)
199
+ epoch_loss += loss
200
+ n_batches += 1
201
+
202
+ # Logging
203
+ if self.global_step % self.config.training.log_every_steps == 0:
204
+ avg_loss = epoch_loss / max(n_batches, 1)
205
+ lr = self.optimizer.param_groups[0]["lr"]
206
+ elapsed = time.time() - start_time
207
+ steps_per_sec = self.global_step / max(elapsed, 1)
208
+
209
+ logger.info(
210
+ "Step %d | Loss: %.4f | LR: %.2e | Speed: %.1f steps/s",
211
+ self.global_step, loss, lr, steps_per_sec,
212
+ )
213
+
214
+ # Evaluation
215
+ if (self.global_step % self.config.training.eval_every_steps == 0
216
+ and self.val_loader is not None):
217
+ val_loss = self.evaluate()
218
+ logger.info("Validation loss: %.4f", val_loss)
219
+ if val_loss < self.best_val_loss:
220
+ self.best_val_loss = val_loss
221
+ self._save_checkpoint("best.pt")
222
+
223
+ # Checkpoint
224
+ if self.global_step % self.config.training.save_every_steps == 0:
225
+ self._save_checkpoint(f"step_{self.global_step}.pt")
226
+
227
+ # Stop condition
228
+ if self.global_step >= self.config.training.max_steps:
229
+ break
230
+
231
+ avg_epoch_loss = epoch_loss / max(n_batches, 1)
232
+ logger.info("Epoch %d complete. Average loss: %.4f", epoch, avg_epoch_loss)
233
+
234
+ # Final save
235
+ self._save_checkpoint("final.pt")
236
+ elapsed = time.time() - start_time
237
+ logger.info(
238
+ "Training complete! %d steps in %.1f hours",
239
+ self.global_step, elapsed / 3600,
240
+ )
241
+
242
+ def _train_step(self, batch: dict[str, torch.Tensor]) -> float:
243
+ """Single training step.
244
+
245
+ Args:
246
+ batch: Batch of training data.
247
+
248
+ Returns:
249
+ Loss value for this step.
250
+ """
251
+ self.model.train()
252
+
253
+ # Move batch to device
254
+ batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
255
+ for k, v in batch.items()}
256
+
257
+ # Sample random timesteps
258
+ batch_size = batch["token_ids"].shape[0]
259
+ t = torch.randint(
260
+ 0, self.config.diffusion.n_timesteps,
261
+ (batch_size,), device=self.device,
262
+ )
263
+
264
+ # Forward pass
265
+ if self.scaler is not None:
266
+ with torch.amp.autocast("cuda", enabled=True):
267
+ predicted, target = self.model(
268
+ token_ids=batch["token_ids"],
269
+ timestep=t,
270
+ evidence_ids=batch.get("evidence_ids"),
271
+ evidence_confidence=batch.get("evidence_confidence"),
272
+ anomaly_ids=batch.get("anomaly_ids"),
273
+ anomaly_confidence=batch.get("anomaly_confidence"),
274
+ reasoning_ids=batch.get("reasoning_ids"),
275
+ reasoning_confidence=batch.get("reasoning_confidence"),
276
+ source_trust=batch.get("source_trust"),
277
+ )
278
+ loss = self.model.compute_loss(predicted, target, t)
279
+ loss = loss / self.config.training.gradient_accumulation_steps
280
+ else:
281
+ predicted, target = self.model(
282
+ token_ids=batch["token_ids"],
283
+ timestep=t,
284
+ evidence_ids=batch.get("evidence_ids"),
285
+ evidence_confidence=batch.get("evidence_confidence"),
286
+ anomaly_ids=batch.get("anomaly_ids"),
287
+ anomaly_confidence=batch.get("anomaly_confidence"),
288
+ reasoning_ids=batch.get("reasoning_ids"),
289
+ reasoning_confidence=batch.get("reasoning_confidence"),
290
+ source_trust=batch.get("source_trust"),
291
+ )
292
+ loss = self.model.compute_loss(predicted, target, t)
293
+ loss = loss / self.config.training.gradient_accumulation_steps
294
+
295
+ # Backward pass
296
+ if self.scaler is not None:
297
+ self.scaler.scale(loss).backward()
298
+ else:
299
+ loss.backward()
300
+
301
+ # Gradient accumulation
302
+ if (self.global_step + 1) % self.config.training.gradient_accumulation_steps == 0:
303
+ # Gradient clipping
304
+ if self.scaler is not None:
305
+ self.scaler.unscale_(self.optimizer)
306
+ torch.nn.utils.clip_grad_norm_(
307
+ self.model.parameters(),
308
+ self.config.training.grad_clip_norm,
309
+ )
310
+
311
+ # Optimizer step
312
+ if self.scaler is not None:
313
+ self.scaler.step(self.optimizer)
314
+ self.scaler.update()
315
+ else:
316
+ self.optimizer.step()
317
+
318
+ # LR schedule
319
+ self.scheduler.step()
320
+
321
+ # Zero gradients
322
+ self.optimizer.zero_grad()
323
+
324
+ # EMA update
325
+ self._update_ema()
326
+
327
+ self.global_step += 1
328
+ self.train_losses.append(loss.item())
329
+
330
+ return loss.item()
331
+
332
+ @torch.no_grad()
333
+ def evaluate(self) -> float:
334
+ """Evaluate on validation set.
335
+
336
+ Returns:
337
+ Average validation loss.
338
+ """
339
+ if self.val_loader is None:
340
+ return float("inf")
341
+
342
+ self.model.eval()
343
+ total_loss = 0.0
344
+ n_batches = 0
345
+
346
+ for batch in self.val_loader:
347
+ batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
348
+ for k, v in batch.items()}
349
+
350
+ batch_size = batch["token_ids"].shape[0]
351
+ t = torch.randint(
352
+ 0, self.config.diffusion.n_timesteps,
353
+ (batch_size,), device=self.device,
354
+ )
355
+
356
+ predicted, target = self.model(
357
+ token_ids=batch["token_ids"],
358
+ timestep=t,
359
+ evidence_ids=batch.get("evidence_ids"),
360
+ evidence_confidence=batch.get("evidence_confidence"),
361
+ anomaly_ids=batch.get("anomaly_ids"),
362
+ anomaly_confidence=batch.get("anomaly_confidence"),
363
+ reasoning_ids=batch.get("reasoning_ids"),
364
+ reasoning_confidence=batch.get("reasoning_confidence"),
365
+ source_trust=batch.get("source_trust"),
366
+ )
367
+ loss = self.model.compute_loss(predicted, target, t)
368
+ total_loss += loss.item()
369
+ n_batches += 1
370
+
371
+ avg_loss = total_loss / max(n_batches, 1)
372
+ self.model.train()
373
+ return avg_loss
374
+
375
+ def _save_checkpoint(self, filename: str) -> None:
376
+ """Save training checkpoint.
377
+
378
+ Args:
379
+ filename: Checkpoint filename.
380
+ """
381
+ path = self.output_dir / filename
382
+ checkpoint = {
383
+ "model_state_dict": self.model.state_dict(),
384
+ "optimizer_state_dict": self.optimizer.state_dict(),
385
+ "scheduler_state_dict": self.scheduler.state_dict(),
386
+ "global_step": self.global_step,
387
+ "best_val_loss": self.best_val_loss,
388
+ "config": self.config.to_dict(),
389
+ }
390
+ if self.ema_model is not None:
391
+ checkpoint["ema_state_dict"] = self.ema_model.state_dict()
392
+
393
+ torch.save(checkpoint, path)
394
+ logger.info("Checkpoint saved: %s", path)
395
+
396
+ # Clean up old checkpoints
397
+ self._cleanup_checkpoints()
398
+
399
+ def _cleanup_checkpoints(self) -> None:
400
+ """Remove old checkpoints, keeping only the last N."""
401
+ keep_n = self.config.training.keep_last_n_checkpoints
402
+ checkpoints = sorted(self.output_dir.glob("step_*.pt"))
403
+ while len(checkpoints) > keep_n:
404
+ oldest = checkpoints.pop(0)
405
+ oldest.unlink()
406
+ logger.info("Removed old checkpoint: %s", oldest)
407
+
408
+ def load_checkpoint(self, path: str) -> None:
409
+ """Load from checkpoint.
410
+
411
+ Args:
412
+ path: Checkpoint file path.
413
+ """
414
+ checkpoint = torch.load(path, map_location=self.device)
415
+ self.model.load_state_dict(checkpoint["model_state_dict"])
416
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
417
+ self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
418
+ self.global_step = checkpoint["global_step"]
419
+ self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
420
+ logger.info("Loaded checkpoint from step %d", self.global_step)
inference_example.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """AAM Diffusion LLM v1.0 — Inference Example"""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+ sys.path.insert(0, str(Path(__file__).parent))
7
+
8
+ import torch
9
+ from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
10
+
11
+ def main():
12
+ # Load model and tokenizer
13
+ config = AamDiffusionConfig.from_json("config.json")
14
+ model = AamDiffusionModel.load("model.pt", device="cpu")
15
+ tokenizer = AamTokenizer.load("tokenizer.json")
16
+
17
+ # Create generator
18
+ generator = AamGenerator(model, tokenizer, config)
19
+
20
+ # Generate narrative from graph conditioning
21
+ result = generator.generate(
22
+ trigger="Siapa yang mencuri Snow Plum Pill?",
23
+ evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
24
+ anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
25
+ reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
26
+ source_trust=0.85,
27
+ )
28
+
29
+ print("=" * 60)
30
+ print(" AAM Diffusion LLM — Generated Narrative")
31
+ print("=" * 60)
32
+ print(f" Narrative: {result.narrative}")
33
+ print(f" Confidence: {result.confidence:.1%}")
34
+ print(f" Steps: {result.n_diffusion_steps}")
35
+ print(f" Time: {result.generation_time_s:.2f}s")
36
+
37
+ if __name__ == "__main__":
38
+ main()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651e28c3b1fd60919884cc7e6311cd7a604c368669b9abecf27adb2efbc1eaea
3
+ size 1297247
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.24.0
tokenizer.json ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "bpe_vocab_size": 28000,
4
+ "max_sentences": 32,
5
+ "sentence_boundary_token": "<sent>",
6
+ "pad_token": "<pad>",
7
+ "bos_token": "<bos>",
8
+ "eos_token": "<eos>",
9
+ "mask_token": "<mask>",
10
+ "noise_token": "<noise>",
11
+ "min_frequency": 2
12
+ },
13
+ "vocab": {
14
+ "<pad>": 0,
15
+ "<bos>": 1,
16
+ "<eos>": 2,
17
+ "<mask>": 3,
18
+ "<noise>": 4,
19
+ "<sent>": 5,
20
+ "<evidence>": 6,
21
+ "<anomaly>": 7,
22
+ "<confidence>": 8,
23
+ "<reasoning>": 9,
24
+ "<composition>": 10,
25
+ "<temporal>": 11,
26
+ "<unk>": 12,
27
+ "%": 13,
28
+ ",": 14,
29
+ "-": 15,
30
+ ".": 16,
31
+ "0": 17,
32
+ "1": 18,
33
+ "2": 19,
34
+ "3": 20,
35
+ "4": 21,
36
+ "6": 22,
37
+ "7": 23,
38
+ "8": 24,
39
+ "9": 25,
40
+ ":": 26,
41
+ ";": 27,
42
+ "?": 28,
43
+ "_": 29,
44
+ "a": 30,
45
+ "b": 31,
46
+ "c": 32,
47
+ "d": 33,
48
+ "e": 34,
49
+ "f": 35,
50
+ "g": 36,
51
+ "h": 37,
52
+ "i": 38,
53
+ "j": 39,
54
+ "k": 40,
55
+ "l": 41,
56
+ "m": 42,
57
+ "n": 43,
58
+ "o": 44,
59
+ "p": 45,
60
+ "r": 46,
61
+ "s": 47,
62
+ "t": 48,
63
+ "u": 49,
64
+ "v": 50,
65
+ "w": 51,
66
+ "y": 52,
67
+ "z": 53,
68
+ "an": 54,
69
+ "an</w>": 55,
70
+ "er": 56,
71
+ "en": 57,
72
+ "da": 58,
73
+ "ti": 59,
74
+ "il": 60,
75
+ "si": 61,
76
+ "di": 62,
77
+ "ang</w>": 63,
78
+ "si</w>": 64,
79
+ "anc": 65,
80
+ "kan</w>": 66,
81
+ "al": 67,
82
+ "su": 68,
83
+ "ang": 69,
84
+ "ri</w>": 70,
85
+ "ke": 71,
86
+ "ef": 72,
87
+ "ter": 73,
88
+ "se": 74,
89
+ "te": 75,
90
+ "pa": 76,
91
+ "ng": 77,
92
+ "on</w>": 78,
93
+ "on": 79,
94
+ "hef": 80,
95
+ "hefe": 81,
96
+ "enc": 82,
97
+ "or": 83,
98
+ "la": 84,
99
+ "sim": 85,
100
+ "ul": 86,
101
+ "tida": 87,
102
+ "ar": 88,
103
+ "eng": 89,
104
+ "dari</w>": 90,
105
+ "re": 91,
106
+ "bu": 92,
107
+ "ance</w>": 93,
108
+ "ra": 94,
109
+ "om": 95,
110
+ "hefei</w>": 96,
111
+ "jang": 97,
112
+ "sa": 98,
113
+ "ju</w>": 99,
114
+ "jangm": 100,
115
+ "jangmo": 101,
116
+ "jangmok</w>": 102,
117
+ "al</w>": 103,
118
+ "os": 104,
119
+ "dianc": 105,
120
+ "diancang</w>": 106,
121
+ "ai": 107,
122
+ "in": 108,
123
+ "ja": 109,
124
+ "kon": 110,
125
+ "li": 111,
126
+ "ct</w>": 112,
127
+ "tidak</w>": 113,
128
+ "eri": 114,
129
+ "fi": 115,
130
+ "meng": 116,
131
+ "asi</w>": 117,
132
+ "kesim": 118,
133
+ "kesimp": 119,
134
+ "kesimpul": 120,
135
+ "kesimpulan</w>": 121,
136
+ "di</w>": 122,
137
+ "ngkan</w>": 123,
138
+ "ksi</w>": 124,
139
+ "pi": 125,
140
+ "ya</w>": 126,
141
+ "yang</w>": 127,
142
+ "encu": 128,
143
+ "ta": 129,
144
+ "buk": 130,
145
+ "bukt": 131,
146
+ "bukti</w>": 132,
147
+ "pen": 133,
148
+ "per": 134,
149
+ "lu": 135,
150
+ "le": 136,
151
+ "fiv": 137,
152
+ "five</w>": 138,
153
+ "sw": 139,
154
+ "swor": 140,
155
+ "sword": 141,
156
+ "swords</w>": 142,
157
+ "pencu": 143,
158
+ "ence</w>": 144,
159
+ "ce": 145,
160
+ "ku": 146,
161
+ "ili": 147,
162
+ "sn": 148,
163
+ "sno": 149,
164
+ "snow</w>": 150,
165
+ "plu": 151,
166
+ "plum</w>": 152,
167
+ "pil": 153,
168
+ "pill</w>": 154,
169
+ "mengh": 155,
170
+ "menghil": 156,
171
+ "menghilang</w>": 157,
172
+ "lo": 158,
173
+ "bi": 159,
174
+ "de": 160,
175
+ "anom": 161,
176
+ "anomal": 162,
177
+ "mar": 163,
178
+ "marti": 164,
179
+ "martial</w>": 165,
180
+ "alli": 166,
181
+ "alliance</w>": 167,
182
+ "mu": 168,
183
+ "anal": 169,
184
+ "anali": 170,
185
+ "analisi": 171,
186
+ "analisis</w>": 172,
187
+ "gy": 173,
188
+ "gyer": 174,
189
+ "gyery": 175,
190
+ "gyeryon": 176,
191
+ "gyeryong</w>": 177,
192
+ "mer": 178,
193
+ "merc": 179,
194
+ "merch": 180,
195
+ "merchan": 181,
196
+ "merchant</w>": 182,
197
+ "gu": 183,
198
+ "guil": 184,
199
+ "guild</w>": 185,
200
+ "ha": 186,
201
+ "cr": 187,
202
+ "cros": 188,
203
+ "cross</w>": 189,
204
+ "ref": 190,
205
+ "refer": 191,
206
+ "reference</w>": 192,
207
+ "keja": 193,
208
+ "kejadi": 194,
209
+ "kejadian</w>": 195,
210
+ "simh": 196,
211
+ "simhy": 197,
212
+ "simhye": 198,
213
+ "simhyeon</w>": 199,
214
+ "pav": 200,
215
+ "pavili": 201,
216
+ "pavilion</w>": 202,
217
+ "me": 203,
218
+ "tion</w>": 204,
219
+ "sum": 205,
220
+ "blo": 206,
221
+ "bloo": 207,
222
+ "blood</w>": 208,
223
+ "ser": 209,
224
+ "serpen": 210,
225
+ "serpent</w>": 211,
226
+ "dance</w>": 212,
227
+ "ste": 213,
228
+ "step</w>": 214,
229
+ "pre": 215,
230
+ "predi": 216,
231
+ "tin": 217,
232
+ "tinda": 218,
233
+ "tindakan</w>": 219,
234
+ "beri": 220,
235
+ "beriku": 221,
236
+ "berikut": 222,
237
+ "berikutn": 223,
238
+ "berikutnya</w>": 224,
239
+ "tae": 225,
240
+ "taeul": 226,
241
+ "taeul_": 227,
242
+ "taeul_se": 228,
243
+ "taeul_sect</w>": 229,
244
+ "po": 230,
245
+ "pol": 231,
246
+ "pola</w>": 232,
247
+ "jang</w>": 233,
248
+ "hang": 234,
249
+ "hangi</w>": 235,
250
+ "ad": 236,
251
+ "ada</w>": 237,
252
+ "bar": 238,
253
+ "baru</w>": 239,
254
+ "pat": 240,
255
+ "patter": 241,
256
+ "pattern</w>": 242,
257
+ "terpi": 243,
258
+ "terpisa": 244,
259
+ "terpisah</w>": 245,
260
+ "com": 246,
261
+ "comp": 247,
262
+ "as": 248,
263
+ "dete": 249,
264
+ "deteksi</w>": 250,
265
+ "gu</w>": 251,
266
+ "ilm": 252,
267
+ "ilmu</w>": 253,
268
+ "ketida": 254,
269
+ "ketidak": 255,
270
+ "ketidakse": 256,
271
+ "ketidaksesu": 257,
272
+ "ketidaksesuai": 258,
273
+ "ketidaksesuaian</w>": 259,
274
+ "terk": 260,
275
+ "terkai": 261,
276
+ "terkait</w>": 262,
277
+ "lap": 263,
278
+ "lapor": 264,
279
+ "laporan</w>": 265,
280
+ "hu": 266,
281
+ "hubu": 267,
282
+ "ela": 268,
283
+ "dar": 269,
284
+ "dark": 270,
285
+ "dark_": 271,
286
+ "dark_f": 272,
287
+ "dark_fa": 273,
288
+ "dark_fac": 274,
289
+ "dark_faction</w>": 275,
290
+ "at</w>": 276,
291
+ "anomaly</w>": 277,
292
+ "ban": 278,
293
+ "bandi": 279,
294
+ "bandingkan</w>": 280,
295
+ "tang": 281,
296
+ "tangg": 282,
297
+ "tanggal</w>": 283,
298
+ "hefei": 284,
299
+ "hefei_": 285,
300
+ "hefei_b": 286,
301
+ "hefei_br": 287,
302
+ "hefei_branc": 288,
303
+ "hefei_branch</w>": 289,
304
+ "deng": 290,
305
+ "dengan</w>": 291,
306
+ "hubungkan</w>": 292,
307
+ "fra": 293,
308
+ "frag": 294,
309
+ "fragme": 295,
310
+ "fragmen</w>": 296,
311
+ "pencuri</w>": 297,
312
+ "compos": 298,
313
+ "compose</w>": 299,
314
+ "susu": 300,
315
+ "susun</w>": 301,
316
+ "rec": 302,
317
+ "recal": 303,
318
+ "recall</w>": 304,
319
+ "ing": 305,
320
+ "ingat</w>": 306,
321
+ "semu": 307,
322
+ "semua</w>": 308,
323
+ "predict</w>": 309,
324
+ "perk": 310,
325
+ "perki": 311,
326
+ "perkira": 312,
327
+ "perkirakan</w>": 313,
328
+ "veri": 314,
329
+ "verif": 315,
330
+ "verify</w>": 316,
331
+ "cek</w>": 317,
332
+ "konsi": 318,
333
+ "konsis": 319,
334
+ "konsist": 320,
335
+ "konsisten": 321,
336
+ "konsistensi</w>": 322,
337
+ "konsum": 323,
338
+ "konsumsi</w>": 324,
339
+ "pa</w>": 325,
340
+ "men": 326,
341
+ "ting": 327,
342
+ "fil": 328,
343
+ "filte": 329,
344
+ "filter</w>": 330,
345
+ "eli": 331,
346
+ "elim": 332,
347
+ "elimin": 333,
348
+ "eliminasi</w>": 334,
349
+ "rele": 335,
350
+ "relev": 336,
351
+ "relevan</w>": 337,
352
+ "pil</w>": 338,
353
+ "pasa": 339,
354
+ "pasar</w>": 340,
355
+ "gela": 341,
356
+ "gelap</w>": 342,
357
+ "suc": 343,
358
+ "succe": 344,
359
+ "succes": 345,
360
+ "success</w>": 346,
361
+ "rat": 347,
362
+ "rate</w>": 348,
363
+ "pai": 349,
364
+ "pair</w>": 350,
365
+ "lebi": 351,
366
+ "lebih</w>": 352,
367
+ "tingg": 353,
368
+ "tinggi</w>": 354,
369
+ "bias": 355,
370
+ "biasan": 356,
371
+ "biasanya</w>": 357,
372
+ "dala": 358,
373
+ "dalam</w>": 359,
374
+ "ber": 360,
375
+ "pencur": 361,
376
+ "pencuri": 362,
377
+ "pencurian</w>": 363,
378
+ "ka</w>": 364,
379
+ "tan": 365,
380
+ "tanpa</w>": 366,
381
+ "je": 367,
382
+ "jeja": 368,
383
+ "jejak</w>": 369,
384
+ "perg": 370,
385
+ "perger": 371,
386
+ "pergera": 372,
387
+ "pergerakan</w>": 373,
388
+ "masi</w>": 374,
389
+ "inv": 375,
390
+ "inve": 376,
391
+ "inves": 377,
392
+ "investi": 378,
393
+ "investig": 379,
394
+ "investigasi</w>": 380,
395
+ "hari</w>": 381,
396
+ "sam": 382,
397
+ "sama</w>": 383,
398
+ "dat": 384,
399
+ "data</w>": 385,
400
+ "menu": 386,
401
+ "menun": 387,
402
+ "menunj": 388,
403
+ "menunju": 389,
404
+ "menunjuk": 390,
405
+ "menunjukkan</w>": 391,
406
+ "ca": 392,
407
+ "mi": 393,
408
+ "misi</w>": 394,
409
+ "assi": 395,
410
+ "assig": 396,
411
+ "assign</w>": 397,
412
+ "sen": 398,
413
+ "sendi": 399,
414
+ "sendiri</w>": 400,
415
+ "ka": 401,
416
+ "sete": 402,
417
+ "setela": 403,
418
+ "setelah</w>": 404,
419
+ "temu": 405,
420
+ "ke</w>": 406,
421
+ "sumb": 407,
422
+ "sumbe": 408,
423
+ "sumber</w>": 409,
424
+ "prediksi</w>": 410,
425
+ "ters": 411,
426
+ "tersang": 412,
427
+ "tersangka</w>": 413,
428
+ "penal": 414,
429
+ "penalar": 415,
430
+ "penalaran</w>": 416,
431
+ "menja": 417,
432
+ "menjadi</w>": 418,
433
+ "kun": 419,
434
+ "kunc": 420,
435
+ "kunci</w>": 421,
436
+ "hasi": 422,
437
+ "hasil</w>": 423,
438
+ "inf": 424,
439
+ "infor": 425,
440
+ "informasi</w>": 426,
441
+ "anomali</w>": 427,
442
+ "ya": 428,
443
+ "temuan</w>": 429,
444
+ "berk": 430,
445
+ "berkor": 431,
446
+ "berkorela": 432,
447
+ "berkorelasi</w>": 433,
448
+ "cata": 434,
449
+ "catat": 435,
450
+ "catatan</w>": 436,
451
+ "sia": 437,
452
+ "siapa</w>": 438,
453
+ "mencu": 439,
454
+ "mencuri</w>": 440,
455
+ "terse": 441,
456
+ "tersedi": 442,
457
+ "tersedia</w>": 443,
458
+ "mem": 444,
459
+ "memili": 445,
460
+ "memilik": 446,
461
+ "memiliki</w>": 447,
462
+ "kone": 448,
463
+ "koneksi</w>": 449,
464
+ "con": 450,
465
+ "confi": 451,
466
+ "confid": 452,
467
+ "confidence</w>": 453,
468
+ "mengin": 454,
469
+ "mengindi": 455,
470
+ "mengindika": 456,
471
+ "mengindikasi": 457,
472
+ "mengindikasikan</w>": 458,
473
+ "pene": 459,
474
+ "penelu": 460,
475
+ "penelusu": 461,
476
+ "penelusur": 462,
477
+ "penelusuran</w>": 463,
478
+ "log": 464,
479
+ "logi": 465,
480
+ "logika</w>": 466,
481
+ "pr": 467,
482
+ "pro": 468,
483
+ "prose": 469,
484
+ "proses</w>": 470,
485
+ "ded": 471,
486
+ "dedu": 472,
487
+ "deduksi</w>": 473,
488
+ "mengkon": 474,
489
+ "mengkonfi": 475,
490
+ "mengkonfir": 476,
491
+ "mengkonfirmasi</w>": 477,
492
+ "comple": 478,
493
+ "completion</w>": 479,
494
+ "ba": 480,
495
+ "bah": 481,
496
+ "bahw": 482,
497
+ "bahwa</w>": 483,
498
+ "ev": 484,
499
+ "eval": 485,
500
+ "evalu": 486,
501
+ "evaluasi</w>": 487,
502
+ "keper": 488,
503
+ "keperca": 489,
504
+ "kepercaya": 490,
505
+ "kepercayaan</w>": 491,
506
+ "berta": 492,
507
+ "bertaha": 493,
508
+ "bertahap</w>": 494,
509
+ "insi": 495,
510
+ "insid": 496,
511
+ "inside</w>": 497,
512
+ "jo": 498,
513
+ "job</w>": 499
514
+ },
515
+ "merges": {
516
+ "a|||n": 0,
517
+ "a|||n</w>": 1,
518
+ "e|||r": 2,
519
+ "e|||n": 3,
520
+ "d|||a": 4,
521
+ "t|||i": 5,
522
+ "i|||l": 6,
523
+ "s|||i": 7,
524
+ "d|||i": 8,
525
+ "an|||g</w>": 9,
526
+ "s|||i</w>": 10,
527
+ "an|||c": 11,
528
+ "k|||an</w>": 12,
529
+ "a|||l": 13,
530
+ "s|||u": 14,
531
+ "an|||g": 15,
532
+ "r|||i</w>": 16,
533
+ "k|||e": 17,
534
+ "e|||f": 18,
535
+ "t|||er": 19,
536
+ "s|||e": 20,
537
+ "t|||e": 21,
538
+ "p|||a": 22,
539
+ "n|||g": 23,
540
+ "o|||n</w>": 24,
541
+ "o|||n": 25,
542
+ "h|||ef": 26,
543
+ "hef|||e": 27,
544
+ "en|||c": 28,
545
+ "o|||r": 29,
546
+ "l|||a": 30,
547
+ "si|||m": 31,
548
+ "u|||l": 32,
549
+ "ti|||da": 33,
550
+ "a|||r": 34,
551
+ "en|||g": 35,
552
+ "da|||ri</w>": 36,
553
+ "r|||e": 37,
554
+ "b|||u": 38,
555
+ "anc|||e</w>": 39,
556
+ "r|||a": 40,
557
+ "o|||m": 41,
558
+ "hefe|||i</w>": 42,
559
+ "j|||ang": 43,
560
+ "s|||a": 44,
561
+ "j|||u</w>": 45,
562
+ "jang|||m": 46,
563
+ "jangm|||o": 47,
564
+ "jangmo|||k</w>": 48,
565
+ "a|||l</w>": 49,
566
+ "o|||s": 50,
567
+ "di|||anc": 51,
568
+ "dianc|||ang</w>": 52,
569
+ "a|||i": 53,
570
+ "i|||n": 54,
571
+ "j|||a": 55,
572
+ "k|||on": 56,
573
+ "l|||i": 57,
574
+ "c|||t</w>": 58,
575
+ "tida|||k</w>": 59,
576
+ "er|||i": 60,
577
+ "f|||i": 61,
578
+ "m|||eng": 62,
579
+ "a|||si</w>": 63,
580
+ "ke|||sim": 64,
581
+ "kesim|||p": 65,
582
+ "kesimp|||ul": 66,
583
+ "kesimpul|||an</w>": 67,
584
+ "d|||i</w>": 68,
585
+ "ng|||kan</w>": 69,
586
+ "k|||si</w>": 70,
587
+ "p|||i": 71,
588
+ "y|||a</w>": 72,
589
+ "y|||ang</w>": 73,
590
+ "enc|||u": 74,
591
+ "t|||a": 75,
592
+ "bu|||k": 76,
593
+ "buk|||t": 77,
594
+ "bukt|||i</w>": 78,
595
+ "p|||en": 79,
596
+ "p|||er": 80,
597
+ "l|||u": 81,
598
+ "l|||e": 82,
599
+ "fi|||v": 83,
600
+ "fiv|||e</w>": 84,
601
+ "s|||w": 85,
602
+ "sw|||or": 86,
603
+ "swor|||d": 87,
604
+ "sword|||s</w>": 88,
605
+ "p|||encu": 89,
606
+ "enc|||e</w>": 90,
607
+ "c|||e": 91,
608
+ "k|||u": 92,
609
+ "il|||i": 93,
610
+ "s|||n": 94,
611
+ "sn|||o": 95,
612
+ "sno|||w</w>": 96,
613
+ "p|||lu": 97,
614
+ "plu|||m</w>": 98,
615
+ "p|||il": 99,
616
+ "pil|||l</w>": 100,
617
+ "meng|||h": 101,
618
+ "mengh|||il": 102,
619
+ "menghil|||ang</w>": 103,
620
+ "l|||o": 104,
621
+ "b|||i": 105,
622
+ "d|||e": 106,
623
+ "an|||om": 107,
624
+ "anom|||al": 108,
625
+ "m|||ar": 109,
626
+ "mar|||ti": 110,
627
+ "marti|||al</w>": 111,
628
+ "al|||li": 112,
629
+ "alli|||ance</w>": 113,
630
+ "m|||u": 114,
631
+ "an|||al": 115,
632
+ "anal|||i": 116,
633
+ "anali|||si": 117,
634
+ "analisi|||s</w>": 118,
635
+ "g|||y": 119,
636
+ "gy|||er": 120,
637
+ "gyer|||y": 121,
638
+ "gyery|||on": 122,
639
+ "gyeryon|||g</w>": 123,
640
+ "m|||er": 124,
641
+ "mer|||c": 125,
642
+ "merc|||h": 126,
643
+ "merch|||an": 127,
644
+ "merchan|||t</w>": 128,
645
+ "g|||u": 129,
646
+ "gu|||il": 130,
647
+ "guil|||d</w>": 131,
648
+ "h|||a": 132,
649
+ "c|||r": 133,
650
+ "cr|||os": 134,
651
+ "cros|||s</w>": 135,
652
+ "r|||ef": 136,
653
+ "ref|||er": 137,
654
+ "refer|||ence</w>": 138,
655
+ "ke|||ja": 139,
656
+ "keja|||di": 140,
657
+ "kejadi|||an</w>": 141,
658
+ "sim|||h": 142,
659
+ "simh|||y": 143,
660
+ "simhy|||e": 144,
661
+ "simhye|||on</w>": 145,
662
+ "pa|||v": 146,
663
+ "pav|||ili": 147,
664
+ "pavili|||on</w>": 148,
665
+ "m|||e": 149,
666
+ "ti|||on</w>": 150,
667
+ "su|||m": 151,
668
+ "b|||lo": 152,
669
+ "blo|||o": 153,
670
+ "bloo|||d</w>": 154,
671
+ "s|||er": 155,
672
+ "ser|||pen": 156,
673
+ "serpen|||t</w>": 157,
674
+ "d|||ance</w>": 158,
675
+ "s|||te": 159,
676
+ "ste|||p</w>": 160,
677
+ "p|||re": 161,
678
+ "pre|||di": 162,
679
+ "ti|||n": 163,
680
+ "tin|||da": 164,
681
+ "tinda|||kan</w>": 165,
682
+ "b|||eri": 166,
683
+ "beri|||ku": 167,
684
+ "beriku|||t": 168,
685
+ "berikut|||n": 169,
686
+ "berikutn|||ya</w>": 170,
687
+ "ta|||e": 171,
688
+ "tae|||ul": 172,
689
+ "taeul|||_": 173,
690
+ "taeul_|||se": 174,
691
+ "taeul_se|||ct</w>": 175,
692
+ "p|||o": 176,
693
+ "po|||l": 177,
694
+ "pol|||a</w>": 178,
695
+ "j|||ang</w>": 179,
696
+ "h|||ang": 180,
697
+ "hang|||i</w>": 181,
698
+ "a|||d": 182,
699
+ "ad|||a</w>": 183,
700
+ "b|||ar": 184,
701
+ "bar|||u</w>": 185,
702
+ "pa|||t": 186,
703
+ "pat|||ter": 187,
704
+ "patter|||n</w>": 188,
705
+ "ter|||pi": 189,
706
+ "terpi|||sa": 190,
707
+ "terpisa|||h</w>": 191,
708
+ "c|||om": 192,
709
+ "com|||p": 193,
710
+ "a|||s": 194,
711
+ "de|||te": 195,
712
+ "dete|||ksi</w>": 196,
713
+ "g|||u</w>": 197,
714
+ "il|||m": 198,
715
+ "ilm|||u</w>": 199,
716
+ "ke|||tida": 200,
717
+ "ketida|||k": 201,
718
+ "ketidak|||se": 202,
719
+ "ketidakse|||su": 203,
720
+ "ketidaksesu|||ai": 204,
721
+ "ketidaksesuai|||an</w>": 205,
722
+ "ter|||k": 206,
723
+ "terk|||ai": 207,
724
+ "terkai|||t</w>": 208,
725
+ "la|||p": 209,
726
+ "lap|||or": 210,
727
+ "lapor|||an</w>": 211,
728
+ "h|||u": 212,
729
+ "hu|||bu": 213,
730
+ "e|||la": 214,
731
+ "da|||r": 215,
732
+ "dar|||k": 216,
733
+ "dark|||_": 217,
734
+ "dark_|||f": 218,
735
+ "dark_f|||a": 219,
736
+ "dark_fa|||c": 220,
737
+ "dark_fac|||tion</w>": 221,
738
+ "a|||t</w>": 222,
739
+ "anomal|||y</w>": 223,
740
+ "b|||an": 224,
741
+ "ban|||di": 225,
742
+ "bandi|||ngkan</w>": 226,
743
+ "t|||ang": 227,
744
+ "tang|||g": 228,
745
+ "tangg|||al</w>": 229,
746
+ "hefe|||i": 230,
747
+ "hefei|||_": 231,
748
+ "hefei_|||b": 232,
749
+ "hefei_b|||r": 233,
750
+ "hefei_br|||anc": 234,
751
+ "hefei_branc|||h</w>": 235,
752
+ "d|||eng": 236,
753
+ "deng|||an</w>": 237,
754
+ "hubu|||ngkan</w>": 238,
755
+ "f|||ra": 239,
756
+ "fra|||g": 240,
757
+ "frag|||me": 241,
758
+ "fragme|||n</w>": 242,
759
+ "pencu|||ri</w>": 243,
760
+ "comp|||os": 244,
761
+ "compos|||e</w>": 245,
762
+ "su|||su": 246,
763
+ "susu|||n</w>": 247,
764
+ "re|||c": 248,
765
+ "rec|||al": 249,
766
+ "recal|||l</w>": 250,
767
+ "i|||ng": 251,
768
+ "ing|||at</w>": 252,
769
+ "se|||mu": 253,
770
+ "semu|||a</w>": 254,
771
+ "predi|||ct</w>": 255,
772
+ "per|||k": 256,
773
+ "perk|||i": 257,
774
+ "perki|||ra": 258,
775
+ "perkira|||kan</w>": 259,
776
+ "v|||eri": 260,
777
+ "veri|||f": 261,
778
+ "verif|||y</w>": 262,
779
+ "ce|||k</w>": 263,
780
+ "kon|||si": 264,
781
+ "konsi|||s": 265,
782
+ "konsis|||t": 266,
783
+ "konsist|||en": 267,
784
+ "konsisten|||si</w>": 268,
785
+ "kon|||sum": 269,
786
+ "konsum|||si</w>": 270,
787
+ "p|||a</w>": 271,
788
+ "m|||en": 272,
789
+ "ti|||ng": 273,
790
+ "f|||il": 274,
791
+ "fil|||te": 275,
792
+ "filte|||r</w>": 276,
793
+ "e|||li": 277,
794
+ "eli|||m": 278,
795
+ "elim|||in": 279,
796
+ "elimin|||asi</w>": 280,
797
+ "re|||le": 281,
798
+ "rele|||v": 282,
799
+ "relev|||an</w>": 283,
800
+ "pi|||l</w>": 284,
801
+ "pa|||sa": 285,
802
+ "pasa|||r</w>": 286,
803
+ "g|||ela": 287,
804
+ "gela|||p</w>": 288,
805
+ "su|||c": 289,
806
+ "suc|||ce": 290,
807
+ "succe|||s": 291,
808
+ "succes|||s</w>": 292,
809
+ "ra|||t": 293,
810
+ "rat|||e</w>": 294,
811
+ "pa|||i": 295,
812
+ "pai|||r</w>": 296,
813
+ "le|||bi": 297,
814
+ "lebi|||h</w>": 298,
815
+ "ting|||g": 299,
816
+ "tingg|||i</w>": 300,
817
+ "bi|||as": 301,
818
+ "bias|||an": 302,
819
+ "biasan|||ya</w>": 303,
820
+ "da|||la": 304,
821
+ "dala|||m</w>": 305,
822
+ "b|||er": 306,
823
+ "pencu|||r": 307,
824
+ "pencur|||i": 308,
825
+ "pencuri|||an</w>": 309,
826
+ "k|||a</w>": 310,
827
+ "t|||an": 311,
828
+ "tan|||pa</w>": 312,
829
+ "j|||e": 313,
830
+ "je|||ja": 314,
831
+ "jeja|||k</w>": 315,
832
+ "per|||g": 316,
833
+ "perg|||er": 317,
834
+ "perger|||a": 318,
835
+ "pergera|||kan</w>": 319,
836
+ "m|||asi</w>": 320,
837
+ "in|||v": 321,
838
+ "inv|||e": 322,
839
+ "inve|||s": 323,
840
+ "inves|||ti": 324,
841
+ "investi|||g": 325,
842
+ "investig|||asi</w>": 326,
843
+ "ha|||ri</w>": 327,
844
+ "sa|||m": 328,
845
+ "sam|||a</w>": 329,
846
+ "da|||t": 330,
847
+ "dat|||a</w>": 331,
848
+ "men|||u": 332,
849
+ "menu|||n": 333,
850
+ "menun|||j": 334,
851
+ "menunj|||u": 335,
852
+ "menunju|||k": 336,
853
+ "menunjuk|||kan</w>": 337,
854
+ "c|||a": 338,
855
+ "m|||i": 339,
856
+ "mi|||si</w>": 340,
857
+ "as|||si": 341,
858
+ "assi|||g": 342,
859
+ "assig|||n</w>": 343,
860
+ "s|||en": 344,
861
+ "sen|||di": 345,
862
+ "sendi|||ri</w>": 346,
863
+ "k|||a": 347,
864
+ "se|||te": 348,
865
+ "sete|||la": 349,
866
+ "setela|||h</w>": 350,
867
+ "te|||mu": 351,
868
+ "k|||e</w>": 352,
869
+ "sum|||b": 353,
870
+ "sumb|||e": 354,
871
+ "sumbe|||r</w>": 355,
872
+ "predi|||ksi</w>": 356,
873
+ "ter|||s": 357,
874
+ "ters|||ang": 358,
875
+ "tersang|||ka</w>": 359,
876
+ "pen|||al": 360,
877
+ "penal|||ar": 361,
878
+ "penalar|||an</w>": 362,
879
+ "men|||ja": 363,
880
+ "menja|||di</w>": 364,
881
+ "ku|||n": 365,
882
+ "kun|||c": 366,
883
+ "kunc|||i</w>": 367,
884
+ "ha|||si": 368,
885
+ "hasi|||l</w>": 369,
886
+ "in|||f": 370,
887
+ "inf|||or": 371,
888
+ "infor|||masi</w>": 372,
889
+ "anomal|||i</w>": 373,
890
+ "y|||a": 374,
891
+ "temu|||an</w>": 375,
892
+ "ber|||k": 376,
893
+ "berk|||or": 377,
894
+ "berkor|||ela": 378,
895
+ "berkorela|||si</w>": 379,
896
+ "ca|||ta": 380,
897
+ "cata|||t": 381,
898
+ "catat|||an</w>": 382,
899
+ "si|||a": 383,
900
+ "sia|||pa</w>": 384,
901
+ "m|||encu": 385,
902
+ "mencu|||ri</w>": 386,
903
+ "ter|||se": 387,
904
+ "terse|||di": 388,
905
+ "tersedi|||a</w>": 389,
906
+ "me|||m": 390,
907
+ "mem|||ili": 391,
908
+ "memili|||k": 392,
909
+ "memilik|||i</w>": 393,
910
+ "kon|||e": 394,
911
+ "kone|||ksi</w>": 395,
912
+ "c|||on": 396,
913
+ "con|||fi": 397,
914
+ "confi|||d": 398,
915
+ "confid|||ence</w>": 399,
916
+ "meng|||in": 400,
917
+ "mengin|||di": 401,
918
+ "mengindi|||ka": 402,
919
+ "mengindika|||si": 403,
920
+ "mengindikasi|||kan</w>": 404,
921
+ "pen|||e": 405,
922
+ "pene|||lu": 406,
923
+ "penelu|||su": 407,
924
+ "penelusu|||r": 408,
925
+ "penelusur|||an</w>": 409,
926
+ "lo|||g": 410,
927
+ "log|||i": 411,
928
+ "logi|||ka</w>": 412,
929
+ "p|||r": 413,
930
+ "pr|||o": 414,
931
+ "pro|||se": 415,
932
+ "prose|||s</w>": 416,
933
+ "de|||d": 417,
934
+ "ded|||u": 418,
935
+ "dedu|||ksi</w>": 419,
936
+ "meng|||kon": 420,
937
+ "mengkon|||fi": 421,
938
+ "mengkonfi|||r": 422,
939
+ "mengkonfir|||masi</w>": 423,
940
+ "comp|||le": 424,
941
+ "comple|||tion</w>": 425,
942
+ "b|||a": 426,
943
+ "ba|||h": 427,
944
+ "bah|||w": 428,
945
+ "bahw|||a</w>": 429,
946
+ "e|||v": 430,
947
+ "ev|||al": 431,
948
+ "eval|||u": 432,
949
+ "evalu|||asi</w>": 433,
950
+ "ke|||per": 434,
951
+ "keper|||ca": 435,
952
+ "keperca|||ya": 436,
953
+ "kepercaya|||an</w>": 437,
954
+ "ber|||ta": 438,
955
+ "berta|||ha": 439,
956
+ "bertaha|||p</w>": 440,
957
+ "in|||si": 441,
958
+ "insi|||d": 442,
959
+ "insid|||e</w>": 443,
960
+ "j|||o": 444,
961
+ "jo|||b</w>": 445
962
+ },
963
+ "is_trained": true
964
+ }
training_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "aam-diffusion-v1.0",
3
+ "aam_mind_source": "rsvs_graph",
4
+ "aam_body_type": "specialized_diffusion",
5
+ "architecture": {
6
+ "type": "diffusion_transformer",
7
+ "d_model": 64,
8
+ "n_layers": 2,
9
+ "n_heads": 4,
10
+ "d_ff": 128,
11
+ "vocab_size": 500,
12
+ "max_seq_len": 32,
13
+ "pos_encoding_type": "learned"
14
+ },
15
+ "diffusion": {
16
+ "n_timesteps": 50,
17
+ "n_inference_steps": 5,
18
+ "schedule_type": "cosine",
19
+ "prediction_type": "epsilon",
20
+ "sampling_method": "ddim"
21
+ },
22
+ "graph_encoder": {
23
+ "d_graph": 32,
24
+ "n_graph_layers": 1,
25
+ "conditioning_method": "cross_attention"
26
+ },
27
+ "parameters": 311670
28
+ }