AAM Diffusion LLM v1.0 — The Body of Aphantasic Abstraction Model
Browse files- .gitignore +14 -0
- README.md +253 -0
- config.json +105 -0
- diffusion_llm/README.md +331 -0
- diffusion_llm/__init__.py +60 -0
- diffusion_llm/config/__init__.py +5 -0
- diffusion_llm/config/model_config.py +620 -0
- diffusion_llm/data/__init__.py +6 -0
- diffusion_llm/data/data_pipeline.py +179 -0
- diffusion_llm/data/synthetic_generator.py +427 -0
- diffusion_llm/inference/__init__.py +5 -0
- diffusion_llm/inference/generator.py +333 -0
- diffusion_llm/model/__init__.py +13 -0
- diffusion_llm/model/aam_diffusion_model.py +475 -0
- diffusion_llm/model/diffusion_transformer.py +394 -0
- diffusion_llm/model/graph_encoder.py +553 -0
- diffusion_llm/model/noise_scheduler.py +426 -0
- diffusion_llm/requirements.txt +17 -0
- diffusion_llm/scripts/evaluate.py +157 -0
- diffusion_llm/scripts/export.py +71 -0
- diffusion_llm/scripts/train.py +168 -0
- diffusion_llm/scripts/train_final.py +686 -0
- diffusion_llm/scripts/train_minimal.py +260 -0
- diffusion_llm/tests/__init__.py +1 -0
- diffusion_llm/tests/test_model.py +239 -0
- diffusion_llm/tests/test_scheduler.py +98 -0
- diffusion_llm/tokenizer/__init__.py +5 -0
- diffusion_llm/tokenizer/aam_tokenizer.py +596 -0
- diffusion_llm/training/__init__.py +7 -0
- diffusion_llm/training/dataset.py +371 -0
- diffusion_llm/training/losses.py +127 -0
- diffusion_llm/training/trainer.py +420 -0
- inference_example.py +38 -0
- pytorch_model.bin +3 -0
- requirements.txt +2 -0
- tokenizer.json +964 -0
- training_config.json +28 -0
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AAM Diffusion LLM v1.0 — HuggingFace Repository Files
|
| 2 |
+
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
*.so
|
| 10 |
+
.env
|
| 11 |
+
output/
|
| 12 |
+
aam-diffusion-v1/
|
| 13 |
+
*.log
|
| 14 |
+
.DS_Store
|
README.md
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- id
|
| 4 |
+
- en
|
| 5 |
+
license: mit
|
| 6 |
+
library_name: pytorch
|
| 7 |
+
tags:
|
| 8 |
+
- diffusion
|
| 9 |
+
- text-generation
|
| 10 |
+
- aam
|
| 11 |
+
- aphantasic-abstraction-model
|
| 12 |
+
- sentence-arrangement
|
| 13 |
+
- graph-conditioned
|
| 14 |
+
- indonesian
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# AAM Diffusion LLM v1.0
|
| 18 |
+
|
| 19 |
+
> **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
|
| 20 |
+
|
| 21 |
+
The dedicated "body" of the **Aphantasic Abstraction Model (AAM)** — a small diffusion LLM specifically trained to arrange sentences from structured graph data.
|
| 22 |
+
|
| 23 |
+
## What is this?
|
| 24 |
+
|
| 25 |
+
This is **NOT** a general-purpose LLM. This is a **SPECIALIZED sentence composer** that:
|
| 26 |
+
- Takes **graph-structured conditioning** as input (evidence nodes, anomalies, reasoning chains, confidence scores)
|
| 27 |
+
- Produces **coherent natural language narratives** through iterative denoising (diffusion process)
|
| 28 |
+
- **Cannot hallucinate** — it can only narrate what the graph knows
|
| 29 |
+
|
| 30 |
+
### Why Diffusion (Not Autoregressive)?
|
| 31 |
+
|
| 32 |
+
1. **Non-sequential generation** — Can revise earlier parts while generating later parts, mirroring how thoughts form: vague intuition → clearer pattern → explicit narrative
|
| 33 |
+
2. **Graph conditioning** — The entire graph structure is encoded as conditioning, not just a text prefix
|
| 34 |
+
3. **Anti-hallucination by design** — Trained exclusively on Graph→Narrative pairs, the model has no capability to generate information outside the graph conditioning
|
| 35 |
+
|
| 36 |
+
## Architecture
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
┌──────────────────────────────────────────────────────────┐
|
| 40 |
+
│ AAM = 1 Pikiran + 1 Tubuh │
|
| 41 |
+
│ │
|
| 42 |
+
│ Pikiran (Mind) = RSVS Knowledge Graph │
|
| 43 |
+
│ - Structural memory — perfect recall │
|
| 44 |
+
│ - Relational — understands concept connections │
|
| 45 |
+
│ - Confidence scores — knows certainty levels │
|
| 46 |
+
│ │
|
| 47 |
+
│ Tubuh (Body) = AAM Diffusion LLM (This Model) │
|
| 48 |
+
│ ┌─────────────────────────────────────────────┐ │
|
| 49 |
+
│ │ Graph Conditioning Encoder │ │
|
| 50 |
+
│ │ ├─ Evidence Node Encoder │ │
|
| 51 |
+
│ │ ├─ Composition Encoder │ │
|
| 52 |
+
│ │ ├─ Anomaly Encoder │ │
|
| 53 |
+
│ │ ├─ Reasoning Chain Encoder │ │
|
| 54 |
+
│ │ ├─ Confidence Embedding │ │
|
| 55 |
+
│ │ ├─ Temporal Embedding │ │
|
| 56 |
+
│ │ └─ Graph Attention Layers │ │
|
| 57 |
+
│ │ ↓ (cross-attention keys/values) │ │
|
| 58 |
+
│ ├─────────────────────────────────────────────┤ │
|
| 59 |
+
│ │ Diffusion Transformer (Denoiser) │ │
|
| 60 |
+
│ │ ├─ Token Embedding │ │
|
| 61 |
+
│ │ ├─ Timestep Embedding (sinusoidal) │ │
|
| 62 |
+
│ │ ├─ N × TransformerBlock: │ │
|
| 63 |
+
│ │ │ ├─ AdaptiveLayerNorm + Self-Attention │ │
|
| 64 |
+
│ │ │ ├─ AdaptiveLayerNorm + Cross-Attention │ │
|
| 65 |
+
│ │ │ └─ AdaptiveLayerNorm + Feed-Forward │ │
|
| 66 |
+
│ │ └─ Output Projection │ │
|
| 67 |
+
│ │ ↓ (predicted noise) │ │
|
| 68 |
+
│ ├─────────────────────────────────────────────┤ │
|
| 69 |
+
│ │ Noise Scheduler │ │
|
| 70 |
+
│ │ ├─ Forward: x_0 + noise → x_t │ │
|
| 71 |
+
│ │ └─ Reverse: x_t → denoise → x_{t-1} │ │
|
| 72 |
+
│ └─────────────────────────────────────────────┘ │
|
| 73 |
+
│ │
|
| 74 |
+
│ Training: Graph→Narrative pairs │
|
| 75 |
+
│ Inference: Noise → N denoising steps → Narrative │
|
| 76 |
+
└──────────────────────────────────────────────────────────┘
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Model Details (v1.0 — Trained)
|
| 80 |
+
|
| 81 |
+
| Parameter | Value |
|
| 82 |
+
|-----------|-------|
|
| 83 |
+
| Architecture | Diffusion Transformer with Graph Conditioning |
|
| 84 |
+
| d_model | 64 |
|
| 85 |
+
| n_layers | 2 |
|
| 86 |
+
| n_heads | 4 |
|
| 87 |
+
| d_ff | 128 |
|
| 88 |
+
| **Total Parameters** | **311,670 (311.7K)** |
|
| 89 |
+
| Vocab size | 500 (BPE + special tokens) |
|
| 90 |
+
| Max sequence length | 32 |
|
| 91 |
+
| Diffusion timesteps (train) | 50 |
|
| 92 |
+
| Diffusion timesteps (inference) | 5 |
|
| 93 |
+
| Noise schedule | Cosine |
|
| 94 |
+
| Prediction type | Epsilon (noise prediction) |
|
| 95 |
+
| Sampling method | DDIM |
|
| 96 |
+
|
| 97 |
+
> **Note**: This v1.0 model was trained with minimal parameters (311K) for proof-of-concept on CPU. For production use, scale up to the `base` (170M) or `medium` (300M) configurations provided in the framework.
|
| 98 |
+
|
| 99 |
+
## Model Sizes (Framework Supports)
|
| 100 |
+
|
| 101 |
+
| Size | d_model | Layers | Heads | Params | Recommended For |
|
| 102 |
+
|------|---------|--------|-------|--------|----------------|
|
| 103 |
+
| tiny | 256 | 4 | 4 | ~25M | Quick testing, debugging |
|
| 104 |
+
| small | 512 | 8 | 8 | ~70M | Development, prototyping |
|
| 105 |
+
| **base** | **768** | **12** | **12** | **~170M** | **Recommended for training** |
|
| 106 |
+
| medium | 1024 | 12 | 16 | ~300M | Final training, best quality |
|
| 107 |
+
|
| 108 |
+
## Usage
|
| 109 |
+
|
| 110 |
+
### Quick Start
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
|
| 114 |
+
|
| 115 |
+
# Load model
|
| 116 |
+
config = AamDiffusionConfig.from_json("config.json")
|
| 117 |
+
model = AamDiffusionModel.load("model.pt", device="cpu")
|
| 118 |
+
tokenizer = AamTokenizer.load("tokenizer.json")
|
| 119 |
+
|
| 120 |
+
# Create generator
|
| 121 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 122 |
+
|
| 123 |
+
# Generate narrative from graph conditioning
|
| 124 |
+
result = generator.generate(
|
| 125 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 126 |
+
evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 127 |
+
anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 128 |
+
reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
|
| 129 |
+
source_trust=0.85,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print(result.narrative)
|
| 133 |
+
print(f"Confidence: {result.confidence:.1%}")
|
| 134 |
+
print(f"Steps: {result.n_diffusion_steps}")
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### Training Your Own Model
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
from diffusion_llm import AamDiffusionConfig, get_default_config
|
| 141 |
+
from diffusion_llm.training import AamTrainer, GraphNarrativeDataset
|
| 142 |
+
from diffusion_llm.data import DataPipeline
|
| 143 |
+
|
| 144 |
+
# Get config for your desired size
|
| 145 |
+
config = get_default_config("base") # 170M params
|
| 146 |
+
|
| 147 |
+
# Prepare data pipeline
|
| 148 |
+
pipeline = DataPipeline(config)
|
| 149 |
+
tokenizer, train_loader, val_loader = pipeline.prepare()
|
| 150 |
+
|
| 151 |
+
# Create and train model
|
| 152 |
+
model = AamDiffusionModel(config)
|
| 153 |
+
trainer = AamTrainer(config, model, tokenizer, train_loader.dataset, val_loader.dataset)
|
| 154 |
+
trainer.train()
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### Command Line
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# Train with default config
|
| 161 |
+
python diffusion_llm/scripts/train.py --model_size base
|
| 162 |
+
|
| 163 |
+
# Generate narratives
|
| 164 |
+
python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --generate
|
| 165 |
+
|
| 166 |
+
# Export model
|
| 167 |
+
python diffusion_llm/scripts/export.py --checkpoint output/best.pt --output model_export/
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## Philosophy
|
| 171 |
+
|
| 172 |
+
**AAM = 1 Pikiran + 1 Tubuh (1 Mind + 1 Body)**
|
| 173 |
+
|
| 174 |
+
- **Mind** = RSVS Knowledge Graph (structural memory, perfect recall, relational understanding)
|
| 175 |
+
- **Body** = This Diffusion LLM (sentence arranger, graph-conditioned, anti-hallucination)
|
| 176 |
+
|
| 177 |
+
Unlike using a rented LLM (GPT, Claude, etc.) as the "body", this model is **specifically trained for AAM**:
|
| 178 |
+
- It **cannot generate** information not present in the graph conditioning
|
| 179 |
+
- It **arranges sentences** based on structured evidence
|
| 180 |
+
- It uses **diffusion** (non-sequential generation) instead of autoregressive generation
|
| 181 |
+
- It is **small** but **specialized** — like Jin Soun's body in the novel, it may be "third-rate" but it's **his own**
|
| 182 |
+
|
| 183 |
+
> Jin Soun bukan orang yang menyewa tubuh orang lain untuk berbicara.
|
| 184 |
+
> Dia punya tubuh sendiri — lemah, third-rate, tapi MILIKNYA.
|
| 185 |
+
> Karena tubuhnya khusus dilatih untuk mengeksekusi perintah dari
|
| 186 |
+
> pikirannya (bukan pikiran orang lain), outputnya lebih terarah
|
| 187 |
+
> daripada orang yang punya tubuh lebih kuat tapi pikiran lebih lemah.
|
| 188 |
+
|
| 189 |
+
## Framework Structure
|
| 190 |
+
|
| 191 |
+
```
|
| 192 |
+
diffusion_llm/
|
| 193 |
+
├── __init__.py # Public API
|
| 194 |
+
├── config/
|
| 195 |
+
│ └── model_config.py # All configuration dataclasses
|
| 196 |
+
├── tokenizer/
|
| 197 |
+
│ └── aam_tokenizer.py # Sentence-level + BPE hybrid tokenizer
|
| 198 |
+
├── model/
|
| 199 |
+
│ ├── noise_scheduler.py # Forward/reverse diffusion process
|
| 200 |
+
│ ├── graph_encoder.py # Graph conditioning encoder
|
| 201 |
+
│ ├── diffusion_transformer.py # Core denoising transformer
|
| 202 |
+
│ └── aam_diffusion_model.py # Complete model (combines all)
|
| 203 |
+
├── training/
|
| 204 |
+
│ ├── losses.py # Loss functions (MSE, MAE, Huber, weighted)
|
| 205 |
+
│ ├── dataset.py # GraphNarrative dataset
|
| 206 |
+
│ └── trainer.py # Training loop with AMP, EMA, etc.
|
| 207 |
+
├── inference/
|
| 208 |
+
│ └── generator.py # Inference pipeline
|
| 209 |
+
├── data/
|
| 210 |
+
│ ├── synthetic_generator.py # Synthetic training data
|
| 211 |
+
│ └── data_pipeline.py # Data preparation pipeline
|
| 212 |
+
├── scripts/
|
| 213 |
+
│ ├── train.py # Training entry point
|
| 214 |
+
│ ├── evaluate.py # Evaluation & generation
|
| 215 |
+
│ └── export.py # Model export
|
| 216 |
+
└── tests/
|
| 217 |
+
├── test_model.py # Model component tests
|
| 218 |
+
└── test_scheduler.py # Noise scheduler tests
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
## Training Data Format
|
| 222 |
+
|
| 223 |
+
Data training dalam format JSONL:
|
| 224 |
+
|
| 225 |
+
```json
|
| 226 |
+
{
|
| 227 |
+
"narrative": "Berdasarkan analisis, Diancang Five Swords mencuri Snow Plum Pill.",
|
| 228 |
+
"trigger": "Siapa yang mencuri Snow Plum Pill?",
|
| 229 |
+
"evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 230 |
+
"compositions": [],
|
| 231 |
+
"confidence_map": {"Hefei": 0.9, "Diancang Five Swords": 0.85},
|
| 232 |
+
"anomalies": ["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 233 |
+
"reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
|
| 234 |
+
"source_trust": 0.85,
|
| 235 |
+
"language": "id"
|
| 236 |
+
}
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
## License
|
| 240 |
+
|
| 241 |
+
MIT
|
| 242 |
+
|
| 243 |
+
## Citation
|
| 244 |
+
|
| 245 |
+
```bibtex
|
| 246 |
+
@software{aam_diffusion_llm_v1,
|
| 247 |
+
title = {AAM Diffusion LLM: The Body of Aphantasic Abstraction Model},
|
| 248 |
+
author = {AAM Team},
|
| 249 |
+
year = {2026},
|
| 250 |
+
description = {A specialized diffusion LLM for sentence arrangement from graph-structured data},
|
| 251 |
+
url = {https://huggingface.co/aam-diffusion-v1}
|
| 252 |
+
}
|
| 253 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": {
|
| 3 |
+
"d_model": 64,
|
| 4 |
+
"n_layers": 2,
|
| 5 |
+
"n_heads": 4,
|
| 6 |
+
"d_ff": 128,
|
| 7 |
+
"dropout": 0.1,
|
| 8 |
+
"activation": "gelu",
|
| 9 |
+
"max_seq_len": 32,
|
| 10 |
+
"vocab_size": 500,
|
| 11 |
+
"pos_encoding_type": "learned",
|
| 12 |
+
"use_flash_attention": false,
|
| 13 |
+
"norm_type": "layernorm",
|
| 14 |
+
"norm_eps": 1e-06,
|
| 15 |
+
"init_std": 0.02
|
| 16 |
+
},
|
| 17 |
+
"diffusion": {
|
| 18 |
+
"n_timesteps": 50,
|
| 19 |
+
"n_inference_steps": 5,
|
| 20 |
+
"schedule_type": "cosine",
|
| 21 |
+
"beta_start": 0.0001,
|
| 22 |
+
"beta_end": 0.02,
|
| 23 |
+
"prediction_type": "epsilon",
|
| 24 |
+
"sampling_method": "ddim",
|
| 25 |
+
"eta_ddim": 0.0,
|
| 26 |
+
"clip_sample_max": 5.0,
|
| 27 |
+
"clip_sample_min": -5.0,
|
| 28 |
+
"loss_type": "mse",
|
| 29 |
+
"loss_weighting": "none",
|
| 30 |
+
"p2_gamma": 1.0,
|
| 31 |
+
"p2_k": 1.0
|
| 32 |
+
},
|
| 33 |
+
"graph_encoder": {
|
| 34 |
+
"d_graph": 32,
|
| 35 |
+
"n_graph_layers": 1,
|
| 36 |
+
"n_graph_heads": 2,
|
| 37 |
+
"max_evidence_nodes": 3,
|
| 38 |
+
"max_compositions": 2,
|
| 39 |
+
"max_anomalies": 2,
|
| 40 |
+
"max_reasoning_steps": 2,
|
| 41 |
+
"conditioning_method": "cross_attention",
|
| 42 |
+
"embed_confidence": false,
|
| 43 |
+
"embed_temporal": false
|
| 44 |
+
},
|
| 45 |
+
"tokenizer": {
|
| 46 |
+
"bpe_vocab_size": 500,
|
| 47 |
+
"max_sentences": 32,
|
| 48 |
+
"sentence_boundary_token": "<sent>",
|
| 49 |
+
"pad_token": "<pad>",
|
| 50 |
+
"bos_token": "<bos>",
|
| 51 |
+
"eos_token": "<eos>",
|
| 52 |
+
"mask_token": "<mask>",
|
| 53 |
+
"noise_token": "<noise>",
|
| 54 |
+
"evidence_token": "<evidence>",
|
| 55 |
+
"anomaly_token": "<anomaly>",
|
| 56 |
+
"confidence_token": "<confidence>",
|
| 57 |
+
"reasoning_token": "<reasoning>",
|
| 58 |
+
"composition_token": "<composition>",
|
| 59 |
+
"temporal_token": "<temporal>",
|
| 60 |
+
"min_frequency": 2,
|
| 61 |
+
"dropout_rate": 0.0
|
| 62 |
+
},
|
| 63 |
+
"training": {
|
| 64 |
+
"learning_rate": 0.001,
|
| 65 |
+
"weight_decay": 0.01,
|
| 66 |
+
"adam_beta1": 0.9,
|
| 67 |
+
"adam_beta2": 0.999,
|
| 68 |
+
"adam_eps": 1e-08,
|
| 69 |
+
"lr_schedule": "cosine",
|
| 70 |
+
"warmup_steps": 5,
|
| 71 |
+
"batch_size": 2,
|
| 72 |
+
"gradient_accumulation_steps": 4,
|
| 73 |
+
"max_steps": 50,
|
| 74 |
+
"max_epochs": 100,
|
| 75 |
+
"dropout": 0.1,
|
| 76 |
+
"grad_clip_norm": 1.0,
|
| 77 |
+
"use_amp": false,
|
| 78 |
+
"amp_dtype": "bf16",
|
| 79 |
+
"save_every_steps": 5000,
|
| 80 |
+
"eval_every_steps": 1000,
|
| 81 |
+
"keep_last_n_checkpoints": 3,
|
| 82 |
+
"use_ema": true,
|
| 83 |
+
"ema_decay": 0.9999,
|
| 84 |
+
"train_data_path": "",
|
| 85 |
+
"val_data_path": "",
|
| 86 |
+
"num_workers": 0,
|
| 87 |
+
"log_every_steps": 100,
|
| 88 |
+
"wandb_project": "aam-diffusion-llm",
|
| 89 |
+
"wandb_run_name": ""
|
| 90 |
+
},
|
| 91 |
+
"inference": {
|
| 92 |
+
"n_steps": 5,
|
| 93 |
+
"temperature": 1.0,
|
| 94 |
+
"top_k": 50,
|
| 95 |
+
"top_p": 0.95,
|
| 96 |
+
"repetition_penalty": 1.2,
|
| 97 |
+
"max_output_sentences": 16,
|
| 98 |
+
"language": "id"
|
| 99 |
+
},
|
| 100 |
+
"model_name": "aam-diffusion-v1.0",
|
| 101 |
+
"output_dir": "./aam-diffusion-v1",
|
| 102 |
+
"seed": 42,
|
| 103 |
+
"aam_mind_source": "rsvs_graph",
|
| 104 |
+
"aam_body_type": "specialized_diffusion"
|
| 105 |
+
}
|
diffusion_llm/README.md
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AAM Diffusion LLM Framework
|
| 2 |
+
|
| 3 |
+
> **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
|
| 4 |
+
|
| 5 |
+
Framework khusus untuk melatih Diffusion LLM yang menjadi "tubuh" (body) dari Aphantasic Abstraction Model (AAM). Ini BUKAN LLM umum — ini model yang KHUSUS dilatih untuk menyusun kalimat dari data graph yang terstruktur.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Filosofi
|
| 10 |
+
|
| 11 |
+
### Kenapa Bukan LLM Umum?
|
| 12 |
+
|
| 13 |
+
Konsep sebelumnya: "tubuh Jin Soun = LLM umum (GPT, Claude, dll.)" — ini **salah besar**.
|
| 14 |
+
|
| 15 |
+
| Aspek | LLM Umum (Sewaan) | AAM Diffusion LLM (Milik Sendiri) |
|
| 16 |
+
|-------|-------------------|-----------------------------------|
|
| 17 |
+
| Input | Prompt teks | Graph conditioning (evidence, anomaly, dll.) |
|
| 18 |
+
| Output | Teks probabilistik | Narrative yang grounded di graph |
|
| 19 |
+
| Hallucination | BISA mengarang | TIDAK BISA — hanya menarasikan apa yang graph ketahui |
|
| 20 |
+
| Tujuan | General purpose | Khusus menyusun kalimat dari graph |
|
| 21 |
+
| Ukuran | 7B-175B params | 100M-500M params |
|
| 22 |
+
| Metode | Autoregressive | Diffusion (non-sequential) |
|
| 23 |
+
| Identitas | Sewaan | MILIK AAM sendiri |
|
| 24 |
+
|
| 25 |
+
### Kenapa Diffusion (Bukan Autoregressive)?
|
| 26 |
+
|
| 27 |
+
1. **Non-sequential** — Bisa merevisi bagian awal saat generating bagian akhir. Mirip cara Jin Soun membentuk pikiran: vague → clearer → explicit.
|
| 28 |
+
|
| 29 |
+
2. **Graph conditioning** — Seluruh graph bisa di-encode sebagai conditioning, bukan hanya prefix. Autoregressive hanya bisa melihat "apa yang sudah di-generate sebelumnya."
|
| 30 |
+
|
| 31 |
+
3. **Coherent long-form** — Diffusion menghasilkan teks yang lebih koheren untuk narasi panjang karena setiap token "mengetahui" tentang token lain.
|
| 32 |
+
|
| 33 |
+
4. **Anti-hallucination** — Model dilatih KHUSUS untuk Graph→Narrative, tidak punya kapabilitas mengarang informasi di luar graph.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Arsitektur
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
┌──────────────────────────────────────────────────────────┐
|
| 41 |
+
│ AAM = 1 Pikiran + 1 Tubuh │
|
| 42 |
+
│ │
|
| 43 |
+
│ Pikiran (Mind) = RSVS Knowledge Graph │
|
| 44 |
+
│ - Structural memory — mengingat SEMUA │
|
| 45 |
+
│ - Relational — memahami koneksi antar konsep │
|
| 46 |
+
│ - Perfect recall — tidak pernah lupa │
|
| 47 |
+
│ - Confidence scores — tahu apa yang pasti vs ragu │
|
| 48 |
+
│ │
|
| 49 |
+
│ Tubuh (Body) = AAM Diffusion LLM │
|
| 50 |
+
│ ┌─────────────────────────────────────────────┐ │
|
| 51 |
+
│ │ Graph Conditioning Encoder │ │
|
| 52 |
+
│ │ ├─ Evidence Node Encoder │ │
|
| 53 |
+
│ │ ├─ Composition Encoder │ │
|
| 54 |
+
│ │ ├─ Anomaly Encoder │ │
|
| 55 |
+
│ │ ├─ Reasoning Chain Encoder │ │
|
| 56 |
+
│ │ ├─ Confidence Embedding │ │
|
| 57 |
+
│ │ ├─ Temporal Embedding │ │
|
| 58 |
+
│ │ └─ Graph Attention Layers │ │
|
| 59 |
+
│ │ ↓ (cross-attention keys/values) │ │
|
| 60 |
+
│ ├─────────────────────────────────────────────┤ │
|
| 61 |
+
│ │ Diffusion Transformer (Denoiser) │ │
|
| 62 |
+
│ │ ├─ Token Embedding │ │
|
| 63 |
+
│ │ ├─ Timestep Embedding (sinusoidal) │ │
|
| 64 |
+
│ │ ├─ N × TransformerBlock: │ │
|
| 65 |
+
│ │ │ ├─ AdaptiveLayerNorm + Self-Attention │ │
|
| 66 |
+
│ │ │ ├─ AdaptiveLayerNorm + Cross-Attention │ │
|
| 67 |
+
│ │ │ └─ AdaptiveLayerNorm + Feed-Forward │ │
|
| 68 |
+
│ │ └─ Output Projection │ │
|
| 69 |
+
│ │ ↓ (predicted noise) │ │
|
| 70 |
+
│ ├─────────────────────────────────────────────┤ │
|
| 71 |
+
│ │ Noise Scheduler │ │
|
| 72 |
+
│ │ ├─ Forward: x_0 + noise → x_t │ │
|
| 73 |
+
│ │ └─ Reverse: x_t → denoise → x_{t-1} │ │
|
| 74 |
+
│ └─────────────────────────────────────────────┘ │
|
| 75 |
+
│ │
|
| 76 |
+
│ Training: Graph→Narrative pairs │
|
| 77 |
+
│ Inference: Noise → N denoising steps → Narrative │
|
| 78 |
+
└─────────────────────────────���────────────────────────────┘
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Struktur Folder
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
diffusion_llm/
|
| 87 |
+
├── __init__.py # Package init with public API
|
| 88 |
+
├── config/
|
| 89 |
+
│ ├── __init__.py
|
| 90 |
+
│ └── model_config.py # All configuration dataclasses
|
| 91 |
+
├── tokenizer/
|
| 92 |
+
│ ├── __init__.py
|
| 93 |
+
│ └── aam_tokenizer.py # Sentence-level + BPE hybrid tokenizer
|
| 94 |
+
├── model/
|
| 95 |
+
│ ├── __init__.py
|
| 96 |
+
│ ├── noise_scheduler.py # Forward/reverse diffusion process
|
| 97 |
+
│ ├── graph_encoder.py # Graph conditioning encoder
|
| 98 |
+
│ ├── diffusion_transformer.py # Core denoising transformer
|
| 99 |
+
│ └── aam_diffusion_model.py # Complete model (combines all)
|
| 100 |
+
├── training/
|
| 101 |
+
│ ├── __init__.py
|
| 102 |
+
│ ├── losses.py # Loss functions (MSE, MAE, Huber, weighted)
|
| 103 |
+
│ ├── dataset.py # GraphNarrative dataset
|
| 104 |
+
│ └── trainer.py # Training loop with AMP, EMA, etc.
|
| 105 |
+
├── inference/
|
| 106 |
+
│ ├── __init__.py
|
| 107 |
+
│ └── generator.py # Inference pipeline
|
| 108 |
+
├── data/
|
| 109 |
+
│ ├── __init__.py
|
| 110 |
+
│ ├── synthetic_generator.py # Synthetic training data
|
| 111 |
+
│ └── data_pipeline.py # Data preparation pipeline
|
| 112 |
+
├── scripts/
|
| 113 |
+
│ ├── train.py # Training entry point
|
| 114 |
+
│ ├── evaluate.py # Evaluation & generation
|
| 115 |
+
│ └── export.py # Model export
|
| 116 |
+
├── tests/
|
| 117 |
+
│ ├── __init__.py
|
| 118 |
+
│ ├── test_scheduler.py # Noise scheduler tests
|
| 119 |
+
│ └── test_model.py # Model component tests
|
| 120 |
+
├── requirements.txt # Python dependencies
|
| 121 |
+
└── README.md # This file
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
## Quick Start
|
| 127 |
+
|
| 128 |
+
### 1. Install Dependencies
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
pip install torch numpy pytest
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### 2. Generate Synthetic Data
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 138 |
+
|
| 139 |
+
generator = SyntheticDataGenerator(seed=42, language="id")
|
| 140 |
+
train_path, val_path = generator.generate_training_split(
|
| 141 |
+
output_dir="./data",
|
| 142 |
+
n_train=10000,
|
| 143 |
+
n_val=500,
|
| 144 |
+
)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### 3. Train the Model
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Quick test with tiny model
|
| 151 |
+
python diffusion_llm/scripts/train.py --model_size tiny --max_steps 100
|
| 152 |
+
|
| 153 |
+
# Full training with base model
|
| 154 |
+
python diffusion_llm/scripts/train.py --model_size base --max_steps 500000
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### 4. Generate Narratives
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# Generate samples
|
| 161 |
+
python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --generate
|
| 162 |
+
|
| 163 |
+
# Interactive mode
|
| 164 |
+
python diffusion_llm/scripts/evaluate.py --checkpoint output/best.pt --interactive
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### 5. Programmatic Usage
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
from diffusion_llm import (
|
| 171 |
+
AamDiffusionConfig, get_default_config,
|
| 172 |
+
AamDiffusionModel, AamTokenizer, AamGenerator,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Load model and tokenizer
|
| 176 |
+
config = AamDiffusionConfig.from_json("output/config.json")
|
| 177 |
+
model = AamDiffusionModel.load("output/best.pt")
|
| 178 |
+
tokenizer = AamTokenizer.load("output/data/tokenizer.json")
|
| 179 |
+
|
| 180 |
+
# Create generator
|
| 181 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 182 |
+
|
| 183 |
+
# Generate narrative from graph conditioning
|
| 184 |
+
result = generator.generate(
|
| 185 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 186 |
+
evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 187 |
+
anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 188 |
+
reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali"],
|
| 189 |
+
source_trust=0.85,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
print(result.narrative)
|
| 193 |
+
print(f"Confidence: {result.confidence:.1%}")
|
| 194 |
+
print(f"Steps: {result.n_diffusion_steps}")
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## Model Sizes
|
| 200 |
+
|
| 201 |
+
| Size | d_model | Layers | Heads | Params | Recommended For |
|
| 202 |
+
|------|---------|--------|-------|--------|----------------|
|
| 203 |
+
| tiny | 256 | 4 | 4 | ~25M | Quick testing, debugging |
|
| 204 |
+
| small | 512 | 8 | 8 | ~70M | Development, prototyping |
|
| 205 |
+
| **base** | **768** | **12** | **12** | **~170M** | **Recommended for training** |
|
| 206 |
+
| medium | 1024 | 12 | 16 | ~300M | Final training, best quality |
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## Konfigurasi
|
| 211 |
+
|
| 212 |
+
### Model Config
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, ModelConfig, DiffusionConfig
|
| 216 |
+
|
| 217 |
+
config = AamDiffusionConfig(
|
| 218 |
+
model=ModelConfig(
|
| 219 |
+
d_model=768, # Hidden dimension
|
| 220 |
+
n_layers=12, # Transformer blocks
|
| 221 |
+
n_heads=12, # Attention heads
|
| 222 |
+
d_ff=3072, # Feed-forward dimension
|
| 223 |
+
vocab_size=32000, # Vocabulary size
|
| 224 |
+
max_seq_len=512, # Maximum sequence length
|
| 225 |
+
),
|
| 226 |
+
diffusion=DiffusionConfig(
|
| 227 |
+
n_timesteps=1000, # Training timesteps
|
| 228 |
+
n_inference_steps=50, # Inference steps (fewer = faster)
|
| 229 |
+
schedule_type="cosine", # Noise schedule
|
| 230 |
+
prediction_type="epsilon", # Predict noise
|
| 231 |
+
sampling_method="ddim", # Fast deterministic sampling
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
### Inference Config
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
from diffusion_llm.config.model_config import InferenceConfig
|
| 240 |
+
|
| 241 |
+
inference = InferenceConfig(
|
| 242 |
+
n_steps=50, # Denoising steps
|
| 243 |
+
temperature=1.0, # Sampling temperature
|
| 244 |
+
top_k=50, # Top-k sampling
|
| 245 |
+
max_output_sentences=16, # Max sentences
|
| 246 |
+
language="id", # Output language
|
| 247 |
+
)
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## Integrasi dengan AAM Pipeline
|
| 253 |
+
|
| 254 |
+
Framework ini dirancang untuk menjadi "tubuh" dari AAM. Setelah model dilatih,
|
| 255 |
+
integrasi dengan `pipeline.py` sangat mudah:
|
| 256 |
+
|
| 257 |
+
```python
|
| 258 |
+
# Dalam pipeline.py, ganti fallback:
|
| 259 |
+
from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator
|
| 260 |
+
|
| 261 |
+
class AamPipeline:
|
| 262 |
+
def __init__(self, ...):
|
| 263 |
+
# Load trained diffusion model
|
| 264 |
+
diffusion_config = AamDiffusionConfig.from_json("path/to/config.json")
|
| 265 |
+
diffusion_model = AamDiffusionModel.load("path/to/best.pt")
|
| 266 |
+
diffusion_tokenizer = AamTokenizer.load("path/to/tokenizer.json")
|
| 267 |
+
self.diffusion_llm = AamGenerator(diffusion_model, diffusion_tokenizer, diffusion_config)
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
---
|
| 271 |
+
|
| 272 |
+
## Training Data Format
|
| 273 |
+
|
| 274 |
+
Data training dalam format JSONL, satu contoh per baris:
|
| 275 |
+
|
| 276 |
+
```json
|
| 277 |
+
{
|
| 278 |
+
"narrative": "Berdasarkan analisis, Diancang Five Swords mencuri Snow Plum Pill menggunakan Ju Jangmok sebagai kambing hitam.",
|
| 279 |
+
"trigger": "Siapa yang mencuri Snow Plum Pill?",
|
| 280 |
+
"evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
|
| 281 |
+
"compositions": [],
|
| 282 |
+
"confidence_map": {"Hefei": 0.9, "Diancang Five Swords": 0.85, "Ju Jangmok": 0.7},
|
| 283 |
+
"anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
|
| 284 |
+
"reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola", "Pattern completion dari bukti terpisah"],
|
| 285 |
+
"source_trust": 0.85,
|
| 286 |
+
"temporal_context": [],
|
| 287 |
+
"language": "id",
|
| 288 |
+
"source": "synthetic"
|
| 289 |
+
}
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
## Running Tests
|
| 295 |
+
|
| 296 |
+
```bash
|
| 297 |
+
# Run all tests
|
| 298 |
+
cd diffusion_llm
|
| 299 |
+
python -m pytest tests/ -v
|
| 300 |
+
|
| 301 |
+
# Run specific test
|
| 302 |
+
python -m pytest tests/test_model.py -v
|
| 303 |
+
|
| 304 |
+
# Run with coverage
|
| 305 |
+
python -m pytest tests/ --cov=diffusion_llm
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
+
## Roadmap
|
| 311 |
+
|
| 312 |
+
- [x] **Phase 1: Framework Design** — Arsitektur, config, interface
|
| 313 |
+
- [x] **Phase 2: Core Components** — Noise scheduler, transformer, graph encoder, tokenizer
|
| 314 |
+
- [x] **Phase 3: Training Infrastructure** — Trainer, dataset, loss functions, synthetic data
|
| 315 |
+
- [x] **Phase 4: Inference Pipeline** — Generator, batch generation, interactive mode
|
| 316 |
+
- [ ] **Phase 5: Training Execution** — Train on synthetic data, iterate
|
| 317 |
+
- [ ] **Phase 6: Real Data** — Collect real Graph→Narrative pairs from AAM usage
|
| 318 |
+
- [ ] **Phase 7: Optimization** — Quantization, distillation, flash attention
|
| 319 |
+
- [ ] **Phase 8: Integration** — Plug trained model into AAM pipeline
|
| 320 |
+
|
| 321 |
+
---
|
| 322 |
+
|
| 323 |
+
## Analogi Novel
|
| 324 |
+
|
| 325 |
+
> Jin Soun bukan orang yang menyewa tubuh orang lain untuk berbicara.
|
| 326 |
+
> Dia punya tubuh sendiri — lemah, third-rate, tapi MILIKNYA.
|
| 327 |
+
> Karena tubuhnya khusus dilatih untuk mengeksekusi perintah dari
|
| 328 |
+
> pikirannya (bukan pikiran orang lain), outputnya lebih terarah
|
| 329 |
+
> daripada orang yang punya tubuh lebih kuat tapi pikiran lebih lemah.
|
| 330 |
+
>
|
| 331 |
+
> **AAM = 1 pikiran + 1 tubuh. Bukan 1 pikiran + tubuh sewaan.**
|
diffusion_llm/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM Framework — The Body of Aphantasic Abstraction Model
|
| 3 |
+
|
| 4 |
+
"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)
|
| 5 |
+
|
| 6 |
+
Pikiran (Mind) = RSVS Knowledge Graph — structural, relational, perfect memory
|
| 7 |
+
Tubuh (Body) = This Diffusion LLM — generates natural language FROM the graph
|
| 8 |
+
|
| 9 |
+
This is NOT a general-purpose LLM. This is a SPECIALIZED sentence composer
|
| 10 |
+
that takes structured graph data as input and produces coherent, evidence-backed
|
| 11 |
+
narrative output. Think of it as a "vocal cord" for the graph — it can only
|
| 12 |
+
say what the graph knows, but it says it fluently.
|
| 13 |
+
|
| 14 |
+
Why Diffusion?
|
| 15 |
+
- Diffusion models start from noise and iteratively denoise
|
| 16 |
+
- This mirrors how Jin Soun's thoughts form: from vague intuition ->
|
| 17 |
+
clearer pattern -> explicit narrative
|
| 18 |
+
- Unlike autoregressive LLMs (GPT), diffusion models can:
|
| 19 |
+
- Be conditioned on structured input (graph)
|
| 20 |
+
- Revise earlier parts during generation (non-sequential)
|
| 21 |
+
- Produce more coherent long-form text from structure
|
| 22 |
+
|
| 23 |
+
Architecture:
|
| 24 |
+
Input: Graph conditioning (evidence nodes, compositions, confidence, anomalies)
|
| 25 |
+
Process: Iterative denoising from noise
|
| 26 |
+
Output: Natural language narrative grounded in graph structure
|
| 27 |
+
|
| 28 |
+
Analogi: Jin Soun (graph) + tubuhnya (this model).
|
| 29 |
+
Tubuhnya third-rate, tapi karena KHUSUS dilatih untuk
|
| 30 |
+
mengeksekusi perintah dari graph-nya sendiri, outputnya
|
| 31 |
+
lebih terarah daripada LLM umum yang "tidak kenal" graph.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
__version__ = "0.1.0"
|
| 35 |
+
__author__ = "AAM Team"
|
| 36 |
+
|
| 37 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
|
| 38 |
+
from diffusion_llm.model.noise_scheduler import NoiseScheduler
|
| 39 |
+
from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
|
| 40 |
+
from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
|
| 41 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 42 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 43 |
+
from diffusion_llm.inference.generator import AamGenerator
|
| 44 |
+
from diffusion_llm.training.trainer import AamTrainer
|
| 45 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset
|
| 46 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 47 |
+
|
| 48 |
+
__all__ = [
|
| 49 |
+
"AamDiffusionConfig",
|
| 50 |
+
"get_default_config",
|
| 51 |
+
"NoiseScheduler",
|
| 52 |
+
"GraphConditioningEncoder",
|
| 53 |
+
"DiffusionTransformer",
|
| 54 |
+
"AamDiffusionModel",
|
| 55 |
+
"AamTokenizer",
|
| 56 |
+
"AamGenerator",
|
| 57 |
+
"AamTrainer",
|
| 58 |
+
"GraphNarrativeDataset",
|
| 59 |
+
"SyntheticDataGenerator",
|
| 60 |
+
]
|
diffusion_llm/config/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration module for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
|
| 4 |
+
|
| 5 |
+
__all__ = ["AamDiffusionConfig", "get_default_config"]
|
diffusion_llm/config/model_config.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Model Configuration
|
| 3 |
+
|
| 4 |
+
Defines all hyperparameters for the diffusion model architecture,
|
| 5 |
+
training process, and inference pipeline.
|
| 6 |
+
|
| 7 |
+
Design Philosophy:
|
| 8 |
+
- Small model (100M-500M params) — specialized, not general
|
| 9 |
+
- Sentence-level tokenization — not subword, because AAM arranges
|
| 10 |
+
sentences, not individual tokens
|
| 11 |
+
- Graph-conditioned — the model MUST receive graph structure as input
|
| 12 |
+
- Non-sequential generation — diffusion, not autoregressive
|
| 13 |
+
|
| 14 |
+
Analogi: Seperti tubuh Jin Soun, model ini kecil tapi KKHUSUS
|
| 15 |
+
dilatih untuk satu tugas: menarasikan dari graph. Tidak perlu
|
| 16 |
+
7B params kalau tugasku hanya menyusun kalimat dari data yang
|
| 17 |
+
sudah terstruktur.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
from dataclasses import dataclass, field, asdict
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
"""Architecture hyperparameters for the Diffusion Transformer.
|
| 31 |
+
|
| 32 |
+
Target: 100M-500M parameters total.
|
| 33 |
+
Calculation:
|
| 34 |
+
params ≈ d_model^2 * (12 * n_layers) for transformer
|
| 35 |
+
d_model=512, n_layers=8 → ~50M core params
|
| 36 |
+
d_model=768, n_layers=12 → ~170M core params
|
| 37 |
+
d_model=1024, n_layers=12 → ~300M core params
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# --- Core Transformer ---
|
| 41 |
+
d_model: int = 768
|
| 42 |
+
"""Hidden dimension of the transformer."""
|
| 43 |
+
|
| 44 |
+
n_layers: int = 12
|
| 45 |
+
"""Number of transformer blocks."""
|
| 46 |
+
|
| 47 |
+
n_heads: int = 12
|
| 48 |
+
"""Number of attention heads (d_model must be divisible by n_heads)."""
|
| 49 |
+
|
| 50 |
+
d_ff: int = 3072
|
| 51 |
+
"""Feed-forward hidden dimension (typically 4x d_model)."""
|
| 52 |
+
|
| 53 |
+
dropout: float = 0.1
|
| 54 |
+
"""Dropout rate for attention and feed-forward layers."""
|
| 55 |
+
|
| 56 |
+
activation: str = "gelu"
|
| 57 |
+
"""Activation function: 'gelu' or 'relu'."""
|
| 58 |
+
|
| 59 |
+
# --- Sequence ---
|
| 60 |
+
max_seq_len: int = 512
|
| 61 |
+
"""Maximum sequence length (in sentence-level tokens)."""
|
| 62 |
+
|
| 63 |
+
# --- Vocabulary ---
|
| 64 |
+
vocab_size: int = 32000
|
| 65 |
+
"""Vocabulary size for the tokenizer.
|
| 66 |
+
Since we use sentence-level tokens + subword BPE hybrid,
|
| 67 |
+
this includes special tokens + subword units.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# --- Positional Encoding ---
|
| 71 |
+
pos_encoding_type: str = "rotary"
|
| 72 |
+
"""Positional encoding type: 'rotary' (RoPE) or 'learned'."""
|
| 73 |
+
|
| 74 |
+
# --- Attention ---
|
| 75 |
+
use_flash_attention: bool = True
|
| 76 |
+
"""Whether to use Flash Attention 2 if available."""
|
| 77 |
+
|
| 78 |
+
# --- Normalization ---
|
| 79 |
+
norm_type: str = "rmsnorm"
|
| 80 |
+
"""Normalization type: 'rmsnorm' or 'layernorm'."""
|
| 81 |
+
|
| 82 |
+
norm_eps: float = 1e-6
|
| 83 |
+
"""Epsilon for normalization layers."""
|
| 84 |
+
|
| 85 |
+
# --- Initialization ---
|
| 86 |
+
init_std: float = 0.02
|
| 87 |
+
"""Standard deviation for weight initialization."""
|
| 88 |
+
|
| 89 |
+
def estimate_params(self) -> str:
|
| 90 |
+
"""Estimate total parameter count."""
|
| 91 |
+
# Embedding: vocab_size * d_model
|
| 92 |
+
embed_params = self.vocab_size * self.d_model
|
| 93 |
+
# Per layer: 4 * d_model^2 (QKV + O) + 2 * d_model * d_ff (FF)
|
| 94 |
+
layer_params = 4 * self.d_model ** 2 + 2 * self.d_model * self.d_ff
|
| 95 |
+
total = embed_params + self.n_layers * layer_params
|
| 96 |
+
if total >= 1e9:
|
| 97 |
+
return f"{total / 1e9:.1f}B"
|
| 98 |
+
elif total >= 1e6:
|
| 99 |
+
return f"{total / 1e6:.1f}M"
|
| 100 |
+
else:
|
| 101 |
+
return f"{total / 1e3:.1f}K"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class DiffusionConfig:
|
| 106 |
+
"""Hyperparameters for the diffusion process.
|
| 107 |
+
|
| 108 |
+
The diffusion process works on the latent representation of text:
|
| 109 |
+
1. Forward: Add Gaussian noise to text embeddings over T timesteps
|
| 110 |
+
2. Reverse: Learn to denoise step by step
|
| 111 |
+
3. At inference: Start from pure noise, denoise to coherent text
|
| 112 |
+
|
| 113 |
+
This is DIFFERENT from image diffusion because:
|
| 114 |
+
- We operate in a learned latent space (not pixel space)
|
| 115 |
+
- Text has discrete structure (sentences, not pixels)
|
| 116 |
+
- We use a text-specific noise schedule
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# --- Noise Schedule ---
|
| 120 |
+
n_timesteps: int = 1000
|
| 121 |
+
"""Total number of diffusion timesteps for training."""
|
| 122 |
+
|
| 123 |
+
n_inference_steps: int = 50
|
| 124 |
+
"""Number of denoising steps at inference (fewer = faster, less quality)."""
|
| 125 |
+
|
| 126 |
+
schedule_type: str = "cosine"
|
| 127 |
+
"""Noise schedule type: 'linear', 'cosine', or 'sigmoid'."""
|
| 128 |
+
|
| 129 |
+
beta_start: float = 1e-4
|
| 130 |
+
"""Starting beta for linear schedule."""
|
| 131 |
+
|
| 132 |
+
beta_end: float = 0.02
|
| 133 |
+
"""Ending beta for linear schedule."""
|
| 134 |
+
|
| 135 |
+
# --- Noise Prediction ---
|
| 136 |
+
prediction_type: str = "epsilon"
|
| 137 |
+
"""What the model predicts: 'epsilon' (noise), 'x0' (clean data),
|
| 138 |
+
or 'v' (velocity). Epsilon prediction is most stable for text."""
|
| 139 |
+
|
| 140 |
+
# --- Sampling ---
|
| 141 |
+
sampling_method: str = "ddim"
|
| 142 |
+
"""Sampling method: 'ddpm' (slow, stochastic) or 'ddim' (fast, deterministic)."""
|
| 143 |
+
|
| 144 |
+
eta_ddim: float = 0.0
|
| 145 |
+
"""DDIM stochasticity parameter (0 = deterministic, 1 = full stochastic)."""
|
| 146 |
+
|
| 147 |
+
# --- Clipping ---
|
| 148 |
+
clip_sample_max: float = 5.0
|
| 149 |
+
"""Maximum value for clipped samples during inference."""
|
| 150 |
+
|
| 151 |
+
clip_sample_min: float = -5.0
|
| 152 |
+
"""Minimum value for clipped samples during inference."""
|
| 153 |
+
|
| 154 |
+
# --- Loss ---
|
| 155 |
+
loss_type: str = "mse"
|
| 156 |
+
"""Loss function: 'mse' (L2) or 'mae' (L1) or 'huber'."""
|
| 157 |
+
|
| 158 |
+
loss_weighting: str = "min_snr"
|
| 159 |
+
"""Loss weighting strategy: 'none', 'min_snr', or 'p2'."""
|
| 160 |
+
|
| 161 |
+
p2_gamma: float = 1.0
|
| 162 |
+
"""P2 weighting gamma (only used if loss_weighting='p2')."""
|
| 163 |
+
|
| 164 |
+
p2_k: float = 1.0
|
| 165 |
+
"""P2 weighting k (only used if loss_weighting='p2')."""
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@dataclass
|
| 169 |
+
class GraphEncoderConfig:
|
| 170 |
+
"""Configuration for the Graph Conditioning Encoder.
|
| 171 |
+
|
| 172 |
+
The graph encoder takes structured graph data (evidence nodes,
|
| 173 |
+
compositions, confidence scores, anomalies, reasoning chains)
|
| 174 |
+
and produces a conditioning vector that guides the diffusion process.
|
| 175 |
+
|
| 176 |
+
This is the KEY differentiator from general LLMs:
|
| 177 |
+
the model is conditioned on GRAPH STRUCTURE, not just text prompts.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
# --- Graph Encoder Architecture ---
|
| 181 |
+
d_graph: int = 512
|
| 182 |
+
"""Hidden dimension for graph encoding."""
|
| 183 |
+
|
| 184 |
+
n_graph_layers: int = 4
|
| 185 |
+
"""Number of graph attention layers."""
|
| 186 |
+
|
| 187 |
+
n_graph_heads: int = 8
|
| 188 |
+
"""Number of attention heads for graph encoding."""
|
| 189 |
+
|
| 190 |
+
# --- Input Dimensions ---
|
| 191 |
+
max_evidence_nodes: int = 50
|
| 192 |
+
"""Maximum number of evidence nodes to encode."""
|
| 193 |
+
|
| 194 |
+
max_compositions: int = 20
|
| 195 |
+
"""Maximum number of compositions to encode."""
|
| 196 |
+
|
| 197 |
+
max_anomalies: int = 10
|
| 198 |
+
"""Maximum number of anomalies to encode."""
|
| 199 |
+
|
| 200 |
+
max_reasoning_steps: int = 15
|
| 201 |
+
"""Maximum number of reasoning steps to encode."""
|
| 202 |
+
|
| 203 |
+
# --- Conditioning Injection ---
|
| 204 |
+
conditioning_method: str = "cross_attention"
|
| 205 |
+
"""How to inject graph conditioning into the diffusion model:
|
| 206 |
+
'cross_attention' (separate encoder, cross-attn in transformer)
|
| 207 |
+
'ada_ln' (adaptive layer norm, conditioning modulates scale/shift)
|
| 208 |
+
'concat' (concatenate conditioning to input sequence)
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
# --- Confidence Embedding ---
|
| 212 |
+
embed_confidence: bool = True
|
| 213 |
+
"""Whether to embed confidence scores as part of the conditioning."""
|
| 214 |
+
|
| 215 |
+
# --- Temporal Embedding ---
|
| 216 |
+
embed_temporal: bool = True
|
| 217 |
+
"""Whether to embed temporal context (time-based relationships)."""
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@dataclass
|
| 221 |
+
class TokenizerConfig:
|
| 222 |
+
"""Configuration for the AAM Sentence-Level Tokenizer.
|
| 223 |
+
|
| 224 |
+
Unlike standard BPE tokenizers that operate at subword level,
|
| 225 |
+
AAM's tokenizer is designed for SENTENCE ARRANGEMENT:
|
| 226 |
+
- Sentences are the primary unit of generation
|
| 227 |
+
- Within sentences, subword BPE handles individual words
|
| 228 |
+
- Special tokens for graph structure (evidence, anomaly, etc.)
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
# --- BPE ---
|
| 232 |
+
bpe_vocab_size: int = 28000
|
| 233 |
+
"""Subword BPE vocabulary size (within the total vocab_size)."""
|
| 234 |
+
|
| 235 |
+
# --- Sentence-Level ---
|
| 236 |
+
max_sentences: int = 32
|
| 237 |
+
"""Maximum number of sentences in one generation."""
|
| 238 |
+
|
| 239 |
+
sentence_boundary_token: str = "<sent>"
|
| 240 |
+
"""Token marking sentence boundaries."""
|
| 241 |
+
|
| 242 |
+
# --- Special Tokens ---
|
| 243 |
+
pad_token: str = "<pad>"
|
| 244 |
+
bos_token: str = "<bos>"
|
| 245 |
+
eos_token: str = "<eos>"
|
| 246 |
+
mask_token: str = "<mask>"
|
| 247 |
+
noise_token: str = "<noise>"
|
| 248 |
+
|
| 249 |
+
# --- Graph-Structure Tokens ---
|
| 250 |
+
evidence_token: str = "<evidence>"
|
| 251 |
+
anomaly_token: str = "<anomaly>"
|
| 252 |
+
confidence_token: str = "<confidence>"
|
| 253 |
+
reasoning_token: str = "<reasoning>"
|
| 254 |
+
composition_token: str = "<composition>"
|
| 255 |
+
temporal_token: str = "<temporal>"
|
| 256 |
+
|
| 257 |
+
# --- Training ---
|
| 258 |
+
min_frequency: int = 2
|
| 259 |
+
"""Minimum frequency for BPE merge operations."""
|
| 260 |
+
|
| 261 |
+
dropout_rate: float = 0.0
|
| 262 |
+
"""BPE dropout rate (0 = no dropout, regularization during training)."""
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@dataclass
|
| 266 |
+
class TrainingConfig:
|
| 267 |
+
"""Training hyperparameters and settings."""
|
| 268 |
+
|
| 269 |
+
# --- Optimizer ---
|
| 270 |
+
learning_rate: float = 1e-4
|
| 271 |
+
"""Peak learning rate."""
|
| 272 |
+
|
| 273 |
+
weight_decay: float = 0.01
|
| 274 |
+
"""Weight decay for AdamW."""
|
| 275 |
+
|
| 276 |
+
adam_beta1: float = 0.9
|
| 277 |
+
"""Adam beta1."""
|
| 278 |
+
|
| 279 |
+
adam_beta2: float = 0.999
|
| 280 |
+
"""Adam beta2."""
|
| 281 |
+
|
| 282 |
+
adam_eps: float = 1e-8
|
| 283 |
+
"""Adam epsilon."""
|
| 284 |
+
|
| 285 |
+
# --- Learning Rate Schedule ---
|
| 286 |
+
lr_schedule: str = "cosine"
|
| 287 |
+
"""LR schedule: 'cosine', 'linear', or 'constant'."""
|
| 288 |
+
|
| 289 |
+
warmup_steps: int = 2000
|
| 290 |
+
"""Number of warmup steps."""
|
| 291 |
+
|
| 292 |
+
# --- Training ---
|
| 293 |
+
batch_size: int = 32
|
| 294 |
+
"""Training batch size (per GPU)."""
|
| 295 |
+
|
| 296 |
+
gradient_accumulation_steps: int = 4
|
| 297 |
+
"""Gradient accumulation steps (effective batch = batch_size * this)."""
|
| 298 |
+
|
| 299 |
+
max_steps: int = 500000
|
| 300 |
+
"""Maximum training steps."""
|
| 301 |
+
|
| 302 |
+
max_epochs: int = 100
|
| 303 |
+
"""Maximum training epochs."""
|
| 304 |
+
|
| 305 |
+
# --- Regularization ---
|
| 306 |
+
dropout: float = 0.1
|
| 307 |
+
"""Training dropout rate."""
|
| 308 |
+
|
| 309 |
+
grad_clip_norm: float = 1.0
|
| 310 |
+
"""Gradient clipping max norm."""
|
| 311 |
+
|
| 312 |
+
# --- Mixed Precision ---
|
| 313 |
+
use_amp: bool = True
|
| 314 |
+
"""Whether to use Automatic Mixed Precision (fp16/bf16)."""
|
| 315 |
+
|
| 316 |
+
amp_dtype: str = "bf16"
|
| 317 |
+
"""AMP data type: 'fp16' or 'bf16'."""
|
| 318 |
+
|
| 319 |
+
# --- Checkpointing ---
|
| 320 |
+
save_every_steps: int = 5000
|
| 321 |
+
"""Save checkpoint every N steps."""
|
| 322 |
+
|
| 323 |
+
eval_every_steps: int = 1000
|
| 324 |
+
"""Evaluate every N steps."""
|
| 325 |
+
|
| 326 |
+
keep_last_n_checkpoints: int = 3
|
| 327 |
+
"""Keep only the last N checkpoints."""
|
| 328 |
+
|
| 329 |
+
# --- EMA ---
|
| 330 |
+
use_ema: bool = True
|
| 331 |
+
"""Whether to use Exponential Moving Average for inference weights."""
|
| 332 |
+
|
| 333 |
+
ema_decay: float = 0.9999
|
| 334 |
+
"""EMA decay rate."""
|
| 335 |
+
|
| 336 |
+
# --- Data ---
|
| 337 |
+
train_data_path: str = ""
|
| 338 |
+
"""Path to training data (JSONL format)."""
|
| 339 |
+
|
| 340 |
+
val_data_path: str = ""
|
| 341 |
+
"""Path to validation data (JSONL format)."""
|
| 342 |
+
|
| 343 |
+
num_workers: int = 4
|
| 344 |
+
"""Number of data loading workers."""
|
| 345 |
+
|
| 346 |
+
# --- Logging ---
|
| 347 |
+
log_every_steps: int = 100
|
| 348 |
+
"""Log training metrics every N steps."""
|
| 349 |
+
|
| 350 |
+
wandb_project: str = "aam-diffusion-llm"
|
| 351 |
+
"""Weights & Biases project name."""
|
| 352 |
+
|
| 353 |
+
wandb_run_name: str = ""
|
| 354 |
+
"""Weights & Biases run name (auto-generated if empty)."""
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@dataclass
|
| 358 |
+
class InferenceConfig:
|
| 359 |
+
"""Inference-time configuration."""
|
| 360 |
+
|
| 361 |
+
n_steps: int = 50
|
| 362 |
+
"""Number of denoising steps (more = better quality, slower)."""
|
| 363 |
+
|
| 364 |
+
temperature: float = 1.0
|
| 365 |
+
"""Sampling temperature (1.0 = standard, <1 = more deterministic)."""
|
| 366 |
+
|
| 367 |
+
top_k: int = 50
|
| 368 |
+
"""Top-k sampling for token decoding."""
|
| 369 |
+
|
| 370 |
+
top_p: float = 0.95
|
| 371 |
+
"""Nucleus sampling threshold."""
|
| 372 |
+
|
| 373 |
+
repetition_penalty: float = 1.2
|
| 374 |
+
"""Penalty for repeating tokens."""
|
| 375 |
+
|
| 376 |
+
max_output_sentences: int = 16
|
| 377 |
+
"""Maximum number of sentences in output."""
|
| 378 |
+
|
| 379 |
+
language: str = "id"
|
| 380 |
+
"""Output language: 'id' (Indonesian) or 'en' (English)."""
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@dataclass
|
| 384 |
+
class AamDiffusionConfig:
|
| 385 |
+
"""Master configuration for the AAM Diffusion LLM.
|
| 386 |
+
|
| 387 |
+
Combines all sub-configurations into a single object.
|
| 388 |
+
This is the entry point for configuring the entire framework.
|
| 389 |
+
"""
|
| 390 |
+
|
| 391 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 392 |
+
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
|
| 393 |
+
graph_encoder: GraphEncoderConfig = field(default_factory=GraphEncoderConfig)
|
| 394 |
+
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
| 395 |
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
| 396 |
+
inference: InferenceConfig = field(default_factory=InferenceConfig)
|
| 397 |
+
|
| 398 |
+
# --- Meta ---
|
| 399 |
+
model_name: str = "aam-diffusion-v0.1"
|
| 400 |
+
"""Model name for saving/loading."""
|
| 401 |
+
|
| 402 |
+
output_dir: str = "./output"
|
| 403 |
+
"""Base output directory."""
|
| 404 |
+
|
| 405 |
+
seed: int = 42
|
| 406 |
+
"""Random seed for reproducibility."""
|
| 407 |
+
|
| 408 |
+
# --- AAM Philosophy ---
|
| 409 |
+
aam_mind_source: str = "rsvs_graph"
|
| 410 |
+
"""Source of the 'mind' that conditions this 'body'.
|
| 411 |
+
Always 'rsvs_graph' for AAM — the model CANNOT generate
|
| 412 |
+
information not present in the graph conditioning."""
|
| 413 |
+
|
| 414 |
+
aam_body_type: str = "specialized_diffusion"
|
| 415 |
+
"""Type of the 'body'. Always 'specialized_diffusion' for AAM.
|
| 416 |
+
This is NOT a general LLM — it only arranges sentences
|
| 417 |
+
based on graph-structured evidence."""
|
| 418 |
+
|
| 419 |
+
def to_dict(self) -> dict:
|
| 420 |
+
"""Serialize config to dictionary."""
|
| 421 |
+
return asdict(self)
|
| 422 |
+
|
| 423 |
+
def to_json(self, path: str | Path) -> None:
|
| 424 |
+
"""Save config to JSON file."""
|
| 425 |
+
path = Path(path)
|
| 426 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 427 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 428 |
+
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
| 429 |
+
|
| 430 |
+
@classmethod
|
| 431 |
+
def from_json(cls, path: str | Path) -> AamDiffusionConfig:
|
| 432 |
+
"""Load config from JSON file."""
|
| 433 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 434 |
+
data = json.load(f)
|
| 435 |
+
return cls(
|
| 436 |
+
model=ModelConfig(**data.get("model", {})),
|
| 437 |
+
diffusion=DiffusionConfig(**data.get("diffusion", {})),
|
| 438 |
+
graph_encoder=GraphEncoderConfig(**data.get("graph_encoder", {})),
|
| 439 |
+
tokenizer=TokenizerConfig(**data.get("tokenizer", {})),
|
| 440 |
+
training=TrainingConfig(**data.get("training", {})),
|
| 441 |
+
inference=InferenceConfig(**data.get("inference", {})),
|
| 442 |
+
model_name=data.get("model_name", "aam-diffusion-v0.1"),
|
| 443 |
+
output_dir=data.get("output_dir", "./output"),
|
| 444 |
+
seed=data.get("seed", 42),
|
| 445 |
+
aam_mind_source=data.get("aam_mind_source", "rsvs_graph"),
|
| 446 |
+
aam_body_type=data.get("aam_body_type", "specialized_diffusion"),
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def summary(self) -> str:
|
| 450 |
+
"""Print a summary of the configuration."""
|
| 451 |
+
lines = [
|
| 452 |
+
"=" * 60,
|
| 453 |
+
f" AAM Diffusion LLM Configuration: {self.model_name}",
|
| 454 |
+
"=" * 60,
|
| 455 |
+
"",
|
| 456 |
+
f" Model Architecture:",
|
| 457 |
+
f" d_model={self.model.d_model}, n_layers={self.model.n_layers}, "
|
| 458 |
+
f"n_heads={self.model.n_heads}",
|
| 459 |
+
f" d_ff={self.model.d_ff}, vocab_size={self.model.vocab_size}",
|
| 460 |
+
f" max_seq_len={self.model.max_seq_len}",
|
| 461 |
+
f" Estimated params: {self.model.estimate_params()}",
|
| 462 |
+
"",
|
| 463 |
+
f" Diffusion Process:",
|
| 464 |
+
f" Timesteps (train)={self.diffusion.n_timesteps}",
|
| 465 |
+
f" Timesteps (inference)={self.diffusion.n_inference_steps}",
|
| 466 |
+
f" Schedule={self.diffusion.schedule_type}",
|
| 467 |
+
f" Prediction={self.diffusion.prediction_type}",
|
| 468 |
+
f" Sampling={self.diffusion.sampling_method}",
|
| 469 |
+
"",
|
| 470 |
+
f" Graph Encoder:",
|
| 471 |
+
f" d_graph={self.graph_encoder.d_graph}",
|
| 472 |
+
f" n_layers={self.graph_encoder.n_graph_layers}",
|
| 473 |
+
f" Conditioning={self.graph_encoder.conditioning_method}",
|
| 474 |
+
f" Max evidence nodes={self.graph_encoder.max_evidence_nodes}",
|
| 475 |
+
"",
|
| 476 |
+
f" Training:",
|
| 477 |
+
f" LR={self.training.learning_rate}",
|
| 478 |
+
f" Batch={self.training.batch_size} x {self.training.gradient_accumulation_steps} accum",
|
| 479 |
+
f" Max steps={self.training.max_steps}",
|
| 480 |
+
f" AMP={self.training.use_amp} ({self.training.amp_dtype})",
|
| 481 |
+
"",
|
| 482 |
+
f" AAM Philosophy:",
|
| 483 |
+
f" Mind = {self.aam_mind_source} (RSVS Knowledge Graph)",
|
| 484 |
+
f" Body = {self.aam_body_type} (This Model)",
|
| 485 |
+
f" Identity = 1 Mind + 1 Body (NOT rented LLM)",
|
| 486 |
+
"",
|
| 487 |
+
"=" * 60,
|
| 488 |
+
]
|
| 489 |
+
return "\n".join(lines)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def get_default_config(
|
| 493 |
+
model_size: str = "base",
|
| 494 |
+
) -> AamDiffusionConfig:
|
| 495 |
+
"""Get a default configuration for different model sizes.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
model_size: One of 'tiny', 'small', 'base', 'medium'.
|
| 499 |
+
- tiny: ~25M params (for quick testing)
|
| 500 |
+
- small: ~70M params (for development)
|
| 501 |
+
- base: ~170M params (recommended for training)
|
| 502 |
+
- medium: ~300M params (for final training)
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
AamDiffusionConfig with appropriate settings.
|
| 506 |
+
"""
|
| 507 |
+
configs = {
|
| 508 |
+
"tiny": AamDiffusionConfig(
|
| 509 |
+
model=ModelConfig(
|
| 510 |
+
d_model=256,
|
| 511 |
+
n_layers=4,
|
| 512 |
+
n_heads=4,
|
| 513 |
+
d_ff=1024,
|
| 514 |
+
vocab_size=16000,
|
| 515 |
+
max_seq_len=256,
|
| 516 |
+
),
|
| 517 |
+
graph_encoder=GraphEncoderConfig(
|
| 518 |
+
d_graph=256,
|
| 519 |
+
n_graph_layers=2,
|
| 520 |
+
n_graph_heads=4,
|
| 521 |
+
),
|
| 522 |
+
diffusion=DiffusionConfig(
|
| 523 |
+
n_timesteps=500,
|
| 524 |
+
n_inference_steps=20,
|
| 525 |
+
),
|
| 526 |
+
training=TrainingConfig(
|
| 527 |
+
batch_size=16,
|
| 528 |
+
learning_rate=3e-4,
|
| 529 |
+
warmup_steps=500,
|
| 530 |
+
max_steps=100000,
|
| 531 |
+
),
|
| 532 |
+
model_name="aam-diffusion-tiny",
|
| 533 |
+
),
|
| 534 |
+
"small": AamDiffusionConfig(
|
| 535 |
+
model=ModelConfig(
|
| 536 |
+
d_model=512,
|
| 537 |
+
n_layers=8,
|
| 538 |
+
n_heads=8,
|
| 539 |
+
d_ff=2048,
|
| 540 |
+
vocab_size=24000,
|
| 541 |
+
max_seq_len=384,
|
| 542 |
+
),
|
| 543 |
+
graph_encoder=GraphEncoderConfig(
|
| 544 |
+
d_graph=384,
|
| 545 |
+
n_graph_layers=4,
|
| 546 |
+
n_graph_heads=8,
|
| 547 |
+
),
|
| 548 |
+
diffusion=DiffusionConfig(
|
| 549 |
+
n_timesteps=1000,
|
| 550 |
+
n_inference_steps=30,
|
| 551 |
+
),
|
| 552 |
+
training=TrainingConfig(
|
| 553 |
+
batch_size=24,
|
| 554 |
+
learning_rate=2e-4,
|
| 555 |
+
warmup_steps=1000,
|
| 556 |
+
max_steps=200000,
|
| 557 |
+
),
|
| 558 |
+
model_name="aam-diffusion-small",
|
| 559 |
+
),
|
| 560 |
+
"base": AamDiffusionConfig(
|
| 561 |
+
model=ModelConfig(
|
| 562 |
+
d_model=768,
|
| 563 |
+
n_layers=12,
|
| 564 |
+
n_heads=12,
|
| 565 |
+
d_ff=3072,
|
| 566 |
+
vocab_size=32000,
|
| 567 |
+
max_seq_len=512,
|
| 568 |
+
),
|
| 569 |
+
graph_encoder=GraphEncoderConfig(
|
| 570 |
+
d_graph=512,
|
| 571 |
+
n_graph_layers=4,
|
| 572 |
+
n_graph_heads=8,
|
| 573 |
+
),
|
| 574 |
+
diffusion=DiffusionConfig(
|
| 575 |
+
n_timesteps=1000,
|
| 576 |
+
n_inference_steps=50,
|
| 577 |
+
),
|
| 578 |
+
training=TrainingConfig(
|
| 579 |
+
batch_size=32,
|
| 580 |
+
learning_rate=1e-4,
|
| 581 |
+
warmup_steps=2000,
|
| 582 |
+
max_steps=500000,
|
| 583 |
+
),
|
| 584 |
+
model_name="aam-diffusion-base",
|
| 585 |
+
),
|
| 586 |
+
"medium": AamDiffusionConfig(
|
| 587 |
+
model=ModelConfig(
|
| 588 |
+
d_model=1024,
|
| 589 |
+
n_layers=12,
|
| 590 |
+
n_heads=16,
|
| 591 |
+
d_ff=4096,
|
| 592 |
+
vocab_size=32000,
|
| 593 |
+
max_seq_len=768,
|
| 594 |
+
),
|
| 595 |
+
graph_encoder=GraphEncoderConfig(
|
| 596 |
+
d_graph=768,
|
| 597 |
+
n_graph_layers=6,
|
| 598 |
+
n_graph_heads=12,
|
| 599 |
+
),
|
| 600 |
+
diffusion=DiffusionConfig(
|
| 601 |
+
n_timesteps=1000,
|
| 602 |
+
n_inference_steps=50,
|
| 603 |
+
),
|
| 604 |
+
training=TrainingConfig(
|
| 605 |
+
batch_size=16,
|
| 606 |
+
learning_rate=5e-5,
|
| 607 |
+
warmup_steps=5000,
|
| 608 |
+
max_steps=1000000,
|
| 609 |
+
),
|
| 610 |
+
model_name="aam-diffusion-medium",
|
| 611 |
+
),
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
if model_size not in configs:
|
| 615 |
+
raise ValueError(
|
| 616 |
+
f"Unknown model_size '{model_size}'. "
|
| 617 |
+
f"Choose from: {list(configs.keys())}"
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
return configs[model_size]
|
diffusion_llm/data/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data pipeline module for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 4 |
+
from diffusion_llm.data.data_pipeline import DataPipeline
|
| 5 |
+
|
| 6 |
+
__all__ = ["SyntheticDataGenerator", "DataPipeline"]
|
diffusion_llm/data/data_pipeline.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Data Pipeline
|
| 3 |
+
|
| 4 |
+
Orchestrates data preparation: from raw graph data and narratives
|
| 5 |
+
to tokenized, batched training data.
|
| 6 |
+
|
| 7 |
+
The pipeline handles:
|
| 8 |
+
1. Loading raw graph→narrative pairs
|
| 9 |
+
2. Generating synthetic data if real data isn't available
|
| 10 |
+
3. Tokenizing all data
|
| 11 |
+
4. Creating train/val splits
|
| 12 |
+
5. Building DataLoaders
|
| 13 |
+
|
| 14 |
+
Analogi: Seperti proses persiapan sebelum Jin Soun berlatih —
|
| 15 |
+
mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan
|
| 16 |
+
data latihan yang terstruktur.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
|
| 27 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig
|
| 28 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 29 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
|
| 30 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DataPipeline:
|
| 36 |
+
"""Data preparation pipeline for AAM Diffusion LLM training.
|
| 37 |
+
|
| 38 |
+
Orchestrates the entire data preparation process:
|
| 39 |
+
1. Check for existing data
|
| 40 |
+
2. Generate synthetic data if needed
|
| 41 |
+
3. Train tokenizer on the data
|
| 42 |
+
4. Create datasets and dataloaders
|
| 43 |
+
|
| 44 |
+
Usage:
|
| 45 |
+
pipeline = DataPipeline(config)
|
| 46 |
+
tokenizer, train_loader, val_loader = pipeline.prepare()
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: AamDiffusionConfig):
|
| 50 |
+
self.config = config
|
| 51 |
+
self.output_dir = Path(config.output_dir) / "data"
|
| 52 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
def prepare(
|
| 55 |
+
self,
|
| 56 |
+
tokenizer: Optional[AamTokenizer] = None,
|
| 57 |
+
force_regenerate: bool = False,
|
| 58 |
+
) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]:
|
| 59 |
+
"""Prepare all data for training.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
tokenizer: Optional pre-trained tokenizer.
|
| 63 |
+
force_regenerate: Whether to regenerate synthetic data.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (tokenizer, train_loader, val_loader).
|
| 67 |
+
"""
|
| 68 |
+
train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None
|
| 69 |
+
val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None
|
| 70 |
+
|
| 71 |
+
# Step 1: Generate synthetic data if no real data
|
| 72 |
+
if not train_path or not train_path.exists() or force_regenerate:
|
| 73 |
+
logger.info("Generating synthetic training data...")
|
| 74 |
+
train_path, val_path = SyntheticDataGenerator.generate_training_split(
|
| 75 |
+
output_dir=self.output_dir,
|
| 76 |
+
n_train=10000,
|
| 77 |
+
n_val=500,
|
| 78 |
+
language=self.config.inference.language,
|
| 79 |
+
seed=self.config.seed,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Step 2: Train tokenizer if not provided
|
| 83 |
+
if tokenizer is None or not tokenizer.is_trained:
|
| 84 |
+
logger.info("Training tokenizer...")
|
| 85 |
+
tokenizer = AamTokenizer()
|
| 86 |
+
# Read training texts for tokenizer training
|
| 87 |
+
texts = self._read_texts(train_path)
|
| 88 |
+
tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size)
|
| 89 |
+
tokenizer.save(self.output_dir / "tokenizer.json")
|
| 90 |
+
logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size)
|
| 91 |
+
|
| 92 |
+
# Step 3: Create datasets
|
| 93 |
+
logger.info("Creating datasets...")
|
| 94 |
+
train_dataset = GraphNarrativeDataset(
|
| 95 |
+
data_path=train_path,
|
| 96 |
+
tokenizer=tokenizer,
|
| 97 |
+
max_seq_len=self.config.model.max_seq_len,
|
| 98 |
+
max_evidence=self.config.graph_encoder.max_evidence_nodes,
|
| 99 |
+
max_anomalies=self.config.graph_encoder.max_anomalies,
|
| 100 |
+
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
val_dataset = None
|
| 104 |
+
if val_path and val_path.exists():
|
| 105 |
+
val_dataset = GraphNarrativeDataset(
|
| 106 |
+
data_path=val_path,
|
| 107 |
+
tokenizer=tokenizer,
|
| 108 |
+
max_seq_len=self.config.model.max_seq_len,
|
| 109 |
+
max_evidence=self.config.graph_encoder.max_evidence_nodes,
|
| 110 |
+
max_anomalies=self.config.graph_encoder.max_anomalies,
|
| 111 |
+
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
|
| 112 |
+
augment=False, # No augmentation for validation
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Step 4: Create dataloaders
|
| 116 |
+
train_loader = DataLoader(
|
| 117 |
+
train_dataset,
|
| 118 |
+
batch_size=self.config.training.batch_size,
|
| 119 |
+
shuffle=True,
|
| 120 |
+
num_workers=self.config.training.num_workers,
|
| 121 |
+
collate_fn=collate_fn,
|
| 122 |
+
pin_memory=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
val_loader = None
|
| 126 |
+
if val_dataset:
|
| 127 |
+
val_loader = DataLoader(
|
| 128 |
+
val_dataset,
|
| 129 |
+
batch_size=self.config.training.batch_size,
|
| 130 |
+
shuffle=False,
|
| 131 |
+
num_workers=self.config.training.num_workers,
|
| 132 |
+
collate_fn=collate_fn,
|
| 133 |
+
pin_memory=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
logger.info(
|
| 137 |
+
"Data pipeline ready: %d training examples, %s validation examples",
|
| 138 |
+
len(train_dataset),
|
| 139 |
+
len(val_dataset) if val_dataset else 0,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return tokenizer, train_loader, val_loader
|
| 143 |
+
|
| 144 |
+
def _read_texts(self, path: Path) -> list[str]:
|
| 145 |
+
"""Read narrative texts from JSONL file for tokenizer training.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
path: Path to JSONL data file.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
List of narrative texts.
|
| 152 |
+
"""
|
| 153 |
+
import json
|
| 154 |
+
texts = []
|
| 155 |
+
if not path.exists():
|
| 156 |
+
return texts
|
| 157 |
+
|
| 158 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 159 |
+
for line in f:
|
| 160 |
+
line = line.strip()
|
| 161 |
+
if not line:
|
| 162 |
+
continue
|
| 163 |
+
try:
|
| 164 |
+
data = json.loads(line)
|
| 165 |
+
# Collect both narratives and evidence for richer tokenizer
|
| 166 |
+
if data.get("narrative"):
|
| 167 |
+
texts.append(data["narrative"])
|
| 168 |
+
if data.get("trigger"):
|
| 169 |
+
texts.append(data["trigger"])
|
| 170 |
+
for ev in data.get("evidence_nodes", []):
|
| 171 |
+
texts.append(ev)
|
| 172 |
+
for anom in data.get("anomalies", []):
|
| 173 |
+
texts.append(anom)
|
| 174 |
+
for step in data.get("reasoning_steps", []):
|
| 175 |
+
texts.append(step)
|
| 176 |
+
except json.JSONDecodeError:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
return texts
|
diffusion_llm/data/synthetic_generator.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Synthetic Data Generator
|
| 3 |
+
|
| 4 |
+
Generates synthetic Graph→Narrative training pairs for
|
| 5 |
+
pre-training the diffusion model before real data is available.
|
| 6 |
+
|
| 7 |
+
The synthetic data follows the AAM pattern:
|
| 8 |
+
- Graph conditioning: evidence, compositions, anomalies, reasoning
|
| 9 |
+
- Target narrative: natural language text that represents the graph data
|
| 10 |
+
|
| 11 |
+
This is essential because:
|
| 12 |
+
1. We need training data before the model can be used
|
| 13 |
+
2. The data must follow the Graph→Narrative format specifically
|
| 14 |
+
3. Synthetic data helps bootstrap the model's ability to
|
| 15 |
+
arrange sentences from structured evidence
|
| 16 |
+
|
| 17 |
+
Analogi: Seperti Jin Soun berlatih dengan kasus-kasus fiktif
|
| 18 |
+
sebelum menghadapi kasus nyata — data sintetis memberikan
|
| 19 |
+
"latihan dasar" sebelum data asli tersedia.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import random
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# --- Templates for synthetic data generation ---
|
| 35 |
+
|
| 36 |
+
# Indonesian narrative templates
|
| 37 |
+
ID_TEMPLATES = {
|
| 38 |
+
"analysis": [
|
| 39 |
+
"Berdasarkan analisis terhadap {trigger}: {evidence_summary}. {reasoning_summary}. Tingkat keyakinan: {confidence_pct}.",
|
| 40 |
+
"Analisis menunjukkan bahwa {trigger} terkait dengan {evidence_summary}. {anomaly_summary}. Kesimpulan: {reasoning_summary}.",
|
| 41 |
+
"Dari data yang tersedia, {trigger} memiliki koneksi ke {evidence_summary}. {reasoning_summary}. Confidence: {confidence_pct}.",
|
| 42 |
+
"Hasil investigasi: {trigger}. Bukti: {evidence_summary}. {anomaly_summary}. {reasoning_summary}.",
|
| 43 |
+
"Temuan: {trigger} berkorelasi dengan {evidence_summary}. Catatan: {anomaly_summary}. Analisis: {reasoning_summary}.",
|
| 44 |
+
],
|
| 45 |
+
"evidence_summary": [
|
| 46 |
+
"bukti menunjukkan {nodes}",
|
| 47 |
+
"data dari {nodes} mengindikasikan",
|
| 48 |
+
"{nodes} menjadi kunci",
|
| 49 |
+
"informasi dari {nodes} mengarah ke",
|
| 50 |
+
"sumber {nodes} mengkonfirmasi",
|
| 51 |
+
],
|
| 52 |
+
"anomaly_summary": [
|
| 53 |
+
"Anomali terdeteksi: {anomalies}",
|
| 54 |
+
"Perhatian: {anomalies}",
|
| 55 |
+
"Pola tidak lazim: {anomalies}",
|
| 56 |
+
"Ketidaksesuaian ditemukan: {anomalies}",
|
| 57 |
+
"Terdapat kejanggalan: {anomalies}",
|
| 58 |
+
],
|
| 59 |
+
"reasoning_summary": [
|
| 60 |
+
"Langkah penalaran: {steps}",
|
| 61 |
+
"Proses deduksi: {steps}",
|
| 62 |
+
"Analisis bertahap: {steps}",
|
| 63 |
+
"Penelusuran logika: {steps}",
|
| 64 |
+
"Rantai penalaran: {steps}",
|
| 65 |
+
],
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# English narrative templates
|
| 69 |
+
EN_TEMPLATES = {
|
| 70 |
+
"analysis": [
|
| 71 |
+
"Based on analysis of {trigger}: {evidence_summary}. {reasoning_summary}. Confidence: {confidence_pct}.",
|
| 72 |
+
"Analysis indicates that {trigger} is related to {evidence_summary}. {anomaly_summary}. Conclusion: {reasoning_summary}.",
|
| 73 |
+
"From available data, {trigger} has connections to {evidence_summary}. {reasoning_summary}. Confidence level: {confidence_pct}.",
|
| 74 |
+
"Investigation results: {trigger}. Evidence: {evidence_summary}. {anomaly_summary}. {reasoning_summary}.",
|
| 75 |
+
"Findings: {trigger} correlates with {evidence_summary}. Note: {anomaly_summary}. Analysis: {reasoning_summary}.",
|
| 76 |
+
],
|
| 77 |
+
"evidence_summary": [
|
| 78 |
+
"evidence shows {nodes}",
|
| 79 |
+
"data from {nodes} indicates",
|
| 80 |
+
"{nodes} are key factors",
|
| 81 |
+
"information from {nodes} points to",
|
| 82 |
+
"sources {nodes} confirm",
|
| 83 |
+
],
|
| 84 |
+
"anomaly_summary": [
|
| 85 |
+
"Anomaly detected: {anomalies}",
|
| 86 |
+
"Note: {anomalies}",
|
| 87 |
+
"Unusual pattern: {anomalies}",
|
| 88 |
+
"Inconsistency found: {anomalies}",
|
| 89 |
+
"Irregularity observed: {anomalies}",
|
| 90 |
+
],
|
| 91 |
+
"reasoning_summary": [
|
| 92 |
+
"Reasoning steps: {steps}",
|
| 93 |
+
"Deductive process: {steps}",
|
| 94 |
+
"Step-by-step analysis: {steps}",
|
| 95 |
+
"Logical trace: {steps}",
|
| 96 |
+
"Reasoning chain: {steps}",
|
| 97 |
+
],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Sample graph data for synthetic generation
|
| 101 |
+
SAMPLE_EVIDENCE_NODES = {
|
| 102 |
+
"id": [
|
| 103 |
+
"Hefei", "Diancang Five Swords", "Ju Jangmok", "Snow Plum Pill",
|
| 104 |
+
"Gyeryong Merchant Guild", "Simhyeon Pavilion", "Martial Alliance",
|
| 105 |
+
"Gu Ilmu", "Jang Hangi", "Blood Serpent Dance Step",
|
| 106 |
+
"taeul_sect", "dark_faction", "hefei_branch",
|
| 107 |
+
],
|
| 108 |
+
"en": [
|
| 109 |
+
"Hefei", "Diancang Five Swords", "Ju Jangmok", "Snow Plum Pill",
|
| 110 |
+
"Gyeryong Merchant Guild", "Simhyeon Pavilion", "Martial Alliance",
|
| 111 |
+
"Gu Ilmu", "Jang Hangi", "Blood Serpent Dance Step",
|
| 112 |
+
"taeul_sect", "dark_faction", "hefei_branch",
|
| 113 |
+
],
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
SAMPLE_TRIGGERS = {
|
| 117 |
+
"id": [
|
| 118 |
+
"Siapa yang mencuri Snow Plum Pill?",
|
| 119 |
+
"Analisis pergerakan Diancang Five Swords",
|
| 120 |
+
"Hubungan antara Ju Jangmok dan pencurian",
|
| 121 |
+
"Anomali dalam laporan Hefei",
|
| 122 |
+
"Investigasi inside job di Diancang",
|
| 123 |
+
"Pola konsumsi Snow Plum Pill",
|
| 124 |
+
"Cross-reference kejadian di Hefei",
|
| 125 |
+
"Evaluasi kepercayaan sumber informasi",
|
| 126 |
+
"Prediksi tindakan berikutnya tersangka",
|
| 127 |
+
"Pattern completion dari bukti terpisah",
|
| 128 |
+
],
|
| 129 |
+
"en": [
|
| 130 |
+
"Who stole the Snow Plum Pill?",
|
| 131 |
+
"Analysis of Diancang Five Swords movements",
|
| 132 |
+
"Connection between Ju Jangmok and the theft",
|
| 133 |
+
"Anomalies in the Hefei reports",
|
| 134 |
+
"Investigation of inside job at Diancang",
|
| 135 |
+
"Pattern of Snow Plum Pill consumption",
|
| 136 |
+
"Cross-referencing events in Hefei",
|
| 137 |
+
"Source trustworthiness evaluation",
|
| 138 |
+
"Predicting next suspect actions",
|
| 139 |
+
"Pattern completion from disparate evidence",
|
| 140 |
+
],
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
SAMPLE_ANOMALIES = {
|
| 144 |
+
"id": [
|
| 145 |
+
"Tidak ada konsumsi pil baru di pasar gelap",
|
| 146 |
+
"Pencuri menghilang tanpa jejak",
|
| 147 |
+
"Success rate pair lebih tinggi dari biasanya",
|
| 148 |
+
"Misi di-assign dari dalam Diancang sendiri",
|
| 149 |
+
"Ju Jangmok menghilang hari yang sama dengan pencurian",
|
| 150 |
+
"Tidak ada pencuri baru setelah Ju Jangmok menghilang",
|
| 151 |
+
],
|
| 152 |
+
"en": [
|
| 153 |
+
"No new pill consumption in black market",
|
| 154 |
+
"Thief disappeared without a trace",
|
| 155 |
+
"Pair success rate unusually high",
|
| 156 |
+
"Mission assigned from within Diancang itself",
|
| 157 |
+
"Ju Jangmok disappeared same day as theft",
|
| 158 |
+
"No new thief appeared after Ju Jangmok vanished",
|
| 159 |
+
],
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
SAMPLE_REASONING_STEPS = {
|
| 163 |
+
"id": [
|
| 164 |
+
"Recall: Ingat semua laporan terkait Hefei",
|
| 165 |
+
"Cross-reference: Bandingkan tanggal kejadian",
|
| 166 |
+
"Filter: Eliminasi yang tidak relevan",
|
| 167 |
+
"Anomaly: Deteksi ketidaksesuaian pola",
|
| 168 |
+
"Pattern: Hubungkan fragmen terpisah",
|
| 169 |
+
"Compose: Susun kesimpulan dari bukti",
|
| 170 |
+
"Predict: Perkirakan tindakan berikutnya",
|
| 171 |
+
"Verify: Cek konsistensi kesimpulan",
|
| 172 |
+
],
|
| 173 |
+
"en": [
|
| 174 |
+
"Recall: Remember all reports related to Hefei",
|
| 175 |
+
"Cross-reference: Compare event dates",
|
| 176 |
+
"Filter: Eliminate irrelevant data",
|
| 177 |
+
"Anomaly: Detect pattern inconsistency",
|
| 178 |
+
"Pattern: Connect disparate fragments",
|
| 179 |
+
"Compose: Assemble conclusion from evidence",
|
| 180 |
+
"Predict: Estimate next actions",
|
| 181 |
+
"Verify: Check conclusion consistency",
|
| 182 |
+
],
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class SyntheticDataGenerator:
|
| 187 |
+
"""Generate synthetic Graph→Narrative training pairs.
|
| 188 |
+
|
| 189 |
+
This generator creates training data that follows the AAM
|
| 190 |
+
pattern: structured graph conditioning → natural language narrative.
|
| 191 |
+
|
| 192 |
+
The generated data covers:
|
| 193 |
+
- Various trigger types (questions, analysis requests)
|
| 194 |
+
- Different numbers of evidence nodes (1-50)
|
| 195 |
+
- Various anomaly patterns
|
| 196 |
+
- Different reasoning chain lengths
|
| 197 |
+
- Confidence distributions
|
| 198 |
+
- Both Indonesian and English
|
| 199 |
+
|
| 200 |
+
Usage:
|
| 201 |
+
generator = SyntheticDataGenerator()
|
| 202 |
+
examples = generator.generate(n=1000, language="id")
|
| 203 |
+
generator.save(examples, "training_data.jsonl")
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
seed: int = 42,
|
| 209 |
+
language: str = "id",
|
| 210 |
+
):
|
| 211 |
+
"""Initialize the synthetic data generator.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
seed: Random seed for reproducibility.
|
| 215 |
+
language: Default language for generation.
|
| 216 |
+
"""
|
| 217 |
+
self.seed = seed
|
| 218 |
+
self.language = language
|
| 219 |
+
random.seed(seed)
|
| 220 |
+
|
| 221 |
+
def generate(
|
| 222 |
+
self,
|
| 223 |
+
n: int = 1000,
|
| 224 |
+
language: Optional[str] = None,
|
| 225 |
+
min_evidence: int = 2,
|
| 226 |
+
max_evidence: int = 15,
|
| 227 |
+
anomaly_probability: float = 0.6,
|
| 228 |
+
reasoning_probability: float = 0.8,
|
| 229 |
+
) -> list[dict]:
|
| 230 |
+
"""Generate synthetic training examples.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
n: Number of examples to generate.
|
| 234 |
+
language: Language override.
|
| 235 |
+
min_evidence: Minimum evidence nodes per example.
|
| 236 |
+
max_evidence: Maximum evidence nodes per example.
|
| 237 |
+
anomaly_probability: Probability of including anomalies.
|
| 238 |
+
reasoning_probability: Probability of including reasoning steps.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
List of training example dictionaries.
|
| 242 |
+
"""
|
| 243 |
+
lang = language or self.language
|
| 244 |
+
templates = ID_TEMPLATES if lang == "id" else EN_TEMPLATES
|
| 245 |
+
evidence_pool = SAMPLE_EVIDENCE_NODES.get(lang, SAMPLE_EVIDENCE_NODES["en"])
|
| 246 |
+
trigger_pool = SAMPLE_TRIGGERS.get(lang, SAMPLE_TRIGGERS["en"])
|
| 247 |
+
anomaly_pool = SAMPLE_ANOMALIES.get(lang, SAMPLE_ANOMALIES["en"])
|
| 248 |
+
reasoning_pool = SAMPLE_REASONING_STEPS.get(lang, SAMPLE_REASONING_STEPS["en"])
|
| 249 |
+
|
| 250 |
+
examples = []
|
| 251 |
+
for _ in range(n):
|
| 252 |
+
# Random trigger
|
| 253 |
+
trigger = random.choice(trigger_pool)
|
| 254 |
+
|
| 255 |
+
# Random evidence nodes
|
| 256 |
+
n_evidence = random.randint(min_evidence, max_evidence)
|
| 257 |
+
evidence = random.sample(evidence_pool, min(n_evidence, len(evidence_pool)))
|
| 258 |
+
|
| 259 |
+
# Random confidence map
|
| 260 |
+
confidence_map = {
|
| 261 |
+
node: round(random.uniform(0.3, 1.0), 2)
|
| 262 |
+
for node in evidence
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Random anomalies
|
| 266 |
+
anomalies = []
|
| 267 |
+
if random.random() < anomaly_probability:
|
| 268 |
+
n_anomalies = random.randint(1, 3)
|
| 269 |
+
anomalies = random.sample(anomaly_pool, min(n_anomalies, len(anomaly_pool)))
|
| 270 |
+
|
| 271 |
+
# Random reasoning steps
|
| 272 |
+
reasoning_steps = []
|
| 273 |
+
if random.random() < reasoning_probability:
|
| 274 |
+
n_steps = random.randint(2, 6)
|
| 275 |
+
reasoning_steps = random.sample(reasoning_pool, min(n_steps, len(reasoning_pool)))
|
| 276 |
+
|
| 277 |
+
# Source trust
|
| 278 |
+
source_trust = round(random.uniform(0.5, 1.0), 2)
|
| 279 |
+
|
| 280 |
+
# Generate narrative from template
|
| 281 |
+
narrative = self._generate_narrative(
|
| 282 |
+
trigger=trigger,
|
| 283 |
+
evidence=evidence,
|
| 284 |
+
anomalies=anomalies,
|
| 285 |
+
reasoning_steps=reasoning_steps,
|
| 286 |
+
confidence_map=confidence_map,
|
| 287 |
+
templates=templates,
|
| 288 |
+
lang=lang,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
example = {
|
| 292 |
+
"narrative": narrative,
|
| 293 |
+
"trigger": trigger,
|
| 294 |
+
"evidence_nodes": evidence,
|
| 295 |
+
"compositions": [],
|
| 296 |
+
"confidence_map": confidence_map,
|
| 297 |
+
"anomalies": anomalies,
|
| 298 |
+
"reasoning_steps": reasoning_steps,
|
| 299 |
+
"source_trust": source_trust,
|
| 300 |
+
"language": lang,
|
| 301 |
+
"source": "synthetic",
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
examples.append(example)
|
| 305 |
+
|
| 306 |
+
logger.info("Generated %d synthetic examples (language=%s)", n, lang)
|
| 307 |
+
return examples
|
| 308 |
+
|
| 309 |
+
def _generate_narrative(
|
| 310 |
+
self,
|
| 311 |
+
trigger: str,
|
| 312 |
+
evidence: list[str],
|
| 313 |
+
anomalies: list[str],
|
| 314 |
+
reasoning_steps: list[str],
|
| 315 |
+
confidence_map: dict[str, float],
|
| 316 |
+
templates: dict,
|
| 317 |
+
lang: str,
|
| 318 |
+
) -> str:
|
| 319 |
+
"""Generate a narrative from templates.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
trigger: Trigger text.
|
| 323 |
+
evidence: Evidence node labels.
|
| 324 |
+
anomalies: Anomaly descriptions.
|
| 325 |
+
reasoning_steps: Reasoning step descriptions.
|
| 326 |
+
confidence_map: Confidence scores.
|
| 327 |
+
templates: Template dictionary.
|
| 328 |
+
lang: Language code.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Generated narrative string.
|
| 332 |
+
"""
|
| 333 |
+
# Build narrative parts
|
| 334 |
+
evidence_str = ", ".join(evidence[:5])
|
| 335 |
+
avg_confidence = sum(confidence_map.values()) / max(len(confidence_map), 1)
|
| 336 |
+
|
| 337 |
+
# Fill templates
|
| 338 |
+
evidence_summary = random.choice(templates["evidence_summary"]).format(
|
| 339 |
+
nodes=evidence_str
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
anomaly_summary = ""
|
| 343 |
+
if anomalies:
|
| 344 |
+
anomaly_summary = random.choice(templates["anomaly_summary"]).format(
|
| 345 |
+
anomalies="; ".join(anomalies[:3])
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
reasoning_summary = ""
|
| 349 |
+
if reasoning_steps:
|
| 350 |
+
reasoning_summary = random.choice(templates["reasoning_summary"]).format(
|
| 351 |
+
steps="; ".join(reasoning_steps[:4])
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Main narrative
|
| 355 |
+
narrative = random.choice(templates["analysis"]).format(
|
| 356 |
+
trigger=trigger,
|
| 357 |
+
evidence_summary=evidence_summary,
|
| 358 |
+
anomaly_summary=anomaly_summary,
|
| 359 |
+
reasoning_summary=reasoning_summary,
|
| 360 |
+
confidence_pct=f"{avg_confidence:.0%}",
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
return narrative
|
| 364 |
+
|
| 365 |
+
def save(
|
| 366 |
+
self,
|
| 367 |
+
examples: list[dict],
|
| 368 |
+
path: str | Path,
|
| 369 |
+
) -> None:
|
| 370 |
+
"""Save examples to JSONL file.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
examples: List of example dictionaries.
|
| 374 |
+
path: Output file path.
|
| 375 |
+
"""
|
| 376 |
+
path = Path(path)
|
| 377 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 380 |
+
for example in examples:
|
| 381 |
+
f.write(json.dumps(example, ensure_ascii=False) + "\n")
|
| 382 |
+
|
| 383 |
+
logger.info("Saved %d examples to %s", len(examples), path)
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def generate_training_split(
|
| 387 |
+
cls,
|
| 388 |
+
output_dir: str | Path,
|
| 389 |
+
n_train: int = 10000,
|
| 390 |
+
n_val: int = 500,
|
| 391 |
+
language: str = "id",
|
| 392 |
+
seed: int = 42,
|
| 393 |
+
) -> tuple[Path, Path]:
|
| 394 |
+
"""Generate and save train/val splits.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
output_dir: Output directory.
|
| 398 |
+
n_train: Number of training examples.
|
| 399 |
+
n_val: Number of validation examples.
|
| 400 |
+
language: Language for generation.
|
| 401 |
+
seed: Random seed.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Tuple of (train_path, val_path).
|
| 405 |
+
"""
|
| 406 |
+
output_dir = Path(output_dir)
|
| 407 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 408 |
+
|
| 409 |
+
generator = cls(seed=seed, language=language)
|
| 410 |
+
|
| 411 |
+
# Generate training data
|
| 412 |
+
train_examples = generator.generate(n=n_train, language=language)
|
| 413 |
+
train_path = output_dir / "train.jsonl"
|
| 414 |
+
generator.save(train_examples, train_path)
|
| 415 |
+
|
| 416 |
+
# Generate validation data (different seed)
|
| 417 |
+
val_generator = cls(seed=seed + 1, language=language)
|
| 418 |
+
val_examples = val_generator.generate(n=n_val, language=language)
|
| 419 |
+
val_path = output_dir / "val.jsonl"
|
| 420 |
+
val_generator.save(val_examples, val_path)
|
| 421 |
+
|
| 422 |
+
logger.info(
|
| 423 |
+
"Generated training split: %d train, %d val",
|
| 424 |
+
n_train, n_val,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
return train_path, val_path
|
diffusion_llm/inference/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference module for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.inference.generator import AamGenerator
|
| 4 |
+
|
| 5 |
+
__all__ = ["AamGenerator"]
|
diffusion_llm/inference/generator.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Inference Generator
|
| 3 |
+
|
| 4 |
+
Generates natural language narratives from graph conditioning
|
| 5 |
+
using the trained diffusion model.
|
| 6 |
+
|
| 7 |
+
The generation process:
|
| 8 |
+
1. Encode graph conditioning (evidence, anomalies, reasoning)
|
| 9 |
+
2. Start from pure noise in the latent space
|
| 10 |
+
3. Iteratively denoise for N steps
|
| 11 |
+
4. Convert denoised embeddings to token IDs
|
| 12 |
+
5. Detokenize to natural language text
|
| 13 |
+
|
| 14 |
+
Analogi: Seperti Jin Soun akhirnya "berbicara" — dari
|
| 15 |
+
pikiran yang kabur (noise) menjadi kata-kata yang jelas
|
| 16 |
+
(denoised narrative). Setiap langkah denoising = satu
|
| 17 |
+
langkah lebih dekat ke koherensi.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
import time
|
| 24 |
+
from dataclasses import dataclass, field
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, InferenceConfig
|
| 30 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 31 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class GenerationResult:
|
| 38 |
+
"""Result from a generation call.
|
| 39 |
+
|
| 40 |
+
Contains the generated narrative plus metadata about
|
| 41 |
+
how it was generated, for traceability.
|
| 42 |
+
"""
|
| 43 |
+
narrative: str
|
| 44 |
+
"""Generated narrative text."""
|
| 45 |
+
|
| 46 |
+
token_ids: list[int] = field(default_factory=list)
|
| 47 |
+
"""Generated token IDs."""
|
| 48 |
+
|
| 49 |
+
n_diffusion_steps: int = 0
|
| 50 |
+
"""Number of denoising steps used."""
|
| 51 |
+
|
| 52 |
+
generation_time_s: float = 0.0
|
| 53 |
+
"""Wall-clock generation time."""
|
| 54 |
+
|
| 55 |
+
model_name: str = ""
|
| 56 |
+
"""Name of the model used."""
|
| 57 |
+
|
| 58 |
+
evidence_used: list[str] = field(default_factory=list)
|
| 59 |
+
"""Evidence nodes that were provided as conditioning."""
|
| 60 |
+
|
| 61 |
+
confidence: float = 0.0
|
| 62 |
+
"""Overall confidence of the generation."""
|
| 63 |
+
|
| 64 |
+
language: str = "id"
|
| 65 |
+
"""Output language."""
|
| 66 |
+
|
| 67 |
+
def to_dict(self) -> dict:
|
| 68 |
+
"""Serialize to dictionary."""
|
| 69 |
+
return {
|
| 70 |
+
"narrative": self.narrative,
|
| 71 |
+
"n_diffusion_steps": self.n_diffusion_steps,
|
| 72 |
+
"generation_time_s": round(self.generation_time_s, 3),
|
| 73 |
+
"model_name": self.model_name,
|
| 74 |
+
"evidence_used": self.evidence_used,
|
| 75 |
+
"confidence": round(self.confidence, 3),
|
| 76 |
+
"language": self.language,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class AamGenerator:
|
| 81 |
+
"""Generate narratives from graph conditioning using the trained model.
|
| 82 |
+
|
| 83 |
+
This is the main inference interface. It takes graph-structured
|
| 84 |
+
data (from the RSVS Knowledge Graph) and produces natural
|
| 85 |
+
language narratives through the diffusion denoising process.
|
| 86 |
+
|
| 87 |
+
Usage:
|
| 88 |
+
# Load model and tokenizer
|
| 89 |
+
config = AamDiffusionConfig.from_json("config.json")
|
| 90 |
+
model = AamDiffusionModel.load("best.pt")
|
| 91 |
+
tokenizer = AamTokenizer.load("tokenizer.json")
|
| 92 |
+
|
| 93 |
+
# Create generator
|
| 94 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 95 |
+
|
| 96 |
+
# Generate narrative
|
| 97 |
+
result = generator.generate(
|
| 98 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 99 |
+
evidence_nodes=["hefei", "diancang", "ju_jangmok"],
|
| 100 |
+
anomalies=["no external pill consumption"],
|
| 101 |
+
reasoning_steps=["Diancang pair was in Hefei before theft"],
|
| 102 |
+
)
|
| 103 |
+
print(result.narrative)
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
model: Trained AamDiffusionModel.
|
| 107 |
+
tokenizer: Trained AamTokenizer.
|
| 108 |
+
config: AamDiffusionConfig with inference settings.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
model: AamDiffusionModel,
|
| 114 |
+
tokenizer: AamTokenizer,
|
| 115 |
+
config: AamDiffusionConfig,
|
| 116 |
+
):
|
| 117 |
+
self.model = model
|
| 118 |
+
self.tokenizer = tokenizer
|
| 119 |
+
self.config = config
|
| 120 |
+
self.inference_config = config.inference
|
| 121 |
+
|
| 122 |
+
# Device
|
| 123 |
+
self.device = next(model.parameters()).device
|
| 124 |
+
|
| 125 |
+
# Set model to eval mode
|
| 126 |
+
self.model.eval()
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def generate(
|
| 130 |
+
self,
|
| 131 |
+
trigger: str = "",
|
| 132 |
+
evidence_nodes: Optional[list[str]] = None,
|
| 133 |
+
compositions: Optional[list[str]] = None,
|
| 134 |
+
confidence_map: Optional[dict[str, float]] = None,
|
| 135 |
+
anomalies: Optional[list[str]] = None,
|
| 136 |
+
reasoning_steps: Optional[list[str]] = None,
|
| 137 |
+
source_trust: float = 1.0,
|
| 138 |
+
n_steps: Optional[int] = None,
|
| 139 |
+
temperature: Optional[float] = None,
|
| 140 |
+
language: Optional[str] = None,
|
| 141 |
+
max_sentences: Optional[int] = None,
|
| 142 |
+
) -> GenerationResult:
|
| 143 |
+
"""Generate a narrative from graph conditioning.
|
| 144 |
+
|
| 145 |
+
This is the main generation method. It:
|
| 146 |
+
1. Tokenizes the graph conditioning data
|
| 147 |
+
2. Encodes it through the graph encoder
|
| 148 |
+
3. Starts from noise and iteratively denoises
|
| 149 |
+
4. Converts the result to text
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
trigger: The trigger question or topic.
|
| 153 |
+
evidence_nodes: Evidence node descriptions.
|
| 154 |
+
compositions: Composition descriptions.
|
| 155 |
+
confidence_map: Node confidence scores.
|
| 156 |
+
anomalies: Anomaly descriptions.
|
| 157 |
+
reasoning_steps: Reasoning step descriptions.
|
| 158 |
+
source_trust: Source trust score.
|
| 159 |
+
n_steps: Override number of denoising steps.
|
| 160 |
+
temperature: Override sampling temperature.
|
| 161 |
+
language: Override output language.
|
| 162 |
+
max_sentences: Maximum sentences in output.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
GenerationResult with the narrative and metadata.
|
| 166 |
+
"""
|
| 167 |
+
start_time = time.time()
|
| 168 |
+
|
| 169 |
+
# Use config defaults if not overridden
|
| 170 |
+
n_steps = n_steps or self.inference_config.n_steps
|
| 171 |
+
temperature = temperature or self.inference_config.temperature
|
| 172 |
+
language = language or self.inference_config.language
|
| 173 |
+
max_sentences = max_sentences or self.inference_config.max_output_sentences
|
| 174 |
+
|
| 175 |
+
# --- Step 1: Tokenize graph conditioning ---
|
| 176 |
+
evidence_ids_tensor = None
|
| 177 |
+
evidence_conf_tensor = None
|
| 178 |
+
anomaly_ids_tensor = None
|
| 179 |
+
anomaly_conf_tensor = None
|
| 180 |
+
reasoning_ids_tensor = None
|
| 181 |
+
reasoning_conf_tensor = None
|
| 182 |
+
|
| 183 |
+
if evidence_nodes:
|
| 184 |
+
evidence_ids_list = []
|
| 185 |
+
evidence_conf_list = []
|
| 186 |
+
for node in evidence_nodes[:self.config.graph_encoder.max_evidence_nodes]:
|
| 187 |
+
ids = self.tokenizer.encode(node, add_special=False)
|
| 188 |
+
ids = self.tokenizer.pad_sequence(ids, 32)
|
| 189 |
+
evidence_ids_list.append(ids)
|
| 190 |
+
conf = (confidence_map or {}).get(node, 0.7)
|
| 191 |
+
evidence_conf_list.append(conf)
|
| 192 |
+
|
| 193 |
+
while len(evidence_ids_list) < self.config.graph_encoder.max_evidence_nodes:
|
| 194 |
+
evidence_ids_list.append([0] * 32)
|
| 195 |
+
evidence_conf_list.append(0.0)
|
| 196 |
+
|
| 197 |
+
evidence_ids_tensor = torch.tensor(
|
| 198 |
+
[evidence_ids_list], dtype=torch.long, device=self.device
|
| 199 |
+
)
|
| 200 |
+
evidence_conf_tensor = torch.tensor(
|
| 201 |
+
[evidence_conf_list], dtype=torch.float32, device=self.device
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if anomalies:
|
| 205 |
+
anomaly_ids_list = []
|
| 206 |
+
for anom in anomalies[:self.config.graph_encoder.max_anomalies]:
|
| 207 |
+
ids = self.tokenizer.encode(anom, add_special=False)
|
| 208 |
+
ids = self.tokenizer.pad_sequence(ids, 32)
|
| 209 |
+
anomaly_ids_list.append(ids)
|
| 210 |
+
|
| 211 |
+
while len(anomaly_ids_list) < self.config.graph_encoder.max_anomalies:
|
| 212 |
+
anomaly_ids_list.append([0] * 32)
|
| 213 |
+
|
| 214 |
+
anomaly_ids_tensor = torch.tensor(
|
| 215 |
+
[anomaly_ids_list], dtype=torch.long, device=self.device
|
| 216 |
+
)
|
| 217 |
+
anomaly_conf_tensor = torch.full(
|
| 218 |
+
(1, self.config.graph_encoder.max_anomalies),
|
| 219 |
+
0.6, dtype=torch.float32, device=self.device,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if reasoning_steps:
|
| 223 |
+
reasoning_ids_list = []
|
| 224 |
+
for step in reasoning_steps[:self.config.graph_encoder.max_reasoning_steps]:
|
| 225 |
+
ids = self.tokenizer.encode(step, add_special=False)
|
| 226 |
+
ids = self.tokenizer.pad_sequence(ids, 32)
|
| 227 |
+
reasoning_ids_list.append(ids)
|
| 228 |
+
|
| 229 |
+
while len(reasoning_ids_list) < self.config.graph_encoder.max_reasoning_steps:
|
| 230 |
+
reasoning_ids_list.append([0] * 32)
|
| 231 |
+
|
| 232 |
+
reasoning_ids_tensor = torch.tensor(
|
| 233 |
+
[reasoning_ids_list], dtype=torch.long, device=self.device
|
| 234 |
+
)
|
| 235 |
+
reasoning_conf_tensor = torch.full(
|
| 236 |
+
(1, self.config.graph_encoder.max_reasoning_steps),
|
| 237 |
+
0.7, dtype=torch.float32, device=self.device,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
source_trust_tensor = torch.tensor(
|
| 241 |
+
[source_trust], dtype=torch.float32, device=self.device
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# --- Step 2: Encode graph conditioning ---
|
| 245 |
+
graph_cond = self.model.graph_encoder(
|
| 246 |
+
evidence_ids=evidence_ids_tensor,
|
| 247 |
+
evidence_confidence=evidence_conf_tensor,
|
| 248 |
+
anomaly_ids=anomaly_ids_tensor,
|
| 249 |
+
anomaly_confidence=anomaly_conf_tensor,
|
| 250 |
+
reasoning_ids=reasoning_ids_tensor,
|
| 251 |
+
reasoning_confidence=reasoning_conf_tensor,
|
| 252 |
+
source_trust=source_trust_tensor,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# --- Step 3: Generate via diffusion denoising ---
|
| 256 |
+
shape = (
|
| 257 |
+
1,
|
| 258 |
+
self.config.model.max_seq_len,
|
| 259 |
+
self.config.model.d_model,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
denoised = self.model.sample(
|
| 263 |
+
graph_cond=graph_cond,
|
| 264 |
+
n_steps=n_steps,
|
| 265 |
+
method=self.config.diffusion.sampling_method,
|
| 266 |
+
shape=shape,
|
| 267 |
+
device=self.device,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# --- Step 4: Convert to tokens ---
|
| 271 |
+
token_ids = self.model.embeddings_to_tokens(
|
| 272 |
+
denoised, temperature=temperature,
|
| 273 |
+
top_k=self.inference_config.top_k,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# --- Step 5: Detokenize ---
|
| 277 |
+
token_list = token_ids[0].cpu().tolist()
|
| 278 |
+
narrative = self.tokenizer.decode(token_list, skip_special=True)
|
| 279 |
+
|
| 280 |
+
# Truncate to max sentences
|
| 281 |
+
if max_sentences:
|
| 282 |
+
sentences = self.tokenizer._split_sentences(narrative)
|
| 283 |
+
if len(sentences) > max_sentences:
|
| 284 |
+
narrative = ". ".join(sentences[:max_sentences]) + "."
|
| 285 |
+
|
| 286 |
+
generation_time = time.time() - start_time
|
| 287 |
+
|
| 288 |
+
# Compute average confidence
|
| 289 |
+
avg_confidence = source_trust
|
| 290 |
+
if confidence_map:
|
| 291 |
+
avg_confidence = sum(confidence_map.values()) / len(confidence_map)
|
| 292 |
+
|
| 293 |
+
return GenerationResult(
|
| 294 |
+
narrative=narrative,
|
| 295 |
+
token_ids=token_list,
|
| 296 |
+
n_diffusion_steps=n_steps,
|
| 297 |
+
generation_time_s=generation_time,
|
| 298 |
+
model_name=self.config.model_name,
|
| 299 |
+
evidence_used=evidence_nodes or [],
|
| 300 |
+
confidence=avg_confidence,
|
| 301 |
+
language=language,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def generate_batch(
|
| 305 |
+
self,
|
| 306 |
+
triggers: list[str],
|
| 307 |
+
evidence_nodes_list: Optional[list[list[str]]] = None,
|
| 308 |
+
anomalies_list: Optional[list[list[str]]] = None,
|
| 309 |
+
**kwargs,
|
| 310 |
+
) -> list[GenerationResult]:
|
| 311 |
+
"""Generate narratives for multiple triggers.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
triggers: List of trigger questions.
|
| 315 |
+
evidence_nodes_list: List of evidence node lists.
|
| 316 |
+
anomalies_list: List of anomaly lists.
|
| 317 |
+
**kwargs: Additional arguments passed to generate().
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
List of GenerationResult objects.
|
| 321 |
+
"""
|
| 322 |
+
results = []
|
| 323 |
+
for i, trigger in enumerate(triggers):
|
| 324 |
+
evidence = evidence_nodes_list[i] if evidence_nodes_list else None
|
| 325 |
+
anomalies = anomalies_list[i] if anomalies_list else None
|
| 326 |
+
result = self.generate(
|
| 327 |
+
trigger=trigger,
|
| 328 |
+
evidence_nodes=evidence,
|
| 329 |
+
anomalies=anomalies,
|
| 330 |
+
**kwargs,
|
| 331 |
+
)
|
| 332 |
+
results.append(result)
|
| 333 |
+
return results
|
diffusion_llm/model/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model components for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.model.noise_scheduler import NoiseScheduler
|
| 4 |
+
from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
|
| 5 |
+
from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
|
| 6 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"NoiseScheduler",
|
| 10 |
+
"GraphConditioningEncoder",
|
| 11 |
+
"DiffusionTransformer",
|
| 12 |
+
"AamDiffusionModel",
|
| 13 |
+
]
|
diffusion_llm/model/aam_diffusion_model.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Complete Model
|
| 3 |
+
|
| 4 |
+
Combines the Diffusion Transformer, Graph Encoder, and Noise Scheduler
|
| 5 |
+
into a single, unified model for training and inference.
|
| 6 |
+
|
| 7 |
+
This is the "body" of AAM — the specialized sentence composer that
|
| 8 |
+
takes graph conditioning as input and produces coherent narratives
|
| 9 |
+
through iterative denoising.
|
| 10 |
+
|
| 11 |
+
Architecture:
|
| 12 |
+
┌──────────────────────────────────────────────────┐
|
| 13 |
+
│ AAM Diffusion Model (The Body) │
|
| 14 |
+
│ │
|
| 15 |
+
│ Input: │
|
| 16 |
+
│ - Token IDs (text) │
|
| 17 |
+
│ - Graph conditioning (evidence, compositions, │
|
| 18 |
+
│ confidence, anomalies, reasoning chains) │
|
| 19 |
+
│ │
|
| 20 |
+
│ Training Process: │
|
| 21 |
+
│ 1. Tokenize text → embeddings │
|
| 22 |
+
│ 2. Sample random timestep t │
|
| 23 |
+
│ 3. Add noise: x_t = schedule.add_noise(x_0, t) │
|
| 24 |
+
│ 4. Encode graph conditioning │
|
| 25 |
+
│ 5. Predict noise: eps = transformer(x_t, t, c) │
|
| 26 |
+
│ 6. Compute loss: L = MSE(eps, eps_target) │
|
| 27 |
+
│ │
|
| 28 |
+
│ Inference Process: │
|
| 29 |
+
│ 1. Start from pure noise x_T │
|
| 30 |
+
│ 2. Encode graph conditioning │
|
| 31 |
+
│ 3. For t = T, T-1, ..., 1: │
|
| 32 |
+
│ a. Predict noise: eps = transformer(x_t, t) │
|
| 33 |
+
│ b. Denoise: x_{t-1} = schedule.step(eps) │
|
| 34 |
+
│ 4. Decode final x_0 → text tokens │
|
| 35 |
+
│ 5. Detokenize → natural language narrative │
|
| 36 |
+
│ │
|
| 37 |
+
│ Key Constraint: │
|
| 38 |
+
│ The model CANNOT generate information not │
|
| 39 |
+
│ present in the graph conditioning. It can only │
|
| 40 |
+
│ ARRANGE what the graph knows into sentences. │
|
| 41 |
+
│ │
|
| 42 |
+
│ Analogi: Jin Soun (mind/graph) + tubuhnya │
|
| 43 |
+
│ (this model). Tubuhnya hanya bisa mengucapkan │
|
| 44 |
+
│ apa yang dipikirkannya — tidak bisa mengarang. │
|
| 45 |
+
└──────────────────────────────────────────────────┘
|
| 46 |
+
|
| 47 |
+
Analogi: Ini adalah seluruh "tubuh" Jin Soun — bukan hanya
|
| 48 |
+
ototnya (transformer), tapi juga sistem saraf (graph encoder)
|
| 49 |
+
dan kemampuan untuk memperbaiki diri (diffusion denoising).
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
from __future__ import annotations
|
| 53 |
+
|
| 54 |
+
import logging
|
| 55 |
+
from typing import Optional
|
| 56 |
+
|
| 57 |
+
import torch
|
| 58 |
+
import torch.nn as nn
|
| 59 |
+
|
| 60 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig
|
| 61 |
+
from diffusion_llm.model.noise_scheduler import NoiseScheduler
|
| 62 |
+
from diffusion_llm.model.graph_encoder import GraphConditioningEncoder
|
| 63 |
+
from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
|
| 64 |
+
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AamDiffusionModel(nn.Module):
|
| 69 |
+
"""Complete AAM Diffusion LLM model.
|
| 70 |
+
|
| 71 |
+
Combines:
|
| 72 |
+
- DiffusionTransformer: Core denoising network
|
| 73 |
+
- GraphConditioningEncoder: Encodes graph structure for conditioning
|
| 74 |
+
- NoiseScheduler: Manages the diffusion process
|
| 75 |
+
|
| 76 |
+
This model is designed to be trained on Graph→Narrative pairs,
|
| 77 |
+
where the graph data comes from the RSVS Knowledge Graph and
|
| 78 |
+
the narrative is the target natural language output.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
config: AamDiffusionConfig with all hyperparameters.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, config: AamDiffusionConfig):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.config = config
|
| 87 |
+
|
| 88 |
+
# Core components
|
| 89 |
+
self.noise_scheduler = NoiseScheduler(
|
| 90 |
+
n_timesteps=config.diffusion.n_timesteps,
|
| 91 |
+
schedule_type=config.diffusion.schedule_type,
|
| 92 |
+
beta_start=config.diffusion.beta_start,
|
| 93 |
+
beta_end=config.diffusion.beta_end,
|
| 94 |
+
prediction_type=config.diffusion.prediction_type,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.graph_encoder = GraphConditioningEncoder(
|
| 98 |
+
config=config.graph_encoder,
|
| 99 |
+
vocab_size=config.model.vocab_size,
|
| 100 |
+
)
|
| 101 |
+
# Align graph encoder output dim with transformer's d_model
|
| 102 |
+
self.graph_encoder.set_output_dim(config.model.d_model)
|
| 103 |
+
|
| 104 |
+
self.transformer = DiffusionTransformer(config.model)
|
| 105 |
+
|
| 106 |
+
# Token-to-embedding projection (shared with transformer)
|
| 107 |
+
# The transformer's token_embedding is used for both
|
| 108 |
+
# encoding input text and decoding output text
|
| 109 |
+
|
| 110 |
+
# Output head: project from d_model to vocab_size
|
| 111 |
+
self.lm_head = nn.Linear(
|
| 112 |
+
config.model.d_model, config.model.vocab_size, bias=False
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Tie weights between token embedding and LM head
|
| 116 |
+
# This is standard practice and reduces parameter count
|
| 117 |
+
self.lm_head.weight = self.transformer.token_embedding.weight
|
| 118 |
+
|
| 119 |
+
# EMA model (for inference, updated during training)
|
| 120 |
+
self._ema_model: Optional[AamDiffusionModel] = None
|
| 121 |
+
self._ema_decay = config.training.ema_decay
|
| 122 |
+
|
| 123 |
+
logger.info(
|
| 124 |
+
"AamDiffusionModel initialized: %s params, %s",
|
| 125 |
+
self._format_params(self.get_num_params()),
|
| 126 |
+
config.model_name,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(
|
| 130 |
+
self,
|
| 131 |
+
token_ids: torch.Tensor,
|
| 132 |
+
timestep: torch.Tensor,
|
| 133 |
+
evidence_ids: Optional[torch.Tensor] = None,
|
| 134 |
+
evidence_confidence: Optional[torch.Tensor] = None,
|
| 135 |
+
evidence_timestamps: Optional[torch.Tensor] = None,
|
| 136 |
+
composition_ids: Optional[torch.Tensor] = None,
|
| 137 |
+
composition_confidence: Optional[torch.Tensor] = None,
|
| 138 |
+
anomaly_ids: Optional[torch.Tensor] = None,
|
| 139 |
+
anomaly_confidence: Optional[torch.Tensor] = None,
|
| 140 |
+
anomaly_timestamps: Optional[torch.Tensor] = None,
|
| 141 |
+
reasoning_ids: Optional[torch.Tensor] = None,
|
| 142 |
+
reasoning_confidence: Optional[torch.Tensor] = None,
|
| 143 |
+
source_trust: Optional[torch.Tensor] = None,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""Forward pass for training.
|
| 146 |
+
|
| 147 |
+
1. Get clean embeddings from token IDs
|
| 148 |
+
2. Add noise at the given timestep
|
| 149 |
+
3. Encode graph conditioning
|
| 150 |
+
4. Predict noise via transformer
|
| 151 |
+
5. Return predicted noise (loss computed externally)
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
token_ids: Clean text token IDs, shape (batch, seq_len).
|
| 155 |
+
timestep: Random timestep indices, shape (batch,).
|
| 156 |
+
evidence_ids: Evidence node token IDs.
|
| 157 |
+
evidence_confidence: Evidence confidence scores.
|
| 158 |
+
evidence_timestamps: Evidence timestamps.
|
| 159 |
+
composition_ids: Composition token IDs.
|
| 160 |
+
composition_confidence: Composition confidence.
|
| 161 |
+
anomaly_ids: Anomaly token IDs.
|
| 162 |
+
anomaly_confidence: Anomaly confidence.
|
| 163 |
+
anomaly_timestamps: Anomaly timestamps.
|
| 164 |
+
reasoning_ids: Reasoning step token IDs.
|
| 165 |
+
reasoning_confidence: Reasoning confidence.
|
| 166 |
+
source_trust: Source trust score.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Predicted noise tensor of shape (batch, seq_len, d_model).
|
| 170 |
+
"""
|
| 171 |
+
# Step 1: Get clean embeddings (x_0)
|
| 172 |
+
x_0 = self.transformer.token_embedding(token_ids)
|
| 173 |
+
|
| 174 |
+
# Step 2: Add noise
|
| 175 |
+
noise = torch.randn_like(x_0)
|
| 176 |
+
x_t = self.noise_scheduler.add_noise(x_0, noise, timestep)
|
| 177 |
+
|
| 178 |
+
# Step 3: Encode graph conditioning
|
| 179 |
+
batch_size = token_ids.shape[0]
|
| 180 |
+
graph_cond = self.graph_encoder(
|
| 181 |
+
evidence_ids=evidence_ids,
|
| 182 |
+
evidence_confidence=evidence_confidence,
|
| 183 |
+
evidence_timestamps=evidence_timestamps,
|
| 184 |
+
composition_ids=composition_ids,
|
| 185 |
+
composition_confidence=composition_confidence,
|
| 186 |
+
anomaly_ids=anomaly_ids,
|
| 187 |
+
anomaly_confidence=anomaly_confidence,
|
| 188 |
+
anomaly_timestamps=anomaly_timestamps,
|
| 189 |
+
reasoning_ids=reasoning_ids,
|
| 190 |
+
reasoning_confidence=reasoning_confidence,
|
| 191 |
+
source_trust=source_trust,
|
| 192 |
+
batch_size=batch_size,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Extract cross-attention keys/values from graph conditioning
|
| 196 |
+
graph_keys = graph_cond.get("keys")
|
| 197 |
+
graph_values = graph_cond.get("values")
|
| 198 |
+
|
| 199 |
+
# Step 4: Predict noise via transformer
|
| 200 |
+
predicted = self.transformer(
|
| 201 |
+
x_t=x_t,
|
| 202 |
+
t=timestep,
|
| 203 |
+
graph_keys=graph_keys,
|
| 204 |
+
graph_values=graph_values,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return predicted, noise
|
| 208 |
+
|
| 209 |
+
def compute_loss(
|
| 210 |
+
self,
|
| 211 |
+
predicted: torch.Tensor,
|
| 212 |
+
target: torch.Tensor,
|
| 213 |
+
timestep: torch.Tensor,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
"""Compute diffusion training loss.
|
| 216 |
+
|
| 217 |
+
Supports different loss types and weighting strategies.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
predicted: Model output (predicted noise/x0/v).
|
| 221 |
+
target: Target (actual noise/x0/v).
|
| 222 |
+
timestep: Timestep indices for loss weighting.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Scalar loss value.
|
| 226 |
+
"""
|
| 227 |
+
# Base loss
|
| 228 |
+
if self.config.diffusion.loss_type == "mse":
|
| 229 |
+
loss = nn.functional.mse_loss(predicted, target, reduction="none")
|
| 230 |
+
elif self.config.diffusion.loss_type == "mae":
|
| 231 |
+
loss = nn.functional.l1_loss(predicted, target, reduction="none")
|
| 232 |
+
elif self.config.diffusion.loss_type == "huber":
|
| 233 |
+
loss = nn.functional.smooth_l1_loss(predicted, target, reduction="none")
|
| 234 |
+
else:
|
| 235 |
+
raise ValueError(f"Unknown loss_type: {self.config.diffusion.loss_type}")
|
| 236 |
+
|
| 237 |
+
# Average over feature dimension
|
| 238 |
+
loss = loss.mean(dim=-1) # (batch, seq_len)
|
| 239 |
+
|
| 240 |
+
# Apply loss weighting
|
| 241 |
+
if self.config.diffusion.loss_weighting == "min_snr":
|
| 242 |
+
loss = self._apply_min_snr_weighting(loss, timestep)
|
| 243 |
+
elif self.config.diffusion.loss_weighting == "p2":
|
| 244 |
+
loss = self._apply_p2_weighting(loss, timestep)
|
| 245 |
+
|
| 246 |
+
# Average over sequence and batch
|
| 247 |
+
return loss.mean()
|
| 248 |
+
|
| 249 |
+
def _apply_min_snr_weighting(
|
| 250 |
+
self,
|
| 251 |
+
loss: torch.Tensor,
|
| 252 |
+
timestep: torch.Tensor,
|
| 253 |
+
gamma: float = 5.0,
|
| 254 |
+
) -> torch.Tensor:
|
| 255 |
+
"""Apply Min-SNR weighting strategy.
|
| 256 |
+
|
| 257 |
+
Weights the loss by min(SNR, gamma) / SNR, where
|
| 258 |
+
SNR = alpha_bar / (1 - alpha_bar).
|
| 259 |
+
|
| 260 |
+
This helps balance the loss across timesteps, preventing
|
| 261 |
+
high-noise steps from dominating.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
loss: Unweighted loss.
|
| 265 |
+
timestep: Timestep indices.
|
| 266 |
+
gamma: SNR clipping value.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Weighted loss.
|
| 270 |
+
"""
|
| 271 |
+
alpha_bar = self.noise_scheduler.alphas_cumprod.to(loss.device)
|
| 272 |
+
snr = alpha_bar[timestep] / (1 - alpha_bar[timestep] + 1e-8)
|
| 273 |
+
weight = torch.clamp(snr, max=gamma) / (snr + 1e-8)
|
| 274 |
+
# Expand weight to match loss shape
|
| 275 |
+
weight = weight.unsqueeze(-1).expand_as(loss)
|
| 276 |
+
return loss * weight
|
| 277 |
+
|
| 278 |
+
def _apply_p2_weighting(
|
| 279 |
+
self,
|
| 280 |
+
loss: torch.Tensor,
|
| 281 |
+
timestep: torch.Tensor,
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
"""Apply P2 weighting strategy.
|
| 284 |
+
|
| 285 |
+
weight = 1 / (SNR^gamma + k)
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
loss: Unweighted loss.
|
| 289 |
+
timestep: Timestep indices.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Weighted loss.
|
| 293 |
+
"""
|
| 294 |
+
alpha_bar = self.noise_scheduler.alphas_cumprod.to(loss.device)
|
| 295 |
+
snr = alpha_bar[timestep] / (1 - alpha_bar[timestep] + 1e-8)
|
| 296 |
+
gamma = self.config.diffusion.p2_gamma
|
| 297 |
+
k = self.config.diffusion.p2_k
|
| 298 |
+
weight = 1.0 / (snr ** gamma + k)
|
| 299 |
+
weight = weight.unsqueeze(-1).expand_as(loss)
|
| 300 |
+
return loss * weight
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def sample(
|
| 304 |
+
self,
|
| 305 |
+
graph_cond: dict[str, torch.Tensor],
|
| 306 |
+
n_steps: Optional[int] = None,
|
| 307 |
+
method: str = "ddim",
|
| 308 |
+
shape: Optional[tuple[int, ...]] = None,
|
| 309 |
+
device: Optional[torch.device] = None,
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
"""Generate samples via iterative denoising.
|
| 312 |
+
|
| 313 |
+
This is the INFERENCE method — start from pure noise and
|
| 314 |
+
iteratively denoise to produce coherent text embeddings.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
graph_cond: Graph conditioning dict from GraphConditioningEncoder.
|
| 318 |
+
n_steps: Number of denoising steps. Uses config if None.
|
| 319 |
+
method: Sampling method ('ddpm' or 'ddim').
|
| 320 |
+
shape: Shape of the output (batch, seq_len, d_model).
|
| 321 |
+
device: Device to generate on.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Denoised embeddings of shape (batch, seq_len, d_model).
|
| 325 |
+
"""
|
| 326 |
+
if n_steps is None:
|
| 327 |
+
n_steps = self.config.diffusion.n_inference_steps
|
| 328 |
+
if device is None:
|
| 329 |
+
device = next(self.parameters()).device
|
| 330 |
+
if shape is None:
|
| 331 |
+
shape = (1, self.config.model.max_seq_len, self.config.model.d_model)
|
| 332 |
+
|
| 333 |
+
# Start from pure noise
|
| 334 |
+
x = torch.randn(shape, device=device)
|
| 335 |
+
|
| 336 |
+
# Get graph conditioning
|
| 337 |
+
graph_keys = graph_cond.get("keys")
|
| 338 |
+
graph_values = graph_cond.get("values")
|
| 339 |
+
|
| 340 |
+
if method == "ddpm":
|
| 341 |
+
# Full DDPM sampling
|
| 342 |
+
for t in reversed(range(self.config.diffusion.n_timesteps)):
|
| 343 |
+
t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
|
| 344 |
+
predicted = self.transformer(
|
| 345 |
+
x_t=x, t=t_tensor,
|
| 346 |
+
graph_keys=graph_keys,
|
| 347 |
+
graph_values=graph_values,
|
| 348 |
+
)
|
| 349 |
+
x = self.noise_scheduler.step_ddpm(predicted, x, t_tensor)
|
| 350 |
+
|
| 351 |
+
elif method == "ddim":
|
| 352 |
+
# Fast DDIM sampling
|
| 353 |
+
timesteps = self.noise_scheduler.get_timestep_schedule(n_steps)
|
| 354 |
+
for i in range(len(timesteps) - 1):
|
| 355 |
+
t = timesteps[i]
|
| 356 |
+
t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else 0
|
| 357 |
+
|
| 358 |
+
t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
|
| 359 |
+
predicted = self.transformer(
|
| 360 |
+
x_t=x, t=t_tensor,
|
| 361 |
+
graph_keys=graph_keys,
|
| 362 |
+
graph_values=graph_values,
|
| 363 |
+
)
|
| 364 |
+
x = self.noise_scheduler.step_ddim(
|
| 365 |
+
predicted, x, t, t_prev,
|
| 366 |
+
eta=self.config.diffusion.eta_ddim,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
def embeddings_to_tokens(
|
| 372 |
+
self,
|
| 373 |
+
embeddings: torch.Tensor,
|
| 374 |
+
temperature: float = 1.0,
|
| 375 |
+
top_k: int = 50,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
"""Convert continuous embeddings to discrete token IDs.
|
| 378 |
+
|
| 379 |
+
This is the final step of generation — project embeddings
|
| 380 |
+
to vocabulary logits and sample tokens.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
embeddings: Denoised embeddings of shape (batch, seq_len, d_model).
|
| 384 |
+
temperature: Sampling temperature.
|
| 385 |
+
top_k: Top-k sampling cutoff.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Token IDs of shape (batch, seq_len).
|
| 389 |
+
"""
|
| 390 |
+
logits = self.lm_head(embeddings) / temperature
|
| 391 |
+
|
| 392 |
+
# Top-k sampling
|
| 393 |
+
if top_k > 0:
|
| 394 |
+
top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
|
| 395 |
+
probs = torch.softmax(top_k_values, dim=-1)
|
| 396 |
+
sampled_indices = torch.multinomial(
|
| 397 |
+
probs.view(-1, top_k), 1
|
| 398 |
+
).view(logits.shape[0], logits.shape[1])
|
| 399 |
+
token_ids = top_k_indices.gather(
|
| 400 |
+
-1, sampled_indices.unsqueeze(-1)
|
| 401 |
+
).squeeze(-1)
|
| 402 |
+
else:
|
| 403 |
+
probs = torch.softmax(logits, dim=-1)
|
| 404 |
+
token_ids = torch.argmax(logits, dim=-1)
|
| 405 |
+
|
| 406 |
+
return token_ids
|
| 407 |
+
|
| 408 |
+
def get_num_params(self) -> int:
|
| 409 |
+
"""Get total number of parameters."""
|
| 410 |
+
return sum(p.numel() for p in self.parameters())
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def _format_params(n: int) -> str:
|
| 414 |
+
"""Format parameter count for display."""
|
| 415 |
+
if n >= 1e9:
|
| 416 |
+
return f"{n / 1e9:.1f}B"
|
| 417 |
+
elif n >= 1e6:
|
| 418 |
+
return f"{n / 1e6:.1f}M"
|
| 419 |
+
elif n >= 1e3:
|
| 420 |
+
return f"{n / 1e3:.1f}K"
|
| 421 |
+
return str(n)
|
| 422 |
+
|
| 423 |
+
def save(self, path: str) -> None:
|
| 424 |
+
"""Save model checkpoint.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
path: Output file path.
|
| 428 |
+
"""
|
| 429 |
+
torch.save({
|
| 430 |
+
"model_state_dict": self.state_dict(),
|
| 431 |
+
"config": self.config.to_dict(),
|
| 432 |
+
}, path)
|
| 433 |
+
logger.info("Model saved to %s", path)
|
| 434 |
+
|
| 435 |
+
@classmethod
|
| 436 |
+
def load(cls, path: str, device: str = "cpu") -> AamDiffusionModel:
|
| 437 |
+
"""Load model from checkpoint.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
path: Checkpoint file path.
|
| 441 |
+
device: Device to load to.
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
Loaded AamDiffusionModel.
|
| 445 |
+
"""
|
| 446 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 447 |
+
config_dict = checkpoint.get("config", {})
|
| 448 |
+
if isinstance(config_dict, dict):
|
| 449 |
+
config = AamDiffusionConfig()
|
| 450 |
+
# Try to reconstruct config from dict
|
| 451 |
+
try:
|
| 452 |
+
from diffusion_llm.config.model_config import (
|
| 453 |
+
ModelConfig, DiffusionConfig, GraphEncoderConfig,
|
| 454 |
+
TokenizerConfig, TrainingConfig, InferenceConfig,
|
| 455 |
+
)
|
| 456 |
+
config = AamDiffusionConfig(
|
| 457 |
+
model=ModelConfig(**config_dict.get("model", {})),
|
| 458 |
+
diffusion=DiffusionConfig(**config_dict.get("diffusion", {})),
|
| 459 |
+
graph_encoder=GraphEncoderConfig(**config_dict.get("graph_encoder", {})),
|
| 460 |
+
tokenizer=TokenizerConfig(**config_dict.get("tokenizer", {})),
|
| 461 |
+
training=TrainingConfig(**config_dict.get("training", {})),
|
| 462 |
+
inference=InferenceConfig(**config_dict.get("inference", {})),
|
| 463 |
+
model_name=config_dict.get("model_name", "aam-diffusion-v0.1"),
|
| 464 |
+
output_dir=config_dict.get("output_dir", "./output"),
|
| 465 |
+
seed=config_dict.get("seed", 42),
|
| 466 |
+
)
|
| 467 |
+
except Exception:
|
| 468 |
+
logger.warning("Could not reconstruct config from checkpoint, using defaults")
|
| 469 |
+
else:
|
| 470 |
+
config = config_dict
|
| 471 |
+
model = cls(config)
|
| 472 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 473 |
+
model.to(device)
|
| 474 |
+
logger.info("Model loaded from %s", path)
|
| 475 |
+
return model
|
diffusion_llm/model/diffusion_transformer.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Diffusion Transformer (Denoiser)
|
| 3 |
+
|
| 4 |
+
The core denoising network. Takes noisy text embeddings and graph
|
| 5 |
+
conditioning, and predicts the noise (or clean data) at each
|
| 6 |
+
diffusion timestep.
|
| 7 |
+
|
| 8 |
+
Architecture:
|
| 9 |
+
Input: Noisy embeddings x_t + timestep t + graph conditioning
|
| 10 |
+
Output: Predicted noise epsilon (or x_0 or v)
|
| 11 |
+
|
| 12 |
+
The transformer uses:
|
| 13 |
+
- Self-attention over the text sequence
|
| 14 |
+
- Cross-attention to graph conditioning (evidence, anomalies, etc.)
|
| 15 |
+
- Timestep embedding (sinusoidal) injected via adaptive layer norm
|
| 16 |
+
- Optional flash attention for efficiency
|
| 17 |
+
|
| 18 |
+
This is the "brainstem" of the body — the core computation that
|
| 19 |
+
transforms noisy signals into coherent patterns.
|
| 20 |
+
|
| 21 |
+
Analogi: Seperti otot Jin Soun yang merespons sinyal dari otak —
|
| 22 |
+
model ini menerima "sinyal noise" dan "instruksi dari graph",
|
| 23 |
+
lalu mengubahnya menjadi gerakan yang koheren (kalimat).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
from diffusion_llm.config.model_config import ModelConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SinusoidalTimestepEmbedding(nn.Module):
|
| 39 |
+
"""Sinusoidal embedding for diffusion timesteps.
|
| 40 |
+
|
| 41 |
+
Maps integer timesteps to d_model-dimensional vectors using
|
| 42 |
+
sinusoidal position encoding, similar to Transformers.
|
| 43 |
+
|
| 44 |
+
This allows the model to know "how noisy" the current input is,
|
| 45 |
+
which is essential for the denoising process.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, d_model: int, max_period: int = 10000):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.d_model = d_model
|
| 51 |
+
self.max_period = max_period
|
| 52 |
+
|
| 53 |
+
# Two-layer MLP to project sinusoidal features
|
| 54 |
+
self.mlp = nn.Sequential(
|
| 55 |
+
nn.Linear(d_model, d_model * 4),
|
| 56 |
+
nn.GELU(),
|
| 57 |
+
nn.Linear(d_model * 4, d_model),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""Embed timesteps.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
t: Timestep indices of shape (batch,).
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Timestep embeddings of shape (batch, d_model).
|
| 68 |
+
"""
|
| 69 |
+
device = t.device
|
| 70 |
+
half_dim = self.d_model // 2
|
| 71 |
+
emb = math.log(self.max_period) / (half_dim - 1)
|
| 72 |
+
emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
|
| 73 |
+
emb = t.float().unsqueeze(-1) * emb.unsqueeze(0)
|
| 74 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 75 |
+
|
| 76 |
+
if emb.shape[-1] < self.d_model:
|
| 77 |
+
emb = F.pad(emb, (0, self.d_model - emb.shape[-1]))
|
| 78 |
+
|
| 79 |
+
return self.mlp(emb)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class AdaptiveLayerNorm(nn.Module):
|
| 83 |
+
"""Adaptive Layer Normalization conditioned on timestep.
|
| 84 |
+
|
| 85 |
+
Instead of fixed scale/shift parameters, this layer norm
|
| 86 |
+
uses the timestep embedding to produce scale and shift:
|
| 87 |
+
|
| 88 |
+
y = (1 + scale(t)) * norm(x) + shift(t)
|
| 89 |
+
|
| 90 |
+
This allows the model to behave differently at different
|
| 91 |
+
noise levels — more "creative" at high noise, more
|
| 92 |
+
"precise" at low noise.
|
| 93 |
+
|
| 94 |
+
Analogi: Jin Soun menyesuaikan intensitas pikirannya
|
| 95 |
+
berdasarkan seberapa kabur situasinya — semakin kabur,
|
| 96 |
+
semakin "kreatif" pendekatannya.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, d_model: int, eps: float = 1e-6):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.norm = nn.LayerNorm(d_model, elementwise_affine=False, eps=eps)
|
| 102 |
+
self.scale_proj = nn.Linear(d_model, d_model)
|
| 103 |
+
self.shift_proj = nn.Linear(d_model, d_model)
|
| 104 |
+
|
| 105 |
+
# Initialize shift to zero, scale to one
|
| 106 |
+
nn.init.zeros_(self.shift_proj.weight)
|
| 107 |
+
nn.init.zeros_(self.shift_proj.bias)
|
| 108 |
+
nn.init.ones_(self.scale_proj.weight)
|
| 109 |
+
nn.init.zeros_(self.scale_proj.bias)
|
| 110 |
+
|
| 111 |
+
def forward(
|
| 112 |
+
self,
|
| 113 |
+
x: torch.Tensor,
|
| 114 |
+
timestep_emb: torch.Tensor,
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
"""Apply adaptive layer norm.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
x: Input tensor of shape (batch, seq_len, d_model).
|
| 120 |
+
timestep_emb: Timestep embedding of shape (batch, d_model).
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Normalized and modulated tensor.
|
| 124 |
+
"""
|
| 125 |
+
normalized = self.norm(x)
|
| 126 |
+
scale = (1 + self.scale_proj(timestep_emb)).unsqueeze(1)
|
| 127 |
+
shift = self.shift_proj(timestep_emb).unsqueeze(1)
|
| 128 |
+
return normalized * scale + shift
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class TransformerBlock(nn.Module):
|
| 132 |
+
"""Single transformer block with self-attention, cross-attention, and FFN.
|
| 133 |
+
|
| 134 |
+
The block structure:
|
| 135 |
+
1. Adaptive Layer Norm + Self-Attention
|
| 136 |
+
2. Adaptive Layer Norm + Cross-Attention (to graph conditioning)
|
| 137 |
+
3. Adaptive Layer Norm + Feed-Forward Network
|
| 138 |
+
|
| 139 |
+
Each sub-layer has a residual connection.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
d_model: int,
|
| 145 |
+
n_heads: int,
|
| 146 |
+
d_ff: int,
|
| 147 |
+
dropout: float = 0.1,
|
| 148 |
+
norm_eps: float = 1e-6,
|
| 149 |
+
norm_type: str = "rmsnorm",
|
| 150 |
+
use_flash_attention: bool = True,
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.d_model = d_model
|
| 154 |
+
self.n_heads = n_heads
|
| 155 |
+
|
| 156 |
+
# Norms
|
| 157 |
+
NormClass = nn.RMSNorm if norm_type == "rmsnorm" else nn.LayerNorm
|
| 158 |
+
|
| 159 |
+
# Self-attention
|
| 160 |
+
self.self_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
|
| 161 |
+
self.self_attn = nn.MultiheadAttention(
|
| 162 |
+
embed_dim=d_model,
|
| 163 |
+
num_heads=n_heads,
|
| 164 |
+
dropout=dropout,
|
| 165 |
+
batch_first=True,
|
| 166 |
+
)
|
| 167 |
+
self.self_attn_dropout = nn.Dropout(dropout)
|
| 168 |
+
|
| 169 |
+
# Cross-attention (to graph conditioning)
|
| 170 |
+
self.cross_attn_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
|
| 171 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 172 |
+
embed_dim=d_model,
|
| 173 |
+
num_heads=n_heads,
|
| 174 |
+
dropout=dropout,
|
| 175 |
+
batch_first=True,
|
| 176 |
+
kdim=d_model,
|
| 177 |
+
vdim=d_model,
|
| 178 |
+
)
|
| 179 |
+
self.cross_attn_dropout = nn.Dropout(dropout)
|
| 180 |
+
|
| 181 |
+
# Feed-forward
|
| 182 |
+
self.ff_norm = AdaptiveLayerNorm(d_model, eps=norm_eps)
|
| 183 |
+
self.ff = nn.Sequential(
|
| 184 |
+
nn.Linear(d_model, d_ff),
|
| 185 |
+
nn.GELU(),
|
| 186 |
+
nn.Dropout(dropout),
|
| 187 |
+
nn.Linear(d_ff, d_model),
|
| 188 |
+
nn.Dropout(dropout),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Layer scales (optional, helps with deep networks)
|
| 192 |
+
self.self_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
|
| 193 |
+
self.cross_attn_scale = nn.Parameter(torch.ones(1) * 0.1)
|
| 194 |
+
self.ff_scale = nn.Parameter(torch.ones(1) * 0.1)
|
| 195 |
+
|
| 196 |
+
def forward(
|
| 197 |
+
self,
|
| 198 |
+
x: torch.Tensor,
|
| 199 |
+
timestep_emb: torch.Tensor,
|
| 200 |
+
graph_keys: Optional[torch.Tensor] = None,
|
| 201 |
+
graph_values: Optional[torch.Tensor] = None,
|
| 202 |
+
causal_mask: Optional[torch.Tensor] = None,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
"""Forward pass.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
x: Input sequence of shape (batch, seq_len, d_model).
|
| 208 |
+
timestep_emb: Timestep embedding of shape (batch, d_model).
|
| 209 |
+
graph_keys: Graph conditioning keys for cross-attention,
|
| 210 |
+
shape (batch, n_graph_nodes, d_model).
|
| 211 |
+
graph_values: Graph conditioning values for cross-attention,
|
| 212 |
+
shape (batch, n_graph_nodes, d_model).
|
| 213 |
+
causal_mask: Optional causal mask for self-attention.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Output sequence of shape (batch, seq_len, d_model).
|
| 217 |
+
"""
|
| 218 |
+
# 1. Self-attention with adaptive layer norm
|
| 219 |
+
normed = self.self_attn_norm(x, timestep_emb)
|
| 220 |
+
attn_out, _ = self.self_attn(
|
| 221 |
+
normed, normed, normed,
|
| 222 |
+
attn_mask=causal_mask,
|
| 223 |
+
need_weights=False,
|
| 224 |
+
)
|
| 225 |
+
x = x + self.self_attn_scale * self.self_attn_dropout(attn_out)
|
| 226 |
+
|
| 227 |
+
# 2. Cross-attention to graph conditioning (if available)
|
| 228 |
+
if graph_keys is not None and graph_values is not None:
|
| 229 |
+
normed = self.cross_attn_norm(x, timestep_emb)
|
| 230 |
+
cross_out, _ = self.cross_attn(
|
| 231 |
+
normed, graph_keys, graph_values,
|
| 232 |
+
need_weights=False,
|
| 233 |
+
)
|
| 234 |
+
x = x + self.cross_attn_scale * self.cross_attn_dropout(cross_out)
|
| 235 |
+
|
| 236 |
+
# 3. Feed-forward with adaptive layer norm
|
| 237 |
+
normed = self.ff_norm(x, timestep_emb)
|
| 238 |
+
ff_out = self.ff(normed)
|
| 239 |
+
x = x + self.ff_scale * ff_out
|
| 240 |
+
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class DiffusionTransformer(nn.Module):
|
| 245 |
+
"""Diffusion Transformer — the core denoising network for AAM.
|
| 246 |
+
|
| 247 |
+
This transformer takes:
|
| 248 |
+
- Noisy text embeddings (x_t)
|
| 249 |
+
- Diffusion timestep (t)
|
| 250 |
+
- Graph conditioning (evidence, anomalies, reasoning chains)
|
| 251 |
+
|
| 252 |
+
And predicts the noise that was added (or the clean data,
|
| 253 |
+
depending on prediction_type).
|
| 254 |
+
|
| 255 |
+
Architecture Overview:
|
| 256 |
+
┌────────────────────────────────────────────────┐
|
| 257 |
+
│ Input Embedding: x_t (noisy) → embedding │
|
| 258 |
+
│ + Positional Encoding (RoPE or learned) │
|
| 259 |
+
│ │
|
| 260 |
+
│ N x TransformerBlock: │
|
| 261 |
+
│ ├─ AdaLN + Self-Attention │
|
| 262 |
+
│ ├─ AdaLN + Cross-Attention (to graph) │
|
| 263 |
+
│ └─ AdaLN + Feed-Forward │
|
| 264 |
+
│ │
|
| 265 |
+
│ Output Projection: → predicted noise │
|
| 266 |
+
└────────────────────────────────────────────────┘
|
| 267 |
+
|
| 268 |
+
Key Features:
|
| 269 |
+
- Adaptive Layer Norm: timestep-conditioned normalization
|
| 270 |
+
- Cross-Attention: graph conditioning guides generation
|
| 271 |
+
- Layer Scales: helps training deep networks
|
| 272 |
+
- RoPE: better length generalization than learned positions
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
config: ModelConfig with architecture hyperparameters.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, config: ModelConfig):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.config = config
|
| 281 |
+
|
| 282 |
+
# Input embedding (from token IDs to d_model)
|
| 283 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 284 |
+
|
| 285 |
+
# Timestep embedding
|
| 286 |
+
self.timestep_embedding = SinusoidalTimestepEmbedding(config.d_model)
|
| 287 |
+
|
| 288 |
+
# Positional encoding
|
| 289 |
+
if config.pos_encoding_type == "learned":
|
| 290 |
+
self.position_embedding = nn.Embedding(
|
| 291 |
+
config.max_seq_len, config.d_model
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
# RoPE is applied inside attention (no separate embedding)
|
| 295 |
+
self.position_embedding = None
|
| 296 |
+
|
| 297 |
+
# Transformer blocks
|
| 298 |
+
self.blocks = nn.ModuleList([
|
| 299 |
+
TransformerBlock(
|
| 300 |
+
d_model=config.d_model,
|
| 301 |
+
n_heads=config.n_heads,
|
| 302 |
+
d_ff=config.d_ff,
|
| 303 |
+
dropout=config.dropout,
|
| 304 |
+
norm_eps=config.norm_eps,
|
| 305 |
+
norm_type=config.norm_type,
|
| 306 |
+
use_flash_attention=config.use_flash_attention,
|
| 307 |
+
)
|
| 308 |
+
for _ in range(config.n_layers)
|
| 309 |
+
])
|
| 310 |
+
|
| 311 |
+
# Final norm
|
| 312 |
+
NormClass = nn.RMSNorm if config.norm_type == "rmsnorm" else nn.LayerNorm
|
| 313 |
+
self.final_norm = NormClass(config.d_model, eps=config.norm_eps)
|
| 314 |
+
|
| 315 |
+
# Output projection (predict noise/x0/v)
|
| 316 |
+
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
| 317 |
+
|
| 318 |
+
# Initialize weights
|
| 319 |
+
self.apply(self._init_weights)
|
| 320 |
+
|
| 321 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 322 |
+
"""Initialize weights with Xavier/GPT-2 style."""
|
| 323 |
+
if isinstance(module, nn.Linear):
|
| 324 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 325 |
+
if module.bias is not None:
|
| 326 |
+
torch.nn.init.zeros_(module.bias)
|
| 327 |
+
elif isinstance(module, nn.Embedding):
|
| 328 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
x_t: torch.Tensor,
|
| 333 |
+
t: torch.Tensor,
|
| 334 |
+
token_ids: Optional[torch.Tensor] = None,
|
| 335 |
+
graph_keys: Optional[torch.Tensor] = None,
|
| 336 |
+
graph_values: Optional[torch.Tensor] = None,
|
| 337 |
+
) -> torch.Tensor:
|
| 338 |
+
"""Forward pass: predict noise given noisy input and timestep.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
x_t: Noisy text embeddings of shape (batch, seq_len, d_model).
|
| 342 |
+
If None, token_ids must be provided.
|
| 343 |
+
t: Timestep indices of shape (batch,).
|
| 344 |
+
token_ids: Token IDs of shape (batch, seq_len).
|
| 345 |
+
Used to create embeddings if x_t is not provided directly.
|
| 346 |
+
In training, x_t comes from the noise scheduler.
|
| 347 |
+
graph_keys: Graph conditioning keys for cross-attention,
|
| 348 |
+
shape (batch, n_graph_nodes, d_model).
|
| 349 |
+
graph_values: Graph conditioning values for cross-attention,
|
| 350 |
+
shape (batch, n_graph_nodes, d_model).
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Predicted noise of shape (batch, seq_len, d_model).
|
| 354 |
+
"""
|
| 355 |
+
# Get input embeddings
|
| 356 |
+
if x_t is None and token_ids is not None:
|
| 357 |
+
# Create embeddings from token IDs (used for initial x_0)
|
| 358 |
+
h = self.token_embedding(token_ids)
|
| 359 |
+
elif x_t is not None:
|
| 360 |
+
h = x_t
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError("Either x_t or token_ids must be provided")
|
| 363 |
+
|
| 364 |
+
# Add positional encoding
|
| 365 |
+
if self.position_embedding is not None:
|
| 366 |
+
seq_len = h.shape[1]
|
| 367 |
+
positions = torch.arange(seq_len, device=h.device).unsqueeze(0)
|
| 368 |
+
h = h + self.position_embedding(positions)
|
| 369 |
+
|
| 370 |
+
# Embed timestep
|
| 371 |
+
t_emb = self.timestep_embedding(t)
|
| 372 |
+
|
| 373 |
+
# Pass through transformer blocks
|
| 374 |
+
for block in self.blocks:
|
| 375 |
+
h = block(
|
| 376 |
+
h,
|
| 377 |
+
timestep_emb=t_emb,
|
| 378 |
+
graph_keys=graph_keys,
|
| 379 |
+
graph_values=graph_values,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Final norm and projection
|
| 383 |
+
h = self.final_norm(h)
|
| 384 |
+
output = self.output_proj(h)
|
| 385 |
+
|
| 386 |
+
return output
|
| 387 |
+
|
| 388 |
+
def get_num_params(self) -> int:
|
| 389 |
+
"""Get total number of parameters."""
|
| 390 |
+
return sum(p.numel() for p in self.parameters())
|
| 391 |
+
|
| 392 |
+
def get_num_trainable_params(self) -> int:
|
| 393 |
+
"""Get number of trainable parameters."""
|
| 394 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
diffusion_llm/model/graph_encoder.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Graph Conditioning Encoder
|
| 3 |
+
|
| 4 |
+
Encodes structured graph data into a conditioning vector that guides
|
| 5 |
+
the diffusion process. This is the KEY differentiator from general LLMs:
|
| 6 |
+
the model is conditioned on GRAPH STRUCTURE, not just text prompts.
|
| 7 |
+
|
| 8 |
+
The graph encoder takes:
|
| 9 |
+
- Evidence nodes (what the graph knows)
|
| 10 |
+
- Compositions (how concepts compose)
|
| 11 |
+
- Confidence scores (how sure the graph is)
|
| 12 |
+
- Anomalies (what doesn't fit)
|
| 13 |
+
- Reasoning chains (how the graph reached conclusions)
|
| 14 |
+
- Temporal context (when events happened)
|
| 15 |
+
|
| 16 |
+
And produces a conditioning representation that the diffusion model
|
| 17 |
+
uses to guide denoising.
|
| 18 |
+
|
| 19 |
+
Analogi: Seperti otak Jin Soun mengirimkan sinyal ke pita suaranya —
|
| 20 |
+
graph memberi "tahu" apa yang harus dikatakan, dan encoder ini
|
| 21 |
+
menerjemahkan "pengetahuan graph" menjadi "instruksi untuk tubuh".
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import math
|
| 27 |
+
from typing import Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
|
| 33 |
+
from diffusion_llm.config.model_config import GraphEncoderConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ConfidenceEmbedding(nn.Module):
|
| 37 |
+
"""Embed confidence scores as continuous values.
|
| 38 |
+
|
| 39 |
+
Maps [0, 1] confidence scores to d_graph-dimensional vectors
|
| 40 |
+
using sinusoidal encoding for smooth interpolation.
|
| 41 |
+
|
| 42 |
+
Analogi: Jin Soun tahu bedanya "aku yakin 100%" vs "mungkin 60%"
|
| 43 |
+
— encoding ini mengajarkan model membedakan juga.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, d_graph: int):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.d_graph = d_graph
|
| 49 |
+
# Learnable projection from scalar to d_graph
|
| 50 |
+
self.projection = nn.Sequential(
|
| 51 |
+
nn.Linear(1, d_graph // 4),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Linear(d_graph // 4, d_graph),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, confidence: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""Embed confidence scores.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
confidence: Tensor of shape (..., 1) with values in [0, 1].
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tensor of shape (..., d_graph).
|
| 64 |
+
"""
|
| 65 |
+
if confidence.dim() == 0:
|
| 66 |
+
confidence = confidence.unsqueeze(0)
|
| 67 |
+
if confidence.dim() == 1:
|
| 68 |
+
confidence = confidence.unsqueeze(-1)
|
| 69 |
+
return self.projection(confidence)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TemporalEmbedding(nn.Module):
|
| 73 |
+
"""Embed temporal context as position-aware vectors.
|
| 74 |
+
|
| 75 |
+
Uses sinusoidal positional encoding adapted for timestamps,
|
| 76 |
+
allowing the model to understand time-based relationships.
|
| 77 |
+
|
| 78 |
+
Analogi: Jin Soun mengingat bahwa "kejadian A terjadi 3 hari
|
| 79 |
+
sebelum kejadian B" — temporal embedding mengajarkan model
|
| 80 |
+
memahami hubungan waktu antar kejadian.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, d_graph: int, max_period: int = 10000):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.d_graph = d_graph
|
| 86 |
+
self.max_period = max_period
|
| 87 |
+
self.projection = nn.Sequential(
|
| 88 |
+
nn.Linear(d_graph, d_graph),
|
| 89 |
+
nn.GELU(),
|
| 90 |
+
nn.Linear(d_graph, d_graph),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Embed timestamps.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
timestamps: Tensor of shape (batch, n_events) with normalized
|
| 98 |
+
timestamps (0 = earliest, 1 = latest).
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tensor of shape (batch, n_events, d_graph).
|
| 102 |
+
"""
|
| 103 |
+
batch_size, n_events = timestamps.shape
|
| 104 |
+
device = timestamps.device
|
| 105 |
+
|
| 106 |
+
# Sinusoidal encoding
|
| 107 |
+
half_dim = self.d_graph // 2
|
| 108 |
+
emb = math.log(self.max_period) / (half_dim - 1)
|
| 109 |
+
emb = torch.exp(torch.arange(half_dim, device=device, dtype=torch.float32) * -emb)
|
| 110 |
+
emb = timestamps.float().unsqueeze(-1) * emb.unsqueeze(0).unsqueeze(0)
|
| 111 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 112 |
+
|
| 113 |
+
if emb.shape[-1] < self.d_graph:
|
| 114 |
+
# Pad if d_graph is odd
|
| 115 |
+
emb = F.pad(emb, (0, self.d_graph - emb.shape[-1]))
|
| 116 |
+
|
| 117 |
+
return self.projection(emb)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class NodeEncoder(nn.Module):
|
| 121 |
+
"""Encode a single evidence node or composition.
|
| 122 |
+
|
| 123 |
+
Each node is represented as:
|
| 124 |
+
- Text embedding (from the tokenizer's vocabulary)
|
| 125 |
+
- Confidence score
|
| 126 |
+
- Optional temporal context
|
| 127 |
+
- Source trust score
|
| 128 |
+
|
| 129 |
+
These are combined into a single d_graph-dimensional vector.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
d_graph: int,
|
| 135 |
+
vocab_size: int = 32000,
|
| 136 |
+
embed_confidence: bool = True,
|
| 137 |
+
embed_temporal: bool = True,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.d_graph = d_graph
|
| 141 |
+
|
| 142 |
+
# Text embedding (will be shared with the main model)
|
| 143 |
+
self.text_embed = nn.Embedding(vocab_size, d_graph)
|
| 144 |
+
|
| 145 |
+
# Confidence embedding
|
| 146 |
+
self.use_confidence = embed_confidence
|
| 147 |
+
if embed_confidence:
|
| 148 |
+
self.conf_embed = ConfidenceEmbedding(d_graph)
|
| 149 |
+
|
| 150 |
+
# Temporal embedding
|
| 151 |
+
self.use_temporal = embed_temporal
|
| 152 |
+
if embed_temporal:
|
| 153 |
+
self.temporal_embed = TemporalEmbedding(d_graph)
|
| 154 |
+
|
| 155 |
+
# Fusion layer — always build for max possible inputs
|
| 156 |
+
# At runtime, we may have fewer (e.g., no temporal data provided),
|
| 157 |
+
# so we use a flexible approach: always concatenate all available
|
| 158 |
+
# embeddings and project through a layer that handles the max size.
|
| 159 |
+
self._n_max_inputs = 1 + int(embed_confidence) + int(embed_temporal)
|
| 160 |
+
self.fusion = nn.Sequential(
|
| 161 |
+
nn.Linear(d_graph * self._n_max_inputs, d_graph),
|
| 162 |
+
nn.GELU(),
|
| 163 |
+
nn.LayerNorm(d_graph),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
token_ids: torch.Tensor,
|
| 169 |
+
confidence: Optional[torch.Tensor] = None,
|
| 170 |
+
timestamps: Optional[torch.Tensor] = None,
|
| 171 |
+
) -> torch.Tensor:
|
| 172 |
+
"""Encode a batch of evidence nodes.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
token_ids: Token IDs of shape (batch, n_nodes, seq_len).
|
| 176 |
+
confidence: Confidence scores of shape (batch, n_nodes).
|
| 177 |
+
timestamps: Timestamps of shape (batch, n_nodes).
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Encoded nodes of shape (batch, n_nodes, d_graph).
|
| 181 |
+
"""
|
| 182 |
+
# Text embedding: mean pool over sequence length
|
| 183 |
+
text_emb = self.text_embed(token_ids).mean(dim=-2) # (batch, n_nodes, d_graph)
|
| 184 |
+
|
| 185 |
+
embeddings = [text_emb]
|
| 186 |
+
|
| 187 |
+
if self.use_confidence:
|
| 188 |
+
if confidence is not None:
|
| 189 |
+
conf_emb = self.conf_embed(confidence.unsqueeze(-1)) # (batch, n_nodes, d_graph)
|
| 190 |
+
embeddings.append(conf_emb)
|
| 191 |
+
else:
|
| 192 |
+
# Zero-pad to maintain consistent dimension
|
| 193 |
+
embeddings.append(torch.zeros_like(text_emb))
|
| 194 |
+
|
| 195 |
+
if self.use_temporal:
|
| 196 |
+
if timestamps is not None:
|
| 197 |
+
temp_emb = self.temporal_embed(timestamps) # (batch, n_nodes, d_graph)
|
| 198 |
+
embeddings.append(temp_emb)
|
| 199 |
+
else:
|
| 200 |
+
embeddings.append(torch.zeros_like(text_emb))
|
| 201 |
+
|
| 202 |
+
# Fuse all embeddings
|
| 203 |
+
combined = torch.cat(embeddings, dim=-1)
|
| 204 |
+
return self.fusion(combined)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class GraphAttentionLayer(nn.Module):
|
| 208 |
+
"""Multi-head attention layer for graph-structured data.
|
| 209 |
+
|
| 210 |
+
Unlike standard self-attention, this operates on graph nodes
|
| 211 |
+
where edges represent structural relationships (compositions,
|
| 212 |
+
evidence links, temporal connections).
|
| 213 |
+
|
| 214 |
+
For now, we use standard multi-head attention over the node
|
| 215 |
+
sequence, as the structural information is already encoded
|
| 216 |
+
in the node features. Future versions can incorporate explicit
|
| 217 |
+
edge structure via graph attention networks (GAT).
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(self, d_graph: int, n_heads: int, dropout: float = 0.1):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.attention = nn.MultiheadAttention(
|
| 223 |
+
embed_dim=d_graph,
|
| 224 |
+
num_heads=n_heads,
|
| 225 |
+
dropout=dropout,
|
| 226 |
+
batch_first=True,
|
| 227 |
+
)
|
| 228 |
+
self.norm = nn.LayerNorm(d_graph)
|
| 229 |
+
self.ff = nn.Sequential(
|
| 230 |
+
nn.Linear(d_graph, d_graph * 4),
|
| 231 |
+
nn.GELU(),
|
| 232 |
+
nn.Dropout(dropout),
|
| 233 |
+
nn.Linear(d_graph * 4, d_graph),
|
| 234 |
+
nn.Dropout(dropout),
|
| 235 |
+
)
|
| 236 |
+
self.norm_ff = nn.LayerNorm(d_graph)
|
| 237 |
+
|
| 238 |
+
def forward(
|
| 239 |
+
self,
|
| 240 |
+
x: torch.Tensor,
|
| 241 |
+
mask: Optional[torch.Tensor] = None,
|
| 242 |
+
) -> torch.Tensor:
|
| 243 |
+
"""Forward pass.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
x: Node features of shape (batch, n_nodes, d_graph).
|
| 247 |
+
mask: Optional attention mask.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Updated node features of same shape.
|
| 251 |
+
"""
|
| 252 |
+
# Self-attention with residual
|
| 253 |
+
attn_out, _ = self.attention(x, x, x, attn_mask=mask)
|
| 254 |
+
x = self.norm(x + attn_out)
|
| 255 |
+
|
| 256 |
+
# Feed-forward with residual
|
| 257 |
+
ff_out = self.ff(x)
|
| 258 |
+
x = self.norm_ff(x + ff_out)
|
| 259 |
+
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class GraphConditioningEncoder(nn.Module):
|
| 264 |
+
"""Encode graph-structured conditioning data for the diffusion model.
|
| 265 |
+
|
| 266 |
+
This encoder takes structured data from the RSVS Knowledge Graph
|
| 267 |
+
and produces conditioning vectors that guide the diffusion process.
|
| 268 |
+
|
| 269 |
+
The encoding process:
|
| 270 |
+
1. Encode each evidence node (text + confidence + temporal)
|
| 271 |
+
2. Encode compositions (how concepts relate)
|
| 272 |
+
3. Encode anomalies (what doesn't fit)
|
| 273 |
+
4. Encode reasoning chain (step-by-step logic)
|
| 274 |
+
5. Aggregate via graph attention layers
|
| 275 |
+
6. Project to conditioning vector for the diffusion model
|
| 276 |
+
|
| 277 |
+
Output modes (conditioning_method):
|
| 278 |
+
- 'cross_attention': Returns (K, V) pairs for cross-attention in transformer
|
| 279 |
+
- 'ada_ln': Returns scale/shift parameters for adaptive layer norm
|
| 280 |
+
- 'concat': Returns a conditioning prefix to concatenate with input
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config: GraphEncoderConfig with hyperparameters.
|
| 284 |
+
vocab_size: Vocabulary size (must match tokenizer).
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
config: GraphEncoderConfig,
|
| 290 |
+
vocab_size: int = 32000,
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.config = config
|
| 294 |
+
self.conditioning_method = config.conditioning_method
|
| 295 |
+
|
| 296 |
+
# Node encoders for different graph element types
|
| 297 |
+
self.evidence_encoder = NodeEncoder(
|
| 298 |
+
d_graph=config.d_graph,
|
| 299 |
+
vocab_size=vocab_size,
|
| 300 |
+
embed_confidence=config.embed_confidence,
|
| 301 |
+
embed_temporal=config.embed_temporal,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.composition_encoder = NodeEncoder(
|
| 305 |
+
d_graph=config.d_graph,
|
| 306 |
+
vocab_size=vocab_size,
|
| 307 |
+
embed_confidence=config.embed_confidence,
|
| 308 |
+
embed_temporal=False, # Compositions don't have temporal info
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.anomaly_encoder = NodeEncoder(
|
| 312 |
+
d_graph=config.d_graph,
|
| 313 |
+
vocab_size=vocab_size,
|
| 314 |
+
embed_confidence=True, # Anomalies always have confidence
|
| 315 |
+
embed_temporal=config.embed_temporal,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.reasoning_encoder = NodeEncoder(
|
| 319 |
+
d_graph=config.d_graph,
|
| 320 |
+
vocab_size=vocab_size,
|
| 321 |
+
embed_confidence=True, # Reasoning steps have confidence
|
| 322 |
+
embed_temporal=False,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Source trust embedding
|
| 326 |
+
self.trust_embed = ConfidenceEmbedding(config.d_graph)
|
| 327 |
+
|
| 328 |
+
# Graph attention layers for cross-node interaction
|
| 329 |
+
self.graph_layers = nn.ModuleList([
|
| 330 |
+
GraphAttentionLayer(
|
| 331 |
+
d_graph=config.d_graph,
|
| 332 |
+
n_heads=config.n_graph_heads,
|
| 333 |
+
dropout=0.1,
|
| 334 |
+
)
|
| 335 |
+
for _ in range(config.n_graph_layers)
|
| 336 |
+
])
|
| 337 |
+
|
| 338 |
+
# Conditioning projection depends on method
|
| 339 |
+
# d_model_out will be set via set_output_dim() or defaults to d_graph
|
| 340 |
+
self._d_model_out = config.d_graph
|
| 341 |
+
|
| 342 |
+
if self.conditioning_method == "cross_attention":
|
| 343 |
+
# Project to (K, V) for cross-attention
|
| 344 |
+
self.key_proj = nn.Linear(config.d_graph, self._d_model_out)
|
| 345 |
+
self.value_proj = nn.Linear(config.d_graph, self._d_model_out)
|
| 346 |
+
|
| 347 |
+
elif self.conditioning_method == "ada_ln":
|
| 348 |
+
# Project to scale and shift for adaptive layer norm
|
| 349 |
+
self.scale_proj = nn.Linear(config.d_graph, self._d_model_out)
|
| 350 |
+
self.shift_proj = nn.Linear(config.d_graph, self._d_model_out)
|
| 351 |
+
|
| 352 |
+
elif self.conditioning_method == "concat":
|
| 353 |
+
# Project to a prefix sequence
|
| 354 |
+
self.concat_proj = nn.Linear(config.d_graph, self._d_model_out)
|
| 355 |
+
|
| 356 |
+
# Global pooling for summary
|
| 357 |
+
self.global_pool_proj = nn.Sequential(
|
| 358 |
+
nn.Linear(config.d_graph, config.d_graph),
|
| 359 |
+
nn.GELU(),
|
| 360 |
+
nn.Linear(config.d_graph, config.d_graph),
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Type embeddings for different graph element types
|
| 364 |
+
self.type_embeddings = nn.Embedding(4, config.d_graph)
|
| 365 |
+
# 0 = evidence, 1 = composition, 2 = anomaly, 3 = reasoning
|
| 366 |
+
|
| 367 |
+
def set_output_dim(self, d_model_out: int) -> None:
|
| 368 |
+
"""Set the output dimension for the projection layers.
|
| 369 |
+
|
| 370 |
+
This must be called after __init__ if d_graph != d_model
|
| 371 |
+
(which is typically the case when the graph encoder's d_graph
|
| 372 |
+
differs from the transformer's d_model).
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
d_model_out: Output dimension (typically the transformer's d_model).
|
| 376 |
+
"""
|
| 377 |
+
if d_model_out == self._d_model_out:
|
| 378 |
+
return # No change needed
|
| 379 |
+
|
| 380 |
+
self._d_model_out = d_model_out
|
| 381 |
+
|
| 382 |
+
# Rebuild projection layers with new output dim
|
| 383 |
+
if self.conditioning_method == "cross_attention":
|
| 384 |
+
self.key_proj = nn.Linear(self.config.d_graph, d_model_out)
|
| 385 |
+
self.value_proj = nn.Linear(self.config.d_graph, d_model_out)
|
| 386 |
+
elif self.conditioning_method == "ada_ln":
|
| 387 |
+
self.scale_proj = nn.Linear(self.config.d_graph, d_model_out)
|
| 388 |
+
self.shift_proj = nn.Linear(self.config.d_graph, d_model_out)
|
| 389 |
+
elif self.conditioning_method == "concat":
|
| 390 |
+
self.concat_proj = nn.Linear(self.config.d_graph, d_model_out)
|
| 391 |
+
|
| 392 |
+
def forward(
|
| 393 |
+
self,
|
| 394 |
+
evidence_ids: Optional[torch.Tensor] = None,
|
| 395 |
+
evidence_confidence: Optional[torch.Tensor] = None,
|
| 396 |
+
evidence_timestamps: Optional[torch.Tensor] = None,
|
| 397 |
+
composition_ids: Optional[torch.Tensor] = None,
|
| 398 |
+
composition_confidence: Optional[torch.Tensor] = None,
|
| 399 |
+
anomaly_ids: Optional[torch.Tensor] = None,
|
| 400 |
+
anomaly_confidence: Optional[torch.Tensor] = None,
|
| 401 |
+
anomaly_timestamps: Optional[torch.Tensor] = None,
|
| 402 |
+
reasoning_ids: Optional[torch.Tensor] = None,
|
| 403 |
+
reasoning_confidence: Optional[torch.Tensor] = None,
|
| 404 |
+
source_trust: Optional[torch.Tensor] = None,
|
| 405 |
+
batch_size: Optional[int] = None,
|
| 406 |
+
) -> dict[str, torch.Tensor]:
|
| 407 |
+
"""Encode graph conditioning data.
|
| 408 |
+
|
| 409 |
+
All inputs are optional — the encoder handles missing data gracefully.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
evidence_ids: Evidence node token IDs, shape (batch, n_evidence, seq_len).
|
| 413 |
+
evidence_confidence: Evidence confidence scores, shape (batch, n_evidence).
|
| 414 |
+
evidence_timestamps: Evidence timestamps, shape (batch, n_evidence).
|
| 415 |
+
composition_ids: Composition token IDs, shape (batch, n_compositions, seq_len).
|
| 416 |
+
composition_confidence: Composition confidence, shape (batch, n_compositions).
|
| 417 |
+
anomaly_ids: Anomaly token IDs, shape (batch, n_anomalies, seq_len).
|
| 418 |
+
anomaly_confidence: Anomaly confidence, shape (batch, n_anomalies).
|
| 419 |
+
anomaly_timestamps: Anomaly timestamps, shape (batch, n_anomalies).
|
| 420 |
+
reasoning_ids: Reasoning step token IDs, shape (batch, n_steps, seq_len).
|
| 421 |
+
reasoning_confidence: Reasoning confidence, shape (batch, n_steps).
|
| 422 |
+
source_trust: Source trust score, shape (batch,).
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Dictionary with conditioning tensors depending on conditioning_method:
|
| 426 |
+
- 'cross_attention': {'keys': ..., 'values': ..., 'global': ...}
|
| 427 |
+
- 'ada_ln': {'scale': ..., 'shift': ..., 'global': ...}
|
| 428 |
+
- 'concat': {'prefix': ..., 'global': ...}
|
| 429 |
+
"""
|
| 430 |
+
batch_size_inferred = self._infer_batch_size(
|
| 431 |
+
evidence_ids, composition_ids, anomaly_ids, reasoning_ids
|
| 432 |
+
)
|
| 433 |
+
device = next(self.parameters()).device
|
| 434 |
+
|
| 435 |
+
# Encode each type of graph element
|
| 436 |
+
node_embeddings = []
|
| 437 |
+
type_indices = []
|
| 438 |
+
|
| 439 |
+
# Evidence nodes
|
| 440 |
+
if evidence_ids is not None:
|
| 441 |
+
evidence_emb = self.evidence_encoder(
|
| 442 |
+
evidence_ids, evidence_confidence, evidence_timestamps
|
| 443 |
+
)
|
| 444 |
+
# Add type embedding
|
| 445 |
+
type_emb = self.type_embeddings(
|
| 446 |
+
torch.zeros(evidence_emb.shape[1], dtype=torch.long, device=device)
|
| 447 |
+
)
|
| 448 |
+
evidence_emb = evidence_emb + type_emb.unsqueeze(0)
|
| 449 |
+
node_embeddings.append(evidence_emb)
|
| 450 |
+
type_indices.extend([0] * evidence_emb.shape[1])
|
| 451 |
+
|
| 452 |
+
# Compositions
|
| 453 |
+
if composition_ids is not None:
|
| 454 |
+
comp_emb = self.composition_encoder(
|
| 455 |
+
composition_ids, composition_confidence
|
| 456 |
+
)
|
| 457 |
+
type_emb = self.type_embeddings(
|
| 458 |
+
torch.ones(comp_emb.shape[1], dtype=torch.long, device=device)
|
| 459 |
+
)
|
| 460 |
+
comp_emb = comp_emb + type_emb.unsqueeze(0)
|
| 461 |
+
node_embeddings.append(comp_emb)
|
| 462 |
+
type_indices.extend([1] * comp_emb.shape[1])
|
| 463 |
+
|
| 464 |
+
# Anomalies
|
| 465 |
+
if anomaly_ids is not None:
|
| 466 |
+
anom_emb = self.anomaly_encoder(
|
| 467 |
+
anomaly_ids, anomaly_confidence, anomaly_timestamps
|
| 468 |
+
)
|
| 469 |
+
type_emb = self.type_embeddings(
|
| 470 |
+
torch.full((anom_emb.shape[1],), 2, dtype=torch.long, device=device)
|
| 471 |
+
)
|
| 472 |
+
anom_emb = anom_emb + type_emb.unsqueeze(0)
|
| 473 |
+
node_embeddings.append(anom_emb)
|
| 474 |
+
type_indices.extend([2] * anom_emb.shape[1])
|
| 475 |
+
|
| 476 |
+
# Reasoning steps
|
| 477 |
+
if reasoning_ids is not None:
|
| 478 |
+
reason_emb = self.reasoning_encoder(
|
| 479 |
+
reasoning_ids, reasoning_confidence
|
| 480 |
+
)
|
| 481 |
+
type_emb = self.type_embeddings(
|
| 482 |
+
torch.full((reason_emb.shape[1],), 3, dtype=torch.long, device=device)
|
| 483 |
+
)
|
| 484 |
+
reason_emb = reason_emb + type_emb.unsqueeze(0)
|
| 485 |
+
node_embeddings.append(reason_emb)
|
| 486 |
+
type_indices.extend([3] * reason_emb.shape[1])
|
| 487 |
+
|
| 488 |
+
# If no graph data, return zero conditioning
|
| 489 |
+
if not node_embeddings:
|
| 490 |
+
bsz = batch_size or batch_size_inferred
|
| 491 |
+
dummy = torch.zeros(
|
| 492 |
+
bsz, 1, self.config.d_graph, device=device
|
| 493 |
+
)
|
| 494 |
+
return self._project_conditioning(dummy)
|
| 495 |
+
|
| 496 |
+
# Concatenate all node embeddings
|
| 497 |
+
all_nodes = torch.cat(node_embeddings, dim=1) # (batch, n_total_nodes, d_graph)
|
| 498 |
+
|
| 499 |
+
# Add source trust as a global bias
|
| 500 |
+
if source_trust is not None:
|
| 501 |
+
trust_emb = self.trust_embed(source_trust.unsqueeze(-1)) # (batch, d_graph)
|
| 502 |
+
# Broadcast trust to all nodes
|
| 503 |
+
all_nodes = all_nodes + trust_emb.unsqueeze(1) * 0.1 # Small influence
|
| 504 |
+
|
| 505 |
+
# Apply graph attention layers
|
| 506 |
+
for layer in self.graph_layers:
|
| 507 |
+
all_nodes = layer(all_nodes)
|
| 508 |
+
|
| 509 |
+
# Compute global conditioning (mean pool)
|
| 510 |
+
global_cond = all_nodes.mean(dim=1) # (batch, d_graph)
|
| 511 |
+
global_cond = self.global_pool_proj(global_cond)
|
| 512 |
+
|
| 513 |
+
# Project based on conditioning method
|
| 514 |
+
result = self._project_conditioning(all_nodes)
|
| 515 |
+
result["global"] = global_cond
|
| 516 |
+
|
| 517 |
+
return result
|
| 518 |
+
|
| 519 |
+
def _project_conditioning(
|
| 520 |
+
self, node_features: torch.Tensor
|
| 521 |
+
) -> dict[str, torch.Tensor]:
|
| 522 |
+
"""Project node features to conditioning format.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
node_features: Shape (batch, n_nodes, d_graph).
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
Dictionary with conditioning tensors.
|
| 529 |
+
"""
|
| 530 |
+
result = {}
|
| 531 |
+
|
| 532 |
+
if self.conditioning_method == "cross_attention":
|
| 533 |
+
result["keys"] = self.key_proj(node_features)
|
| 534 |
+
result["values"] = self.value_proj(node_features)
|
| 535 |
+
|
| 536 |
+
elif self.conditioning_method == "ada_ln":
|
| 537 |
+
# Use mean-pooled features for scale/shift
|
| 538 |
+
pooled = node_features.mean(dim=1)
|
| 539 |
+
result["scale"] = self.scale_proj(pooled)
|
| 540 |
+
result["shift"] = self.shift_proj(pooled)
|
| 541 |
+
|
| 542 |
+
elif self.conditioning_method == "concat":
|
| 543 |
+
result["prefix"] = self.concat_proj(node_features)
|
| 544 |
+
|
| 545 |
+
return result
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def _infer_batch_size(*tensors) -> int:
|
| 549 |
+
"""Infer batch size from the first non-None tensor."""
|
| 550 |
+
for t in tensors:
|
| 551 |
+
if t is not None:
|
| 552 |
+
return t.shape[0]
|
| 553 |
+
return 1
|
diffusion_llm/model/noise_scheduler.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Noise Scheduler
|
| 3 |
+
|
| 4 |
+
Implements the forward (noising) and reverse (denoising) diffusion process.
|
| 5 |
+
|
| 6 |
+
Forward Process:
|
| 7 |
+
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
|
| 8 |
+
|
| 9 |
+
Reverse Process:
|
| 10 |
+
p(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)
|
| 11 |
+
|
| 12 |
+
This scheduler supports:
|
| 13 |
+
- Linear noise schedule (Ho et al., 2020)
|
| 14 |
+
- Cosine noise schedule (Nichol & Dhariwal, 2021) — recommended
|
| 15 |
+
- Sigmoid noise schedule
|
| 16 |
+
|
| 17 |
+
Analogi: Seperti Jin Soun membentuk pikirannya — dari noise
|
| 18 |
+
(kabur, tidak jelas) menjadi sinyal (pola yang jelas).
|
| 19 |
+
Setiap langkah denoising = satu langkah lebih dekat ke
|
| 20 |
+
kesimpulan yang koheren.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
from typing import Optional
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NoiseScheduler(nn.Module):
|
| 33 |
+
"""Noise scheduler for the diffusion process.
|
| 34 |
+
|
| 35 |
+
Manages the noise schedule (beta values, alpha values, etc.)
|
| 36 |
+
and provides methods for adding noise and computing posterior
|
| 37 |
+
distributions.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
n_timesteps: Total number of diffusion timesteps.
|
| 41 |
+
schedule_type: Type of noise schedule ('linear', 'cosine', 'sigmoid').
|
| 42 |
+
beta_start: Starting beta for linear schedule.
|
| 43 |
+
beta_end: Ending beta for linear schedule.
|
| 44 |
+
prediction_type: What the model predicts ('epsilon', 'x0', or 'v').
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
n_timesteps: int = 1000,
|
| 50 |
+
schedule_type: str = "cosine",
|
| 51 |
+
beta_start: float = 1e-4,
|
| 52 |
+
beta_end: float = 0.02,
|
| 53 |
+
prediction_type: str = "epsilon",
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.n_timesteps = n_timesteps
|
| 57 |
+
self.schedule_type = schedule_type
|
| 58 |
+
self.beta_start = beta_start
|
| 59 |
+
self.beta_end = beta_end
|
| 60 |
+
self.prediction_type = prediction_type
|
| 61 |
+
|
| 62 |
+
# Compute and register noise schedule buffers
|
| 63 |
+
betas = self._compute_betas()
|
| 64 |
+
alphas = 1.0 - betas
|
| 65 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 66 |
+
alphas_cumprod_prev = torch.cat(
|
| 67 |
+
[torch.ones(1, dtype=betas.dtype), alphas_cumprod[:-1]]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Register all as buffers (part of model state but not parameters)
|
| 71 |
+
self.register_buffer("betas", betas)
|
| 72 |
+
self.register_buffer("alphas", alphas)
|
| 73 |
+
self.register_buffer("alphas_cumprod", alphas_cumprod)
|
| 74 |
+
self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
| 75 |
+
|
| 76 |
+
# For q(x_t | x_0) computation
|
| 77 |
+
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
| 78 |
+
self.register_buffer(
|
| 79 |
+
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# For posterior q(x_{t-1} | x_t, x_0)
|
| 83 |
+
posterior_variance = (
|
| 84 |
+
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 85 |
+
)
|
| 86 |
+
self.register_buffer("posterior_variance", posterior_variance)
|
| 87 |
+
self.register_buffer(
|
| 88 |
+
"posterior_log_variance_clipped",
|
| 89 |
+
torch.log(posterior_variance.clamp(min=1e-20)),
|
| 90 |
+
)
|
| 91 |
+
self.register_buffer(
|
| 92 |
+
"posterior_mean_coef1",
|
| 93 |
+
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
|
| 94 |
+
)
|
| 95 |
+
self.register_buffer(
|
| 96 |
+
"posterior_mean_coef2",
|
| 97 |
+
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _compute_betas(self) -> torch.Tensor:
|
| 101 |
+
"""Compute beta schedule.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Tensor of shape (n_timesteps,) with beta values.
|
| 105 |
+
"""
|
| 106 |
+
if self.schedule_type == "linear":
|
| 107 |
+
return torch.linspace(
|
| 108 |
+
self.beta_start, self.beta_end, self.n_timesteps
|
| 109 |
+
)
|
| 110 |
+
elif self.schedule_type == "cosine":
|
| 111 |
+
return self._cosine_schedule()
|
| 112 |
+
elif self.schedule_type == "sigmoid":
|
| 113 |
+
return self._sigmoid_schedule()
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"Unknown schedule_type '{self.schedule_type}'. "
|
| 117 |
+
f"Use 'linear', 'cosine', or 'sigmoid'."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _cosine_schedule(self, s: float = 0.008) -> torch.Tensor:
|
| 121 |
+
"""Cosine schedule as proposed in Nichol & Dhariwal 2021.
|
| 122 |
+
|
| 123 |
+
alpha_bar(t) = cos^2((t/T + s) / (1 + s) * pi/2)
|
| 124 |
+
beta(t) = 1 - alpha_bar(t) / alpha_bar(t-1)
|
| 125 |
+
|
| 126 |
+
This schedule avoids too much noise at the end and too
|
| 127 |
+
little at the beginning, leading to more stable training.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
s: Offset to prevent singularity at t=0.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Tensor of beta values.
|
| 134 |
+
"""
|
| 135 |
+
steps = self.n_timesteps + 1
|
| 136 |
+
t = torch.linspace(0, self.n_timesteps, steps)
|
| 137 |
+
alphas_cumprod = torch.cos(
|
| 138 |
+
((t / self.n_timesteps) + s) / (1 + s) * math.pi * 0.5
|
| 139 |
+
) ** 2
|
| 140 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 141 |
+
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 142 |
+
return torch.clamp(betas, 0.0001, 0.9999)
|
| 143 |
+
|
| 144 |
+
def _sigmoid_schedule(self) -> torch.Tensor:
|
| 145 |
+
"""Sigmoid-based noise schedule.
|
| 146 |
+
|
| 147 |
+
beta(t) = sigmoid(-gamma * (t - T/2) + offset) * (beta_end - beta_start) + beta_start
|
| 148 |
+
|
| 149 |
+
Provides a smooth transition between low and high noise.
|
| 150 |
+
"""
|
| 151 |
+
betas = torch.linspace(-6, 6, self.n_timesteps)
|
| 152 |
+
betas = torch.sigmoid(betas) * (self.beta_end - self.beta_start) + self.beta_start
|
| 153 |
+
return torch.clamp(betas, 0.0001, 0.9999)
|
| 154 |
+
|
| 155 |
+
def add_noise(
|
| 156 |
+
self,
|
| 157 |
+
x_0: torch.Tensor,
|
| 158 |
+
noise: torch.Tensor,
|
| 159 |
+
t: torch.Tensor,
|
| 160 |
+
) -> torch.Tensor:
|
| 161 |
+
"""Forward diffusion: add noise to clean data.
|
| 162 |
+
|
| 163 |
+
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
|
| 164 |
+
|
| 165 |
+
x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
x_0: Clean data tensor of shape (batch, seq_len, d_model).
|
| 169 |
+
noise: Noise tensor of same shape as x_0.
|
| 170 |
+
t: Timestep indices of shape (batch,).
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Noisy data x_t of same shape as x_0.
|
| 174 |
+
"""
|
| 175 |
+
# Gather schedule values for timesteps
|
| 176 |
+
sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_0)
|
| 177 |
+
sqrt_one_minus_alpha = self._gather(
|
| 178 |
+
self.sqrt_one_minus_alphas_cumprod, t, x_0
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
|
| 182 |
+
|
| 183 |
+
def compute_loss_target(
|
| 184 |
+
self,
|
| 185 |
+
x_0: torch.Tensor,
|
| 186 |
+
noise: torch.Tensor,
|
| 187 |
+
t: torch.Tensor,
|
| 188 |
+
) -> torch.Tensor:
|
| 189 |
+
"""Compute the target for the diffusion loss.
|
| 190 |
+
|
| 191 |
+
Depending on prediction_type:
|
| 192 |
+
- 'epsilon': target = noise (predict the noise that was added)
|
| 193 |
+
- 'x0': target = x_0 (predict the clean data directly)
|
| 194 |
+
- 'v': target = v (velocity prediction, combines both)
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
x_0: Clean data.
|
| 198 |
+
noise: Noise that was added.
|
| 199 |
+
t: Timestep indices.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Target tensor for loss computation.
|
| 203 |
+
"""
|
| 204 |
+
if self.prediction_type == "epsilon":
|
| 205 |
+
return noise
|
| 206 |
+
elif self.prediction_type == "x0":
|
| 207 |
+
return x_0
|
| 208 |
+
elif self.prediction_type == "v":
|
| 209 |
+
# v = sqrt(alpha_bar) * noise - sqrt(1 - alpha_bar) * x_0
|
| 210 |
+
sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_0)
|
| 211 |
+
sqrt_one_minus_alpha = self._gather(
|
| 212 |
+
self.sqrt_one_minus_alphas_cumprod, t, x_0
|
| 213 |
+
)
|
| 214 |
+
return sqrt_alpha * noise - sqrt_one_minus_alpha * x_0
|
| 215 |
+
else:
|
| 216 |
+
raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
|
| 217 |
+
|
| 218 |
+
def predict_x0_from_epsilon(
|
| 219 |
+
self,
|
| 220 |
+
x_t: torch.Tensor,
|
| 221 |
+
epsilon: torch.Tensor,
|
| 222 |
+
t: torch.Tensor,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
"""Predict x_0 from the model's epsilon prediction.
|
| 225 |
+
|
| 226 |
+
x_0 = (x_t - sqrt(1 - alpha_bar_t) * epsilon) / sqrt(alpha_bar_t)
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x_t: Noisy data.
|
| 230 |
+
epsilon: Predicted noise.
|
| 231 |
+
t: Timestep indices.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Predicted clean data x_0.
|
| 235 |
+
"""
|
| 236 |
+
sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_t)
|
| 237 |
+
sqrt_one_minus_alpha = self._gather(
|
| 238 |
+
self.sqrt_one_minus_alphas_cumprod, t, x_t
|
| 239 |
+
)
|
| 240 |
+
return (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alpha
|
| 241 |
+
|
| 242 |
+
def predict_x0_from_v(
|
| 243 |
+
self,
|
| 244 |
+
x_t: torch.Tensor,
|
| 245 |
+
v: torch.Tensor,
|
| 246 |
+
t: torch.Tensor,
|
| 247 |
+
) -> torch.Tensor:
|
| 248 |
+
"""Predict x_0 from velocity prediction.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
x_t: Noisy data.
|
| 252 |
+
v: Predicted velocity.
|
| 253 |
+
t: Timestep indices.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Predicted clean data x_0.
|
| 257 |
+
"""
|
| 258 |
+
sqrt_alpha = self._gather(self.sqrt_alphas_cumprod, t, x_t)
|
| 259 |
+
sqrt_one_minus_alpha = self._gather(
|
| 260 |
+
self.sqrt_one_minus_alphas_cumprod, t, x_t
|
| 261 |
+
)
|
| 262 |
+
return sqrt_alpha * x_t - sqrt_one_minus_alpha * v
|
| 263 |
+
|
| 264 |
+
def posterior_mean(
|
| 265 |
+
self,
|
| 266 |
+
x_0: torch.Tensor,
|
| 267 |
+
x_t: torch.Tensor,
|
| 268 |
+
t: torch.Tensor,
|
| 269 |
+
) -> torch.Tensor:
|
| 270 |
+
"""Compute the posterior mean q(x_{t-1} | x_t, x_0).
|
| 271 |
+
|
| 272 |
+
mu = coef1 * x_0 + coef2 * x_t
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
x_0: Predicted or actual clean data.
|
| 276 |
+
x_t: Noisy data at timestep t.
|
| 277 |
+
t: Timestep indices.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Posterior mean tensor.
|
| 281 |
+
"""
|
| 282 |
+
coef1 = self._gather(self.posterior_mean_coef1, t, x_t)
|
| 283 |
+
coef2 = self._gather(self.posterior_mean_coef2, t, x_t)
|
| 284 |
+
return coef1 * x_0 + coef2 * x_t
|
| 285 |
+
|
| 286 |
+
def step_ddpm(
|
| 287 |
+
self,
|
| 288 |
+
model_output: torch.Tensor,
|
| 289 |
+
x_t: torch.Tensor,
|
| 290 |
+
t: torch.Tensor,
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
"""Single DDPM reverse step: x_t -> x_{t-1}.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
model_output: Model prediction (epsilon, x0, or v).
|
| 296 |
+
x_t: Noisy data at timestep t.
|
| 297 |
+
t: Current timestep indices.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Denoised data at timestep t-1.
|
| 301 |
+
"""
|
| 302 |
+
# Get predicted x_0
|
| 303 |
+
if self.prediction_type == "epsilon":
|
| 304 |
+
x_0_pred = self.predict_x0_from_epsilon(x_t, model_output, t)
|
| 305 |
+
elif self.prediction_type == "x0":
|
| 306 |
+
x_0_pred = model_output
|
| 307 |
+
elif self.prediction_type == "v":
|
| 308 |
+
x_0_pred = self.predict_x0_from_v(x_t, model_output, t)
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
|
| 311 |
+
|
| 312 |
+
# Clamp x_0 prediction for stability
|
| 313 |
+
x_0_pred = x_0_pred.clamp(-5.0, 5.0)
|
| 314 |
+
|
| 315 |
+
# Compute posterior mean
|
| 316 |
+
mean = self.posterior_mean(x_0_pred, x_t, t)
|
| 317 |
+
|
| 318 |
+
# Add noise (except for t=0)
|
| 319 |
+
if t.min() > 0:
|
| 320 |
+
noise = torch.randn_like(x_t)
|
| 321 |
+
# Get posterior variance
|
| 322 |
+
log_variance = self._gather(
|
| 323 |
+
self.posterior_log_variance_clipped, t, x_t
|
| 324 |
+
)
|
| 325 |
+
noise_scale = torch.exp(0.5 * log_variance)
|
| 326 |
+
return mean + noise_scale * noise
|
| 327 |
+
else:
|
| 328 |
+
return mean
|
| 329 |
+
|
| 330 |
+
def step_ddim(
|
| 331 |
+
self,
|
| 332 |
+
model_output: torch.Tensor,
|
| 333 |
+
x_t: torch.Tensor,
|
| 334 |
+
t: int,
|
| 335 |
+
t_prev: int,
|
| 336 |
+
eta: float = 0.0,
|
| 337 |
+
) -> torch.Tensor:
|
| 338 |
+
"""Single DDIM reverse step: x_t -> x_{t_prev}.
|
| 339 |
+
|
| 340 |
+
DDIM is deterministic when eta=0, allowing fewer steps
|
| 341 |
+
at inference time while maintaining quality.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
model_output: Model prediction.
|
| 345 |
+
x_t: Noisy data at timestep t.
|
| 346 |
+
t: Current timestep (scalar).
|
| 347 |
+
t_prev: Previous timestep (scalar, < t).
|
| 348 |
+
eta: Stochasticity parameter (0 = deterministic).
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Denoised data at timestep t_prev.
|
| 352 |
+
"""
|
| 353 |
+
device = x_t.device
|
| 354 |
+
t_tensor = torch.tensor([t], device=device).expand(x_t.shape[0])
|
| 355 |
+
|
| 356 |
+
# Get predicted x_0
|
| 357 |
+
if self.prediction_type == "epsilon":
|
| 358 |
+
x_0_pred = self.predict_x0_from_epsilon(x_t, model_output, t_tensor)
|
| 359 |
+
elif self.prediction_type == "x0":
|
| 360 |
+
x_0_pred = model_output
|
| 361 |
+
elif self.prediction_type == "v":
|
| 362 |
+
x_0_pred = self.predict_x0_from_v(x_t, model_output, t_tensor)
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError(f"Unknown prediction_type: {self.prediction_type}")
|
| 365 |
+
|
| 366 |
+
x_0_pred = x_0_pred.clamp(-5.0, 5.0)
|
| 367 |
+
|
| 368 |
+
# alpha_bar values
|
| 369 |
+
alpha_t = self.alphas_cumprod[t]
|
| 370 |
+
alpha_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
|
| 371 |
+
|
| 372 |
+
# Compute sigma
|
| 373 |
+
sigma = eta * torch.sqrt(
|
| 374 |
+
(1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Direction pointing to x_t
|
| 378 |
+
pred_dir = torch.sqrt(1 - alpha_prev - sigma ** 2) * (
|
| 379 |
+
(x_t - torch.sqrt(alpha_t) * x_0_pred) / torch.sqrt(1 - alpha_t)
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# DDIM update
|
| 383 |
+
x_prev = torch.sqrt(alpha_prev) * x_0_pred + pred_dir
|
| 384 |
+
|
| 385 |
+
if eta > 0 and sigma > 0:
|
| 386 |
+
noise = torch.randn_like(x_t)
|
| 387 |
+
x_prev = x_prev + sigma * noise
|
| 388 |
+
|
| 389 |
+
return x_prev
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
def _gather(
|
| 393 |
+
values: torch.Tensor,
|
| 394 |
+
t: torch.Tensor,
|
| 395 |
+
target: torch.Tensor,
|
| 396 |
+
) -> torch.Tensor:
|
| 397 |
+
"""Gather schedule values for timesteps and reshape for broadcasting.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
values: Schedule values of shape (n_timesteps,).
|
| 401 |
+
t: Timestep indices of shape (batch,).
|
| 402 |
+
target: Target tensor to match shape.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Gathered values reshaped for broadcasting with target.
|
| 406 |
+
"""
|
| 407 |
+
gathered = values.gather(0, t)
|
| 408 |
+
# Reshape to (batch, 1, 1, ...) for broadcasting
|
| 409 |
+
ndim = target.ndim - 1 # minus batch dim
|
| 410 |
+
for _ in range(ndim):
|
| 411 |
+
gathered = gathered.unsqueeze(-1)
|
| 412 |
+
return gathered.expand_as(target)
|
| 413 |
+
|
| 414 |
+
def get_timestep_schedule(self, n_inference_steps: int) -> list[int]:
|
| 415 |
+
"""Get evenly-spaced timestep schedule for inference.
|
| 416 |
+
|
| 417 |
+
For DDIM: use a subset of the training timesteps.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
n_inference_steps: Number of inference steps.
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
List of timestep indices in descending order.
|
| 424 |
+
"""
|
| 425 |
+
step_size = self.n_timesteps // n_inference_steps
|
| 426 |
+
return list(range(self.n_timesteps - 1, 0, -step_size))
|
diffusion_llm/requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AAM Diffusion LLM — Dependencies
|
| 2 |
+
|
| 3 |
+
# Core
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
|
| 7 |
+
# Training
|
| 8 |
+
tensorboard>=2.13.0
|
| 9 |
+
|
| 10 |
+
# Optional (for logging and monitoring)
|
| 11 |
+
# wandb>=0.15.0
|
| 12 |
+
|
| 13 |
+
# Testing
|
| 14 |
+
pytest>=7.4.0
|
| 15 |
+
|
| 16 |
+
# Note: This framework is designed to be lightweight.
|
| 17 |
+
# No heavy ML framework dependencies beyond PyTorch.
|
diffusion_llm/scripts/evaluate.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AAM Diffusion LLM — Evaluation Script
|
| 4 |
+
|
| 5 |
+
Evaluates a trained AAM Diffusion Model on test data or
|
| 6 |
+
generates sample narratives from graph conditioning.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
# Evaluate on test data
|
| 10 |
+
python scripts/evaluate.py --checkpoint output/best.pt
|
| 11 |
+
|
| 12 |
+
# Generate sample narratives
|
| 13 |
+
python scripts/evaluate.py --checkpoint output/best.pt --generate
|
| 14 |
+
|
| 15 |
+
# Interactive mode
|
| 16 |
+
python scripts/evaluate.py --checkpoint output/best.pt --interactive
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import sys
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 28 |
+
|
| 29 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig
|
| 30 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 31 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 32 |
+
from diffusion_llm.inference.generator import AamGenerator
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def parse_args() -> argparse.Namespace:
|
| 42 |
+
parser = argparse.ArgumentParser(description="Evaluate AAM Diffusion LLM")
|
| 43 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
|
| 44 |
+
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
|
| 45 |
+
parser.add_argument("--generate", action="store_true", help="Generate sample narratives")
|
| 46 |
+
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
| 47 |
+
parser.add_argument("--test_data", type=str, default=None, help="Test data path (JSONL)")
|
| 48 |
+
parser.add_argument("--n_steps", type=int, default=50, help="Inference denoising steps")
|
| 49 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
|
| 50 |
+
parser.add_argument("--language", type=str, default="id", help="Output language")
|
| 51 |
+
return parser.parse_args()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def generate_samples(generator: AamGenerator, language: str) -> None:
|
| 55 |
+
"""Generate sample narratives from predefined graph conditioning."""
|
| 56 |
+
samples = [
|
| 57 |
+
{
|
| 58 |
+
"trigger": "Siapa yang mencuri Snow Plum Pill?",
|
| 59 |
+
"evidence_nodes": ["Hefei", "Diancang Five Swords", "Ju Jangmok", "Gyeryong Merchant Guild"],
|
| 60 |
+
"anomalies": ["Tidak ada konsumsi pil baru di pasar gelap", "Pencuri menghilang tanpa jejak"],
|
| 61 |
+
"reasoning_steps": ["Cross-reference tanggal kejadian", "Deteksi ketidaksesuaian pola"],
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"trigger": "Analisis pergerakan Diancang Five Swords",
|
| 65 |
+
"evidence_nodes": ["Gu Ilmu", "Jang Hangi", "Diancang Five Swords", "Hefei"],
|
| 66 |
+
"anomalies": ["Success rate pair lebih tinggi dari biasanya"],
|
| 67 |
+
"reasoning_steps": ["Recall laporan terkait", "Pattern completion dari bukti"],
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"trigger": "Hubungan antara Ju Jangmok dan pencurian",
|
| 71 |
+
"evidence_nodes": ["Ju Jangmok", "Snow Plum Pill", "dark_faction"],
|
| 72 |
+
"anomalies": ["Ju Jangmok menghilang hari yang sama"],
|
| 73 |
+
"reasoning_steps": ["Eliminasi tersangka obvious", "Verify konsistensi"],
|
| 74 |
+
},
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
print("\n" + "=" * 60)
|
| 78 |
+
print(" AAM Diffusion LLM — Sample Generation")
|
| 79 |
+
print("=" * 60)
|
| 80 |
+
|
| 81 |
+
for i, sample in enumerate(samples, 1):
|
| 82 |
+
result = generator.generate(
|
| 83 |
+
trigger=sample["trigger"],
|
| 84 |
+
evidence_nodes=sample["evidence_nodes"],
|
| 85 |
+
anomalies=sample["anomalies"],
|
| 86 |
+
reasoning_steps=sample["reasoning_steps"],
|
| 87 |
+
language=language,
|
| 88 |
+
)
|
| 89 |
+
print(f"\n--- Sample {i} ---")
|
| 90 |
+
print(f"Trigger: {sample['trigger']}")
|
| 91 |
+
print(f"Evidence: {', '.join(sample['evidence_nodes'])}")
|
| 92 |
+
print(f"Anomalies: {'; '.join(sample['anomalies'])}")
|
| 93 |
+
print(f"\nGenerated Narrative:")
|
| 94 |
+
print(result.narrative)
|
| 95 |
+
print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s]")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def interactive_mode(generator: AamGenerator, language: str) -> None:
|
| 99 |
+
"""Interactive generation mode."""
|
| 100 |
+
print("\n" + "=" * 60)
|
| 101 |
+
print(" AAM Diffusion LLM — Interactive Mode")
|
| 102 |
+
print(" Type 'quit' to exit")
|
| 103 |
+
print("=" * 60)
|
| 104 |
+
|
| 105 |
+
while True:
|
| 106 |
+
trigger = input("\nTrigger/Question: ").strip()
|
| 107 |
+
if trigger.lower() in ("quit", "exit", "q"):
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
evidence = input("Evidence nodes (comma-separated): ").strip()
|
| 111 |
+
evidence_nodes = [e.strip() for e in evidence.split(",") if e.strip()] if evidence else None
|
| 112 |
+
|
| 113 |
+
anomalies_input = input("Anomalies (comma-separated): ").strip()
|
| 114 |
+
anomalies = [a.strip() for a in anomalies_input.split(",") if a.strip()] if anomalies_input else None
|
| 115 |
+
|
| 116 |
+
result = generator.generate(
|
| 117 |
+
trigger=trigger,
|
| 118 |
+
evidence_nodes=evidence_nodes,
|
| 119 |
+
anomalies=anomalies,
|
| 120 |
+
language=language,
|
| 121 |
+
)
|
| 122 |
+
print(f"\nGenerated Narrative:\n{result.narrative}")
|
| 123 |
+
print(f"\n[Steps: {result.n_diffusion_steps}, Time: {result.generation_time_s:.2f}s, Confidence: {result.confidence:.1%}]")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main() -> None:
|
| 127 |
+
args = parse_args()
|
| 128 |
+
|
| 129 |
+
# Load model
|
| 130 |
+
logger.info("Loading model from %s", args.checkpoint)
|
| 131 |
+
model = AamDiffusionModel.load(args.checkpoint)
|
| 132 |
+
|
| 133 |
+
# Load or create tokenizer
|
| 134 |
+
if args.tokenizer:
|
| 135 |
+
tokenizer = AamTokenizer.load(args.tokenizer)
|
| 136 |
+
else:
|
| 137 |
+
# Try to find tokenizer in same directory as checkpoint
|
| 138 |
+
tokenizer_path = Path(args.checkpoint).parent / "data" / "tokenizer.json"
|
| 139 |
+
if tokenizer_path.exists():
|
| 140 |
+
tokenizer = AamTokenizer.load(tokenizer_path)
|
| 141 |
+
else:
|
| 142 |
+
logger.warning("No tokenizer found. Using untrained tokenizer.")
|
| 143 |
+
tokenizer = AamTokenizer()
|
| 144 |
+
|
| 145 |
+
# Create generator
|
| 146 |
+
generator = AamGenerator(model, tokenizer, model.config)
|
| 147 |
+
|
| 148 |
+
if args.interactive:
|
| 149 |
+
interactive_mode(generator, args.language)
|
| 150 |
+
elif args.generate:
|
| 151 |
+
generate_samples(generator, args.language)
|
| 152 |
+
else:
|
| 153 |
+
logger.info("Use --generate or --interactive flag")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
main()
|
diffusion_llm/scripts/export.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AAM Diffusion LLM — Export Script
|
| 4 |
+
|
| 5 |
+
Export a trained model for deployment.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/export.py --checkpoint output/best.pt --output model_export/
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import logging
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
+
|
| 20 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig
|
| 21 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 22 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main() -> None:
|
| 29 |
+
parser = argparse.ArgumentParser(description="Export AAM Diffusion Model")
|
| 30 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 31 |
+
parser.add_argument("--output", type=str, default="./model_export")
|
| 32 |
+
parser.add_argument("--format", type=str, default="pt", choices=["pt", "onnx"])
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
|
| 35 |
+
output_dir = Path(args.output)
|
| 36 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Load model
|
| 39 |
+
model = AamDiffusionModel.load(args.checkpoint)
|
| 40 |
+
model.eval()
|
| 41 |
+
|
| 42 |
+
# Save model
|
| 43 |
+
model_path = output_dir / "model.pt"
|
| 44 |
+
model.save(str(model_path))
|
| 45 |
+
logger.info("Model exported to %s", model_path)
|
| 46 |
+
|
| 47 |
+
# Save config
|
| 48 |
+
config_path = output_dir / "config.json"
|
| 49 |
+
model.config.to_json(config_path)
|
| 50 |
+
logger.info("Config saved to %s", config_path)
|
| 51 |
+
|
| 52 |
+
# Try to copy tokenizer
|
| 53 |
+
checkpoint_dir = Path(args.checkpoint).parent
|
| 54 |
+
tokenizer_path = checkpoint_dir / "data" / "tokenizer.json"
|
| 55 |
+
if tokenizer_path.exists():
|
| 56 |
+
import shutil
|
| 57 |
+
shutil.copy(tokenizer_path, output_dir / "tokenizer.json")
|
| 58 |
+
logger.info("Tokenizer copied to %s", output_dir / "tokenizer.json")
|
| 59 |
+
|
| 60 |
+
# Summary
|
| 61 |
+
print(f"\nExport complete!")
|
| 62 |
+
print(f" Model: {model_path}")
|
| 63 |
+
print(f" Config: {config_path}")
|
| 64 |
+
print(f" Parameters: {model._format_params(model.get_num_params())}")
|
| 65 |
+
print(f"\n This is AAM's own body — 1 mind + 1 body.")
|
| 66 |
+
print(f" Mind = RSVS Knowledge Graph")
|
| 67 |
+
print(f" Body = This Diffusion Model ({model.config.model_name})")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|
diffusion_llm/scripts/train.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AAM Diffusion LLM — Training Script
|
| 4 |
+
|
| 5 |
+
Main entry point for training the AAM Diffusion Model.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Train with default config (base model)
|
| 9 |
+
python scripts/train.py
|
| 10 |
+
|
| 11 |
+
# Train with specific model size
|
| 12 |
+
python scripts/train.py --model_size small
|
| 13 |
+
|
| 14 |
+
# Train with custom config
|
| 15 |
+
python scripts/train.py --config path/to/config.json
|
| 16 |
+
|
| 17 |
+
# Train with specific data
|
| 18 |
+
python scripts/train.py --train_data path/to/train.jsonl --val_data path/to/val.jsonl
|
| 19 |
+
|
| 20 |
+
Analogi: Seperti Jin Soun memulai latihan fisiknya —
|
| 21 |
+
ini adalah titik awal di mana "tubuh" AAM mulai dilatih.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import logging
|
| 28 |
+
import sys
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
# Add parent directory to path for imports
|
| 32 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 33 |
+
|
| 34 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config
|
| 35 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 36 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 37 |
+
from diffusion_llm.training.trainer import AamTrainer
|
| 38 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset
|
| 39 |
+
from diffusion_llm.data.data_pipeline import DataPipeline
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def parse_args() -> argparse.Namespace:
|
| 49 |
+
"""Parse command-line arguments."""
|
| 50 |
+
parser = argparse.ArgumentParser(
|
| 51 |
+
description="Train AAM Diffusion LLM",
|
| 52 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Model configuration
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--model_size", type=str, default="base",
|
| 58 |
+
choices=["tiny", "small", "base", "medium"],
|
| 59 |
+
help="Model size preset",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--config", type=str, default=None,
|
| 63 |
+
help="Path to custom config JSON (overrides --model_size)",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Data
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--train_data", type=str, default=None,
|
| 69 |
+
help="Path to training data (JSONL)",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--val_data", type=str, default=None,
|
| 73 |
+
help="Path to validation data (JSONL)",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--output_dir", type=str, default="./output",
|
| 77 |
+
help="Output directory for checkpoints and logs",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--force_regenerate", action="store_true",
|
| 81 |
+
help="Force regenerate synthetic data",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Training overrides
|
| 85 |
+
parser.add_argument("--batch_size", type=int, default=None)
|
| 86 |
+
parser.add_argument("--learning_rate", type=float, default=None)
|
| 87 |
+
parser.add_argument("--max_steps", type=int, default=None)
|
| 88 |
+
parser.add_argument("--n_timesteps", type=int, default=None)
|
| 89 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 90 |
+
|
| 91 |
+
return parser.parse_args()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main() -> None:
|
| 95 |
+
"""Main training entry point."""
|
| 96 |
+
args = parse_args()
|
| 97 |
+
|
| 98 |
+
# Load or create config
|
| 99 |
+
if args.config:
|
| 100 |
+
config = AamDiffusionConfig.from_json(args.config)
|
| 101 |
+
logger.info("Loaded config from %s", args.config)
|
| 102 |
+
else:
|
| 103 |
+
config = get_default_config(args.model_size)
|
| 104 |
+
logger.info("Using %s model config", args.model_size)
|
| 105 |
+
|
| 106 |
+
# Apply CLI overrides
|
| 107 |
+
if args.output_dir:
|
| 108 |
+
config.output_dir = args.output_dir
|
| 109 |
+
if args.train_data:
|
| 110 |
+
config.training.train_data_path = args.train_data
|
| 111 |
+
if args.val_data:
|
| 112 |
+
config.training.val_data_path = args.val_data
|
| 113 |
+
if args.batch_size:
|
| 114 |
+
config.training.batch_size = args.batch_size
|
| 115 |
+
if args.learning_rate:
|
| 116 |
+
config.training.learning_rate = args.learning_rate
|
| 117 |
+
if args.max_steps:
|
| 118 |
+
config.training.max_steps = args.max_steps
|
| 119 |
+
if args.n_timesteps:
|
| 120 |
+
config.diffusion.n_timesteps = args.n_timesteps
|
| 121 |
+
config.seed = args.seed
|
| 122 |
+
|
| 123 |
+
# Print config summary
|
| 124 |
+
print(config.summary())
|
| 125 |
+
|
| 126 |
+
# Save config
|
| 127 |
+
config_path = Path(config.output_dir) / "config.json"
|
| 128 |
+
config.to_json(config_path)
|
| 129 |
+
logger.info("Config saved to %s", config_path)
|
| 130 |
+
|
| 131 |
+
# Step 1: Prepare data
|
| 132 |
+
pipeline = DataPipeline(config)
|
| 133 |
+
tokenizer, train_loader, val_loader = pipeline.prepare(
|
| 134 |
+
force_regenerate=args.force_regenerate,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Step 2: Create model
|
| 138 |
+
model = AamDiffusionModel(config)
|
| 139 |
+
logger.info(
|
| 140 |
+
"Model created: %s parameters",
|
| 141 |
+
model._format_params(model.get_num_params()),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Step 3: Create datasets (using pre-created loaders)
|
| 145 |
+
train_dataset = train_loader.dataset
|
| 146 |
+
val_dataset = val_loader.dataset if val_loader else None
|
| 147 |
+
|
| 148 |
+
# Step 4: Create trainer and train
|
| 149 |
+
trainer = AamTrainer(
|
| 150 |
+
config=config,
|
| 151 |
+
model=model,
|
| 152 |
+
tokenizer=tokenizer,
|
| 153 |
+
train_dataset=train_dataset,
|
| 154 |
+
val_dataset=val_dataset,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Override data loaders (already created by pipeline)
|
| 158 |
+
trainer.train_loader = train_loader
|
| 159 |
+
trainer.val_loader = val_loader
|
| 160 |
+
|
| 161 |
+
# Start training
|
| 162 |
+
trainer.train()
|
| 163 |
+
|
| 164 |
+
logger.info("Training complete! Output saved to %s", config.output_dir)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
diffusion_llm/scripts/train_final.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AAM Diffusion LLM — Final Training Script
|
| 4 |
+
|
| 5 |
+
Trains the complete AAM Diffusion LLM pipeline:
|
| 6 |
+
1. Generate synthetic training data (Graph→Narrative pairs)
|
| 7 |
+
2. Train the AAM Sentence-Level + BPE Tokenizer
|
| 8 |
+
3. Train the Diffusion Transformer model
|
| 9 |
+
4. Save final model, tokenizer, and config for HuggingFace upload
|
| 10 |
+
|
| 11 |
+
This is the "birth" of AAM's body — from random weights to
|
| 12 |
+
a model that can arrange sentences from graph conditioning.
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
python scripts/train_final.py --output_dir ./aam-diffusion-v1
|
| 16 |
+
python scripts/train_final.py --model_size tiny --max_steps 500
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
# Add parent directory to path
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
from diffusion_llm.config.model_config import (
|
| 35 |
+
AamDiffusionConfig, get_default_config, ModelConfig,
|
| 36 |
+
DiffusionConfig, GraphEncoderConfig, TokenizerConfig,
|
| 37 |
+
TrainingConfig, InferenceConfig,
|
| 38 |
+
)
|
| 39 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 40 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 41 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
|
| 42 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 47 |
+
)
|
| 48 |
+
logger = logging.getLogger("train_final")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_args():
|
| 52 |
+
parser = argparse.ArgumentParser(description="Train AAM Diffusion LLM (Final)")
|
| 53 |
+
parser.add_argument("--model_size", type=str, default="tiny",
|
| 54 |
+
choices=["tiny", "small", "base", "medium"])
|
| 55 |
+
parser.add_argument("--output_dir", type=str, default="./aam-diffusion-v1")
|
| 56 |
+
parser.add_argument("--max_steps", type=int, default=500)
|
| 57 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 58 |
+
parser.add_argument("--learning_rate", type=float, default=3e-4)
|
| 59 |
+
parser.add_argument("--n_synthetic_train", type=int, default=500)
|
| 60 |
+
parser.add_argument("--n_synthetic_val", type=int, default=50)
|
| 61 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 62 |
+
parser.add_argument("--log_every", type=int, default=50)
|
| 63 |
+
parser.add_argument("--save_every", type=int, default=500)
|
| 64 |
+
parser.add_argument("--eval_every", type=int, default=200)
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def set_seed(seed: int):
|
| 69 |
+
"""Set random seeds for reproducibility."""
|
| 70 |
+
torch.manual_seed(seed)
|
| 71 |
+
np.random.seed(seed)
|
| 72 |
+
import random
|
| 73 |
+
random.seed(seed)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_data(output_dir: Path, n_train: int, n_val: int, seed: int):
|
| 77 |
+
"""Generate synthetic training data."""
|
| 78 |
+
logger.info("=" * 60)
|
| 79 |
+
logger.info("STEP 1: Generating Synthetic Training Data")
|
| 80 |
+
logger.info("=" * 60)
|
| 81 |
+
|
| 82 |
+
data_dir = output_dir / "data"
|
| 83 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
train_path, val_path = SyntheticDataGenerator.generate_training_split(
|
| 86 |
+
output_dir=data_dir,
|
| 87 |
+
n_train=n_train,
|
| 88 |
+
n_val=n_val,
|
| 89 |
+
language="id",
|
| 90 |
+
seed=seed,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
logger.info(f" Train data: {train_path} ({n_train} examples)")
|
| 94 |
+
logger.info(f" Val data: {val_path} ({n_val} examples)")
|
| 95 |
+
return train_path, val_path
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def train_tokenizer(train_path: Path, output_dir: Path, config: AamDiffusionConfig) -> AamTokenizer:
|
| 99 |
+
"""Train the AAM Tokenizer on synthetic data."""
|
| 100 |
+
logger.info("=" * 60)
|
| 101 |
+
logger.info("STEP 2: Training AAM Sentence-Level + BPE Tokenizer")
|
| 102 |
+
logger.info("=" * 60)
|
| 103 |
+
|
| 104 |
+
tokenizer = AamTokenizer(config=config.tokenizer)
|
| 105 |
+
|
| 106 |
+
# Read training texts
|
| 107 |
+
texts = []
|
| 108 |
+
with open(train_path, "r", encoding="utf-8") as f:
|
| 109 |
+
for line in f:
|
| 110 |
+
line = line.strip()
|
| 111 |
+
if not line:
|
| 112 |
+
continue
|
| 113 |
+
try:
|
| 114 |
+
data = json.loads(line)
|
| 115 |
+
if data.get("narrative"):
|
| 116 |
+
texts.append(data["narrative"])
|
| 117 |
+
if data.get("trigger"):
|
| 118 |
+
texts.append(data["trigger"])
|
| 119 |
+
for ev in data.get("evidence_nodes", []):
|
| 120 |
+
texts.append(ev)
|
| 121 |
+
for anom in data.get("anomalies", []):
|
| 122 |
+
texts.append(anom)
|
| 123 |
+
for step in data.get("reasoning_steps", []):
|
| 124 |
+
texts.append(step)
|
| 125 |
+
for comp in data.get("compositions", []):
|
| 126 |
+
texts.append(comp)
|
| 127 |
+
except json.JSONDecodeError:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
logger.info(f" Training tokenizer on {len(texts)} texts...")
|
| 131 |
+
tokenizer.train(texts, vocab_size=config.tokenizer.bpe_vocab_size)
|
| 132 |
+
|
| 133 |
+
# Save tokenizer
|
| 134 |
+
tokenizer_path = output_dir / "tokenizer.json"
|
| 135 |
+
tokenizer.save(tokenizer_path)
|
| 136 |
+
logger.info(f" Tokenizer saved: {tokenizer_path}")
|
| 137 |
+
logger.info(f" Vocab size: {tokenizer.vocab_size}")
|
| 138 |
+
logger.info(f" BPE merges: {len(tokenizer.merges)}")
|
| 139 |
+
|
| 140 |
+
return tokenizer
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def create_dataloaders(
|
| 144 |
+
train_path: Path, val_path: Path,
|
| 145 |
+
tokenizer: AamTokenizer, config: AamDiffusionConfig
|
| 146 |
+
):
|
| 147 |
+
"""Create training and validation data loaders."""
|
| 148 |
+
logger.info("=" * 60)
|
| 149 |
+
logger.info("STEP 3: Creating DataLoaders")
|
| 150 |
+
logger.info("=" * 60)
|
| 151 |
+
|
| 152 |
+
train_dataset = GraphNarrativeDataset(
|
| 153 |
+
data_path=train_path,
|
| 154 |
+
tokenizer=tokenizer,
|
| 155 |
+
max_seq_len=config.model.max_seq_len,
|
| 156 |
+
max_evidence=config.graph_encoder.max_evidence_nodes,
|
| 157 |
+
max_anomalies=config.graph_encoder.max_anomalies,
|
| 158 |
+
max_reasoning=config.graph_encoder.max_reasoning_steps,
|
| 159 |
+
augment=True,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
val_dataset = GraphNarrativeDataset(
|
| 163 |
+
data_path=val_path,
|
| 164 |
+
tokenizer=tokenizer,
|
| 165 |
+
max_seq_len=config.model.max_seq_len,
|
| 166 |
+
max_evidence=config.graph_encoder.max_evidence_nodes,
|
| 167 |
+
max_anomalies=config.graph_encoder.max_anomalies,
|
| 168 |
+
max_reasoning=config.graph_encoder.max_reasoning_steps,
|
| 169 |
+
augment=False,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
from torch.utils.data import DataLoader
|
| 173 |
+
|
| 174 |
+
train_loader = DataLoader(
|
| 175 |
+
train_dataset,
|
| 176 |
+
batch_size=config.training.batch_size,
|
| 177 |
+
shuffle=True,
|
| 178 |
+
num_workers=0, # CPU training: use 0 workers
|
| 179 |
+
collate_fn=collate_fn,
|
| 180 |
+
pin_memory=False, # CPU: no pin_memory
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
val_loader = DataLoader(
|
| 184 |
+
val_dataset,
|
| 185 |
+
batch_size=config.training.batch_size,
|
| 186 |
+
shuffle=False,
|
| 187 |
+
num_workers=0,
|
| 188 |
+
collate_fn=collate_fn,
|
| 189 |
+
pin_memory=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
logger.info(f" Train: {len(train_dataset)} examples, {len(train_loader)} batches")
|
| 193 |
+
logger.info(f" Val: {len(val_dataset)} examples, {len(val_loader)} batches")
|
| 194 |
+
|
| 195 |
+
return train_loader, val_loader
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def train_model(
|
| 199 |
+
model: AamDiffusionModel,
|
| 200 |
+
tokenizer: AamTokenizer,
|
| 201 |
+
train_loader,
|
| 202 |
+
val_loader,
|
| 203 |
+
config: AamDiffusionConfig,
|
| 204 |
+
output_dir: Path,
|
| 205 |
+
args,
|
| 206 |
+
):
|
| 207 |
+
"""Train the AAM Diffusion Model."""
|
| 208 |
+
logger.info("=" * 60)
|
| 209 |
+
logger.info("STEP 4: Training AAM Diffusion LLM")
|
| 210 |
+
logger.info("=" * 60)
|
| 211 |
+
|
| 212 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 213 |
+
logger.info(f" Device: {device}")
|
| 214 |
+
logger.info(f" Parameters: {model._format_params(model.get_num_params())}")
|
| 215 |
+
|
| 216 |
+
model.to(device)
|
| 217 |
+
|
| 218 |
+
# Optimizer
|
| 219 |
+
optimizer = torch.optim.AdamW(
|
| 220 |
+
model.parameters(),
|
| 221 |
+
lr=args.learning_rate,
|
| 222 |
+
weight_decay=config.training.weight_decay,
|
| 223 |
+
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# LR scheduler with warmup
|
| 227 |
+
warmup_steps = min(200, args.max_steps // 10)
|
| 228 |
+
|
| 229 |
+
def lr_lambda(step):
|
| 230 |
+
if step < warmup_steps:
|
| 231 |
+
return step / max(warmup_steps, 1)
|
| 232 |
+
progress = (step - warmup_steps) / max(args.max_steps - warmup_steps, 1)
|
| 233 |
+
return 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.14159)).item())
|
| 234 |
+
|
| 235 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 236 |
+
|
| 237 |
+
# Training loop
|
| 238 |
+
global_step = 0
|
| 239 |
+
best_val_loss = float("inf")
|
| 240 |
+
train_losses = []
|
| 241 |
+
start_time = time.time()
|
| 242 |
+
|
| 243 |
+
logger.info(f" Max steps: {args.max_steps}")
|
| 244 |
+
logger.info(f" Batch size: {args.batch_size}")
|
| 245 |
+
logger.info(f" Learning rate: {args.learning_rate}")
|
| 246 |
+
logger.info(f" Warmup steps: {warmup_steps}")
|
| 247 |
+
logger.info("")
|
| 248 |
+
|
| 249 |
+
epoch = 0
|
| 250 |
+
while global_step < args.max_steps:
|
| 251 |
+
epoch += 1
|
| 252 |
+
model.train()
|
| 253 |
+
epoch_loss = 0.0
|
| 254 |
+
n_batches = 0
|
| 255 |
+
|
| 256 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 257 |
+
if global_step >= args.max_steps:
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
# Move batch to device
|
| 261 |
+
batch = {
|
| 262 |
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 263 |
+
for k, v in batch.items()
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# Sample random timesteps
|
| 267 |
+
batch_size = batch["token_ids"].shape[0]
|
| 268 |
+
t = torch.randint(
|
| 269 |
+
0, config.diffusion.n_timesteps,
|
| 270 |
+
(batch_size,), device=device,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Forward pass
|
| 274 |
+
predicted, target = model(
|
| 275 |
+
token_ids=batch["token_ids"],
|
| 276 |
+
timestep=t,
|
| 277 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 278 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 279 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 280 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 281 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 282 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 283 |
+
source_trust=batch.get("source_trust"),
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Compute loss
|
| 287 |
+
loss = model.compute_loss(predicted, target, t)
|
| 288 |
+
|
| 289 |
+
# Backward pass
|
| 290 |
+
optimizer.zero_grad()
|
| 291 |
+
loss.backward()
|
| 292 |
+
|
| 293 |
+
# Gradient clipping
|
| 294 |
+
torch.nn.utils.clip_grad_norm_(
|
| 295 |
+
model.parameters(), config.training.grad_clip_norm
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
optimizer.step()
|
| 299 |
+
scheduler.step()
|
| 300 |
+
|
| 301 |
+
loss_val = loss.item()
|
| 302 |
+
train_losses.append(loss_val)
|
| 303 |
+
epoch_loss += loss_val
|
| 304 |
+
n_batches += 1
|
| 305 |
+
global_step += 1
|
| 306 |
+
|
| 307 |
+
# Logging
|
| 308 |
+
if global_step % args.log_every == 0:
|
| 309 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 310 |
+
avg_loss = sum(train_losses[-args.log_every:]) / len(train_losses[-args.log_every:])
|
| 311 |
+
elapsed = time.time() - start_time
|
| 312 |
+
steps_per_sec = global_step / max(elapsed, 1)
|
| 313 |
+
logger.info(
|
| 314 |
+
f" Step {global_step:>6d}/{args.max_steps} | "
|
| 315 |
+
f"Loss: {avg_loss:.4f} | "
|
| 316 |
+
f"LR: {lr:.2e} | "
|
| 317 |
+
f"Speed: {steps_per_sec:.1f} steps/s"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Evaluation
|
| 321 |
+
if global_step % args.eval_every == 0 and val_loader is not None:
|
| 322 |
+
val_loss = evaluate(model, val_loader, config, device)
|
| 323 |
+
logger.info(f" >>> Validation loss: {val_loss:.4f}")
|
| 324 |
+
if val_loss < best_val_loss:
|
| 325 |
+
best_val_loss = val_loss
|
| 326 |
+
save_model(model, tokenizer, config, output_dir / "best.pt")
|
| 327 |
+
logger.info(f" >>> New best model saved! (val_loss: {val_loss:.4f})")
|
| 328 |
+
|
| 329 |
+
# Checkpoint
|
| 330 |
+
if global_step % args.save_every == 0:
|
| 331 |
+
save_model(model, tokenizer, config, output_dir / f"step_{global_step}.pt")
|
| 332 |
+
|
| 333 |
+
avg_epoch_loss = epoch_loss / max(n_batches, 1)
|
| 334 |
+
logger.info(f" Epoch {epoch} complete. Avg loss: {avg_epoch_loss:.4f}")
|
| 335 |
+
|
| 336 |
+
# Final save
|
| 337 |
+
save_model(model, tokenizer, config, output_dir / "final.pt")
|
| 338 |
+
elapsed = time.time() - start_time
|
| 339 |
+
logger.info("")
|
| 340 |
+
logger.info(f" Training complete! {global_step} steps in {elapsed/60:.1f} minutes")
|
| 341 |
+
logger.info(f" Best val loss: {best_val_loss:.4f}")
|
| 342 |
+
logger.info(f" Final train loss: {train_losses[-1]:.4f}")
|
| 343 |
+
|
| 344 |
+
return model
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def evaluate(model, val_loader, config, device):
|
| 348 |
+
"""Evaluate on validation set."""
|
| 349 |
+
model.eval()
|
| 350 |
+
total_loss = 0.0
|
| 351 |
+
n_batches = 0
|
| 352 |
+
|
| 353 |
+
with torch.no_grad():
|
| 354 |
+
for batch in val_loader:
|
| 355 |
+
batch = {
|
| 356 |
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 357 |
+
for k, v in batch.items()
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
batch_size = batch["token_ids"].shape[0]
|
| 361 |
+
t = torch.randint(
|
| 362 |
+
0, config.diffusion.n_timesteps,
|
| 363 |
+
(batch_size,), device=device,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
predicted, target = model(
|
| 367 |
+
token_ids=batch["token_ids"],
|
| 368 |
+
timestep=t,
|
| 369 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 370 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 371 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 372 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 373 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 374 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 375 |
+
source_trust=batch.get("source_trust"),
|
| 376 |
+
)
|
| 377 |
+
loss = model.compute_loss(predicted, target, t)
|
| 378 |
+
total_loss += loss.item()
|
| 379 |
+
n_batches += 1
|
| 380 |
+
|
| 381 |
+
model.train()
|
| 382 |
+
return total_loss / max(n_batches, 1)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def save_model(model, tokenizer, config, path):
|
| 386 |
+
"""Save model checkpoint with tokenizer."""
|
| 387 |
+
path = Path(path)
|
| 388 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 389 |
+
|
| 390 |
+
checkpoint = {
|
| 391 |
+
"model_state_dict": model.state_dict(),
|
| 392 |
+
"config": config.to_dict(),
|
| 393 |
+
}
|
| 394 |
+
torch.save(checkpoint, path)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def export_for_huggingface(model, tokenizer, config, output_dir: Path):
|
| 398 |
+
"""Export model in HuggingFace-compatible format."""
|
| 399 |
+
logger.info("=" * 60)
|
| 400 |
+
logger.info("STEP 5: Exporting for HuggingFace")
|
| 401 |
+
logger.info("=" * 60)
|
| 402 |
+
|
| 403 |
+
hf_dir = output_dir / "huggingface"
|
| 404 |
+
hf_dir.mkdir(parents=True, exist_ok=True)
|
| 405 |
+
|
| 406 |
+
# Save model weights
|
| 407 |
+
model_path = hf_dir / "model.pt"
|
| 408 |
+
model.save(str(model_path))
|
| 409 |
+
logger.info(f" Model saved: {model_path}")
|
| 410 |
+
|
| 411 |
+
# Save tokenizer
|
| 412 |
+
tokenizer_path = hf_dir / "tokenizer.json"
|
| 413 |
+
tokenizer.save(tokenizer_path)
|
| 414 |
+
logger.info(f" Tokenizer saved: {tokenizer_path}")
|
| 415 |
+
|
| 416 |
+
# Save config
|
| 417 |
+
config_path = hf_dir / "config.json"
|
| 418 |
+
config.to_json(config_path)
|
| 419 |
+
logger.info(f" Config saved: {config_path}")
|
| 420 |
+
|
| 421 |
+
# Save model card
|
| 422 |
+
model_card = f"""---
|
| 423 |
+
language:
|
| 424 |
+
- id
|
| 425 |
+
- en
|
| 426 |
+
license: mit
|
| 427 |
+
library_name: pytorch
|
| 428 |
+
tags:
|
| 429 |
+
- diffusion
|
| 430 |
+
- text-generation
|
| 431 |
+
- aam
|
| 432 |
+
- aphantasic-abstraction-model
|
| 433 |
+
- sentence-arrangement
|
| 434 |
+
- graph-conditioned
|
| 435 |
+
---
|
| 436 |
+
|
| 437 |
+
# AAM Diffusion LLM v1.0
|
| 438 |
+
|
| 439 |
+
> **"AAM = 1 Pikiran + 1 Tubuh" (1 Mind + 1 Body)**
|
| 440 |
+
|
| 441 |
+
The dedicated "body" of the Aphantasic Abstraction Model (AAM) — a small diffusion LLM specifically trained to arrange sentences from structured graph data.
|
| 442 |
+
|
| 443 |
+
## What is this?
|
| 444 |
+
|
| 445 |
+
This is NOT a general-purpose LLM. This is a SPECIALIZED sentence composer that:
|
| 446 |
+
- Takes **graph-structured conditioning** as input (evidence, anomalies, reasoning chains, confidence scores)
|
| 447 |
+
- Produces **coherent natural language narratives** through iterative denoising
|
| 448 |
+
- **Cannot hallucinate** — it can only narrate what the graph knows
|
| 449 |
+
|
| 450 |
+
## Architecture
|
| 451 |
+
|
| 452 |
+
```
|
| 453 |
+
Graph Conditioning Encoder → Diffusion Transformer → Noise Scheduler
|
| 454 |
+
(Mind input) (The Body) (Iterative refinement)
|
| 455 |
+
```
|
| 456 |
+
|
| 457 |
+
### Key Components
|
| 458 |
+
- **Graph Conditioning Encoder**: Encodes evidence nodes, compositions, anomalies, reasoning chains with confidence and temporal embeddings
|
| 459 |
+
- **Diffusion Transformer**: Core denoising network with adaptive layer norm, self-attention, and cross-attention to graph conditioning
|
| 460 |
+
- **Noise Scheduler**: Cosine noise schedule with DDPM/DDIM sampling support
|
| 461 |
+
|
| 462 |
+
## Model Details
|
| 463 |
+
|
| 464 |
+
| Parameter | Value |
|
| 465 |
+
|-----------|-------|
|
| 466 |
+
| Architecture | Diffusion Transformer |
|
| 467 |
+
| d_model | {config.model.d_model} |
|
| 468 |
+
| n_layers | {config.model.n_layers} |
|
| 469 |
+
| n_heads | {config.model.n_heads} |
|
| 470 |
+
| d_ff | {config.model.d_ff} |
|
| 471 |
+
| Parameters | {model._format_params(model.get_num_params())} |
|
| 472 |
+
| Vocab size | {config.model.vocab_size} |
|
| 473 |
+
| Max sequence length | {config.model.max_seq_len} |
|
| 474 |
+
| Diffusion timesteps (train) | {config.diffusion.n_timesteps} |
|
| 475 |
+
| Diffusion timesteps (inference) | {config.diffusion.n_inference_steps} |
|
| 476 |
+
| Noise schedule | {config.diffusion.schedule_type} |
|
| 477 |
+
| Prediction type | {config.diffusion.prediction_type} |
|
| 478 |
+
| Sampling method | {config.diffusion.sampling_method} |
|
| 479 |
+
|
| 480 |
+
## Usage
|
| 481 |
+
|
| 482 |
+
```python
|
| 483 |
+
from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
|
| 484 |
+
|
| 485 |
+
# Load model
|
| 486 |
+
config = AamDiffusionConfig.from_json("config.json")
|
| 487 |
+
model = AamDiffusionModel.load("model.pt")
|
| 488 |
+
tokenizer = AamTokenizer.load("tokenizer.json")
|
| 489 |
+
|
| 490 |
+
# Create generator
|
| 491 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 492 |
+
|
| 493 |
+
# Generate narrative from graph conditioning
|
| 494 |
+
result = generator.generate(
|
| 495 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 496 |
+
evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 497 |
+
anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 498 |
+
reasoning_steps=["Cross-reference tanggal kejadian"],
|
| 499 |
+
source_trust=0.85,
|
| 500 |
+
)
|
| 501 |
+
print(result.narrative)
|
| 502 |
+
```
|
| 503 |
+
|
| 504 |
+
## Philosophy
|
| 505 |
+
|
| 506 |
+
**AAM = 1 Mind + 1 Body**
|
| 507 |
+
|
| 508 |
+
- **Mind** = RSVS Knowledge Graph (structural memory, perfect recall, relational understanding)
|
| 509 |
+
- **Body** = This Diffusion LLM (sentence arranger, graph-conditioned, anti-hallucination)
|
| 510 |
+
|
| 511 |
+
Unlike using a rented LLM (GPT, Claude) as the "body", this model is specifically trained for AAM:
|
| 512 |
+
- It cannot generate information not present in the graph conditioning
|
| 513 |
+
- It arranges sentences based on structured evidence
|
| 514 |
+
- It uses diffusion (non-sequential generation) instead of autoregressive generation
|
| 515 |
+
- It is small ({model._format_params(model.get_num_params())}) but specialized
|
| 516 |
+
|
| 517 |
+
## Training
|
| 518 |
+
|
| 519 |
+
Trained on synthetic Graph→Narrative pairs with:
|
| 520 |
+
- Indonesian and English narrative templates
|
| 521 |
+
- Evidence nodes, anomalies, reasoning chains
|
| 522 |
+
- Confidence score distributions
|
| 523 |
+
- Source trust scores
|
| 524 |
+
|
| 525 |
+
## License
|
| 526 |
+
|
| 527 |
+
MIT
|
| 528 |
+
"""
|
| 529 |
+
model_card_path = hf_dir / "README.md"
|
| 530 |
+
with open(model_card_path, "w", encoding="utf-8") as f:
|
| 531 |
+
f.write(model_card)
|
| 532 |
+
logger.info(f" Model card saved: {model_card_path}")
|
| 533 |
+
|
| 534 |
+
# Copy full framework code
|
| 535 |
+
import shutil
|
| 536 |
+
framework_src = Path(__file__).parent.parent # diffusion_llm/
|
| 537 |
+
framework_dst = hf_dir / "diffusion_llm"
|
| 538 |
+
if framework_dst.exists():
|
| 539 |
+
shutil.rmtree(framework_dst)
|
| 540 |
+
shutil.copytree(framework_src, framework_dst,
|
| 541 |
+
ignore=shutil.ignore_patterns('__pycache__', '*.pyc', 'output', 'data'))
|
| 542 |
+
logger.info(f" Framework code copied to: {framework_dst}")
|
| 543 |
+
|
| 544 |
+
# Save training script
|
| 545 |
+
train_script_dst = hf_dir / "train.py"
|
| 546 |
+
shutil.copy2(Path(__file__), train_script_dst)
|
| 547 |
+
|
| 548 |
+
# Save inference example
|
| 549 |
+
inference_example = hf_dir / "inference_example.py"
|
| 550 |
+
with open(inference_example, "w", encoding="utf-8") as f:
|
| 551 |
+
f.write('''#!/usr/bin/env python3
|
| 552 |
+
"""AAM Diffusion LLM — Inference Example"""
|
| 553 |
+
|
| 554 |
+
import sys
|
| 555 |
+
from pathlib import Path
|
| 556 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 557 |
+
|
| 558 |
+
import torch
|
| 559 |
+
from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
|
| 560 |
+
|
| 561 |
+
def main():
|
| 562 |
+
# Load model and tokenizer
|
| 563 |
+
config = AamDiffusionConfig.from_json("config.json")
|
| 564 |
+
model = AamDiffusionModel.load("model.pt", device="cpu")
|
| 565 |
+
tokenizer = AamTokenizer.load("tokenizer.json")
|
| 566 |
+
|
| 567 |
+
# Create generator
|
| 568 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 569 |
+
|
| 570 |
+
# Generate narrative
|
| 571 |
+
result = generator.generate(
|
| 572 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 573 |
+
evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 574 |
+
anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 575 |
+
reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
|
| 576 |
+
source_trust=0.85,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
print("=" * 60)
|
| 580 |
+
print(" AAM Diffusion LLM — Generated Narrative")
|
| 581 |
+
print("=" * 60)
|
| 582 |
+
print(f" Trigger: {result.evidence_used}")
|
| 583 |
+
print(f" Narrative: {result.narrative}")
|
| 584 |
+
print(f" Confidence: {result.confidence:.1%}")
|
| 585 |
+
print(f" Steps: {result.n_diffusion_steps}")
|
| 586 |
+
print(f" Time: {result.generation_time_s:.2f}s")
|
| 587 |
+
|
| 588 |
+
if __name__ == "__main__":
|
| 589 |
+
main()
|
| 590 |
+
''')
|
| 591 |
+
logger.info(f" Inference example saved: {inference_example}")
|
| 592 |
+
|
| 593 |
+
logger.info(f"\n HuggingFace export complete: {hf_dir}")
|
| 594 |
+
return hf_dir
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def main():
|
| 598 |
+
args = parse_args()
|
| 599 |
+
set_seed(args.seed)
|
| 600 |
+
|
| 601 |
+
output_dir = Path(args.output_dir)
|
| 602 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 603 |
+
|
| 604 |
+
print("=" * 60)
|
| 605 |
+
print(" AAM Diffusion LLM — Final Training")
|
| 606 |
+
print(" \"1 Pikiran + 1 Tubuh\" (1 Mind + 1 Body)")
|
| 607 |
+
print("=" * 60)
|
| 608 |
+
print()
|
| 609 |
+
|
| 610 |
+
# Get config
|
| 611 |
+
config = get_default_config(args.model_size)
|
| 612 |
+
|
| 613 |
+
# CPU-optimized overrides for faster training
|
| 614 |
+
config.model.max_seq_len = 128
|
| 615 |
+
config.model.vocab_size = 8000
|
| 616 |
+
config.graph_encoder.max_evidence_nodes = 10
|
| 617 |
+
config.graph_encoder.max_anomalies = 5
|
| 618 |
+
config.graph_encoder.max_reasoning_steps = 5
|
| 619 |
+
config.graph_encoder.max_compositions = 5
|
| 620 |
+
config.diffusion.n_timesteps = 200
|
| 621 |
+
config.diffusion.n_inference_steps = 20
|
| 622 |
+
config.tokenizer.bpe_vocab_size = 8000 - 13 # minus special tokens
|
| 623 |
+
|
| 624 |
+
# Override settings for CPU training
|
| 625 |
+
config.training.batch_size = args.batch_size
|
| 626 |
+
config.training.learning_rate = args.learning_rate
|
| 627 |
+
config.training.max_steps = args.max_steps
|
| 628 |
+
config.training.use_amp = False # No AMP on CPU
|
| 629 |
+
config.training.num_workers = 0 # No multiprocessing on CPU
|
| 630 |
+
config.training.warmup_steps = min(100, args.max_steps // 5)
|
| 631 |
+
config.output_dir = str(output_dir)
|
| 632 |
+
config.seed = args.seed
|
| 633 |
+
config.model_name = "aam-diffusion-v1.0"
|
| 634 |
+
|
| 635 |
+
# Print config
|
| 636 |
+
print(config.summary())
|
| 637 |
+
|
| 638 |
+
# Step 1: Generate synthetic data
|
| 639 |
+
train_path, val_path = generate_data(
|
| 640 |
+
output_dir, args.n_synthetic_train, args.n_synthetic_val, args.seed
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
# Step 2: Train tokenizer
|
| 644 |
+
tokenizer = train_tokenizer(train_path, output_dir, config)
|
| 645 |
+
|
| 646 |
+
# Update vocab_size to match actual tokenizer
|
| 647 |
+
actual_vocab = tokenizer.vocab_size
|
| 648 |
+
if actual_vocab != config.model.vocab_size:
|
| 649 |
+
logger.info(f" Updating vocab_size: {config.model.vocab_size} → {actual_vocab}")
|
| 650 |
+
config.model.vocab_size = actual_vocab
|
| 651 |
+
|
| 652 |
+
# Step 3: Create dataloaders
|
| 653 |
+
train_loader, val_loader = create_dataloaders(
|
| 654 |
+
train_path, val_path, tokenizer, config
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Step 4: Create and train model
|
| 658 |
+
model = AamDiffusionModel(config)
|
| 659 |
+
logger.info(f" Model parameters: {model._format_params(model.get_num_params())}")
|
| 660 |
+
|
| 661 |
+
model = train_model(
|
| 662 |
+
model, tokenizer, train_loader, val_loader,
|
| 663 |
+
config, output_dir, args
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# Step 5: Export for HuggingFace
|
| 667 |
+
hf_dir = export_for_huggingface(model, tokenizer, config, output_dir)
|
| 668 |
+
|
| 669 |
+
# Final summary
|
| 670 |
+
print()
|
| 671 |
+
print("=" * 60)
|
| 672 |
+
print(" TRAINING COMPLETE!")
|
| 673 |
+
print("=" * 60)
|
| 674 |
+
print(f" Model: {config.model_name}")
|
| 675 |
+
print(f" Parameters: {model._format_params(model.get_num_params())}")
|
| 676 |
+
print(f" Output: {output_dir}")
|
| 677 |
+
print(f" HuggingFace export: {hf_dir}")
|
| 678 |
+
print()
|
| 679 |
+
print(" AAM = 1 Pikiran + 1 Tubuh")
|
| 680 |
+
print(" Pikiran = RSVS Knowledge Graph")
|
| 681 |
+
print(" Tubuh = This Diffusion LLM")
|
| 682 |
+
print("=" * 60)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
if __name__ == "__main__":
|
| 686 |
+
main()
|
diffusion_llm/scripts/train_minimal.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AAM Diffusion LLM — Minimal Training Script for CPU
|
| 4 |
+
|
| 5 |
+
Trains a very small AAM Diffusion LLM model on CPU.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 20 |
+
logger = logging.getLogger("train")
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
from diffusion_llm.config.model_config import (
|
| 24 |
+
AamDiffusionConfig, ModelConfig, DiffusionConfig,
|
| 25 |
+
GraphEncoderConfig, TokenizerConfig, TrainingConfig, InferenceConfig,
|
| 26 |
+
)
|
| 27 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 28 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 29 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
|
| 30 |
+
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
|
| 31 |
+
from torch.utils.data import DataLoader
|
| 32 |
+
|
| 33 |
+
output_dir = Path("./aam-diffusion-v1")
|
| 34 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
data_dir = output_dir / "data"
|
| 36 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# ===== STEP 1: Generate Data =====
|
| 39 |
+
logger.info("STEP 1: Generating synthetic data...")
|
| 40 |
+
train_path, val_path = SyntheticDataGenerator.generate_training_split(
|
| 41 |
+
output_dir=data_dir, n_train=200, n_val=20, language="id", seed=42,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# ===== STEP 2: Train Tokenizer =====
|
| 45 |
+
logger.info("STEP 2: Training tokenizer...")
|
| 46 |
+
tokenizer = AamTokenizer()
|
| 47 |
+
|
| 48 |
+
texts = []
|
| 49 |
+
with open(train_path, "r", encoding="utf-8") as f:
|
| 50 |
+
for line in f:
|
| 51 |
+
line = line.strip()
|
| 52 |
+
if not line:
|
| 53 |
+
continue
|
| 54 |
+
try:
|
| 55 |
+
data = json.loads(line)
|
| 56 |
+
for key in ["narrative", "trigger"]:
|
| 57 |
+
if data.get(key):
|
| 58 |
+
texts.append(data[key])
|
| 59 |
+
for key in ["evidence_nodes", "anomalies", "reasoning_steps"]:
|
| 60 |
+
for item in data.get(key, []):
|
| 61 |
+
texts.append(item)
|
| 62 |
+
except json.JSONDecodeError:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
tokenizer.train(texts, vocab_size=2000)
|
| 66 |
+
tokenizer.save(data_dir / "tokenizer.json")
|
| 67 |
+
actual_vocab = tokenizer.vocab_size
|
| 68 |
+
logger.info(f" Tokenizer: vocab_size={actual_vocab}, merges={len(tokenizer.merges)}")
|
| 69 |
+
|
| 70 |
+
# ===== STEP 3: Config =====
|
| 71 |
+
config = AamDiffusionConfig(
|
| 72 |
+
model=ModelConfig(
|
| 73 |
+
d_model=128,
|
| 74 |
+
n_layers=2,
|
| 75 |
+
n_heads=4,
|
| 76 |
+
d_ff=256,
|
| 77 |
+
vocab_size=actual_vocab,
|
| 78 |
+
max_seq_len=64,
|
| 79 |
+
pos_encoding_type="learned",
|
| 80 |
+
use_flash_attention=False,
|
| 81 |
+
norm_type="layernorm",
|
| 82 |
+
init_std=0.02,
|
| 83 |
+
),
|
| 84 |
+
diffusion=DiffusionConfig(
|
| 85 |
+
n_timesteps=100,
|
| 86 |
+
n_inference_steps=10,
|
| 87 |
+
schedule_type="cosine",
|
| 88 |
+
prediction_type="epsilon",
|
| 89 |
+
loss_type="mse",
|
| 90 |
+
loss_weighting="none",
|
| 91 |
+
),
|
| 92 |
+
graph_encoder=GraphEncoderConfig(
|
| 93 |
+
d_graph=64,
|
| 94 |
+
n_graph_layers=1,
|
| 95 |
+
n_graph_heads=2,
|
| 96 |
+
max_evidence_nodes=5,
|
| 97 |
+
max_compositions=3,
|
| 98 |
+
max_anomalies=3,
|
| 99 |
+
max_reasoning_steps=3,
|
| 100 |
+
conditioning_method="cross_attention",
|
| 101 |
+
embed_confidence=False,
|
| 102 |
+
embed_temporal=False,
|
| 103 |
+
),
|
| 104 |
+
tokenizer=TokenizerConfig(bpe_vocab_size=2000),
|
| 105 |
+
training=TrainingConfig(
|
| 106 |
+
batch_size=4,
|
| 107 |
+
learning_rate=1e-3,
|
| 108 |
+
max_steps=100,
|
| 109 |
+
warmup_steps=10,
|
| 110 |
+
use_amp=False,
|
| 111 |
+
num_workers=0,
|
| 112 |
+
grad_clip_norm=1.0,
|
| 113 |
+
),
|
| 114 |
+
inference=InferenceConfig(n_steps=10),
|
| 115 |
+
model_name="aam-diffusion-v1.0",
|
| 116 |
+
output_dir=str(output_dir),
|
| 117 |
+
seed=42,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ===== STEP 4: Create Model =====
|
| 121 |
+
logger.info("STEP 3: Creating model...")
|
| 122 |
+
model = AamDiffusionModel(config)
|
| 123 |
+
n_params = model.get_num_params()
|
| 124 |
+
logger.info(f" Parameters: {model._format_params(n_params)} ({n_params:,})")
|
| 125 |
+
|
| 126 |
+
# ===== STEP 5: Create DataLoaders =====
|
| 127 |
+
logger.info("STEP 4: Creating dataloaders...")
|
| 128 |
+
train_dataset = GraphNarrativeDataset(
|
| 129 |
+
data_path=train_path, tokenizer=tokenizer,
|
| 130 |
+
max_seq_len=config.model.max_seq_len,
|
| 131 |
+
max_evidence=config.graph_encoder.max_evidence_nodes,
|
| 132 |
+
max_anomalies=config.graph_encoder.max_anomalies,
|
| 133 |
+
max_reasoning=config.graph_encoder.max_reasoning_steps,
|
| 134 |
+
augment=True,
|
| 135 |
+
)
|
| 136 |
+
val_dataset = GraphNarrativeDataset(
|
| 137 |
+
data_path=val_path, tokenizer=tokenizer,
|
| 138 |
+
max_seq_len=config.model.max_seq_len,
|
| 139 |
+
max_evidence=config.graph_encoder.max_evidence_nodes,
|
| 140 |
+
max_anomalies=config.graph_encoder.max_anomalies,
|
| 141 |
+
max_reasoning=config.graph_encoder.max_reasoning_steps,
|
| 142 |
+
augment=False,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
train_loader = DataLoader(
|
| 146 |
+
train_dataset, batch_size=4, shuffle=True,
|
| 147 |
+
num_workers=0, collate_fn=collate_fn,
|
| 148 |
+
)
|
| 149 |
+
val_loader = DataLoader(
|
| 150 |
+
val_dataset, batch_size=4, shuffle=False,
|
| 151 |
+
num_workers=0, collate_fn=collate_fn,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# ===== STEP 6: Train =====
|
| 155 |
+
logger.info("STEP 5: Training...")
|
| 156 |
+
device = torch.device("cpu")
|
| 157 |
+
model.to(device)
|
| 158 |
+
|
| 159 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
| 160 |
+
max_steps = 100
|
| 161 |
+
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
global_step = 0
|
| 164 |
+
train_losses = []
|
| 165 |
+
|
| 166 |
+
for epoch in range(50): # Max epochs
|
| 167 |
+
model.train()
|
| 168 |
+
for batch in train_loader:
|
| 169 |
+
if global_step >= max_steps:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 173 |
+
for k, v in batch.items()}
|
| 174 |
+
|
| 175 |
+
batch_size = batch["token_ids"].shape[0]
|
| 176 |
+
t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
|
| 177 |
+
|
| 178 |
+
predicted, target = model(
|
| 179 |
+
token_ids=batch["token_ids"],
|
| 180 |
+
timestep=t,
|
| 181 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 182 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 183 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 184 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 185 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 186 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 187 |
+
source_trust=batch.get("source_trust"),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
loss = model.compute_loss(predicted, target, t)
|
| 191 |
+
optimizer.zero_grad()
|
| 192 |
+
loss.backward()
|
| 193 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 194 |
+
optimizer.step()
|
| 195 |
+
|
| 196 |
+
train_losses.append(loss.item())
|
| 197 |
+
global_step += 1
|
| 198 |
+
|
| 199 |
+
if global_step % 10 == 0:
|
| 200 |
+
avg = sum(train_losses[-10:]) / len(train_losses[-10:])
|
| 201 |
+
elapsed = time.time() - start_time
|
| 202 |
+
logger.info(f" Step {global_step}/{max_steps} | Loss: {avg:.4f} | Time: {elapsed:.1f}s")
|
| 203 |
+
|
| 204 |
+
if global_step >= max_steps:
|
| 205 |
+
break
|
| 206 |
+
|
| 207 |
+
# ===== STEP 7: Evaluate =====
|
| 208 |
+
logger.info("STEP 6: Evaluating...")
|
| 209 |
+
model.eval()
|
| 210 |
+
val_losses = []
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
for batch in val_loader:
|
| 213 |
+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 214 |
+
for k, v in batch.items()}
|
| 215 |
+
batch_size = batch["token_ids"].shape[0]
|
| 216 |
+
t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
|
| 217 |
+
|
| 218 |
+
predicted, target = model(
|
| 219 |
+
token_ids=batch["token_ids"],
|
| 220 |
+
timestep=t,
|
| 221 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 222 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 223 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 224 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 225 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 226 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 227 |
+
source_trust=batch.get("source_trust"),
|
| 228 |
+
)
|
| 229 |
+
loss = model.compute_loss(predicted, target, t)
|
| 230 |
+
val_losses.append(loss.item())
|
| 231 |
+
|
| 232 |
+
avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else 0
|
| 233 |
+
logger.info(f" Val loss: {avg_val_loss:.4f}")
|
| 234 |
+
|
| 235 |
+
# ===== STEP 8: Save =====
|
| 236 |
+
logger.info("STEP 7: Saving model...")
|
| 237 |
+
|
| 238 |
+
# Save model
|
| 239 |
+
model_path = output_dir / "model.pt"
|
| 240 |
+
torch.save({
|
| 241 |
+
"model_state_dict": model.state_dict(),
|
| 242 |
+
"config": config.to_dict(),
|
| 243 |
+
}, model_path)
|
| 244 |
+
|
| 245 |
+
# Save tokenizer (already saved)
|
| 246 |
+
# Save config
|
| 247 |
+
config.to_json(output_dir / "config.json")
|
| 248 |
+
|
| 249 |
+
elapsed = time.time() - start_time
|
| 250 |
+
logger.info(f"\n DONE! {global_step} steps in {elapsed:.1f}s")
|
| 251 |
+
logger.info(f" Final train loss: {train_losses[-1]:.4f}")
|
| 252 |
+
logger.info(f" Val loss: {avg_val_loss:.4f}")
|
| 253 |
+
logger.info(f" Parameters: {model._format_params(n_params)}")
|
| 254 |
+
logger.info(f" Output: {output_dir}")
|
| 255 |
+
|
| 256 |
+
return model, tokenizer, config, output_dir
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
main()
|
diffusion_llm/tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Tests for AAM Diffusion LLM framework."""
|
diffusion_llm/tests/test_model.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for AAM Diffusion Model components."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import pytest
|
| 5 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config, ModelConfig
|
| 6 |
+
from diffusion_llm.model.noise_scheduler import NoiseScheduler
|
| 7 |
+
from diffusion_llm.model.graph_encoder import GraphConditioningEncoder, GraphEncoderConfig
|
| 8 |
+
from diffusion_llm.model.diffusion_transformer import DiffusionTransformer
|
| 9 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 10 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestConfig:
|
| 14 |
+
"""Test configuration system."""
|
| 15 |
+
|
| 16 |
+
def test_default_config(self):
|
| 17 |
+
"""Test default configuration creation."""
|
| 18 |
+
config = get_default_config("base")
|
| 19 |
+
assert config.model.d_model == 768
|
| 20 |
+
assert config.model.n_layers == 12
|
| 21 |
+
assert config.diffusion.n_timesteps == 1000
|
| 22 |
+
|
| 23 |
+
def test_tiny_config(self):
|
| 24 |
+
"""Test tiny model configuration."""
|
| 25 |
+
config = get_default_config("tiny")
|
| 26 |
+
assert config.model.d_model == 256
|
| 27 |
+
assert config.model.n_layers == 4
|
| 28 |
+
|
| 29 |
+
def test_config_serialization(self, tmp_path):
|
| 30 |
+
"""Test config save/load roundtrip."""
|
| 31 |
+
config = get_default_config("small")
|
| 32 |
+
path = tmp_path / "config.json"
|
| 33 |
+
config.to_json(path)
|
| 34 |
+
|
| 35 |
+
loaded = AamDiffusionConfig.from_json(path)
|
| 36 |
+
assert loaded.model.d_model == config.model.d_model
|
| 37 |
+
assert loaded.model.n_layers == config.model.n_layers
|
| 38 |
+
|
| 39 |
+
def test_param_estimation(self):
|
| 40 |
+
"""Test parameter count estimation."""
|
| 41 |
+
config = ModelConfig(d_model=768, n_layers=12, d_ff=3072)
|
| 42 |
+
params = config.estimate_params()
|
| 43 |
+
assert "M" in params # Should be in millions
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestTokenizer:
|
| 47 |
+
"""Test AAM Tokenizer."""
|
| 48 |
+
|
| 49 |
+
def test_basic_encoding(self):
|
| 50 |
+
"""Test basic text encoding."""
|
| 51 |
+
tokenizer = AamTokenizer()
|
| 52 |
+
# Train on sample text first
|
| 53 |
+
tokenizer.train(["Hello world this is a test", "Another test sentence"])
|
| 54 |
+
|
| 55 |
+
ids = tokenizer.encode("Hello world")
|
| 56 |
+
assert isinstance(ids, list)
|
| 57 |
+
assert len(ids) > 0
|
| 58 |
+
assert ids[0] == tokenizer.bos_id
|
| 59 |
+
assert ids[-1] == tokenizer.eos_id
|
| 60 |
+
|
| 61 |
+
def test_decode_roundtrip(self):
|
| 62 |
+
"""Test encode/decode roundtrip."""
|
| 63 |
+
tokenizer = AamTokenizer()
|
| 64 |
+
texts = [
|
| 65 |
+
"Berdasarkan analisis, pencuri adalah Diancang.",
|
| 66 |
+
"Anomali terdeteksi dalam laporan Hefei.",
|
| 67 |
+
"Evidence: Ju Jangmok, Snow Plum Pill.",
|
| 68 |
+
]
|
| 69 |
+
tokenizer.train(texts)
|
| 70 |
+
|
| 71 |
+
for text in texts:
|
| 72 |
+
ids = tokenizer.encode(text)
|
| 73 |
+
decoded = tokenizer.decode(ids, skip_special=True)
|
| 74 |
+
# Decoded text should contain key words
|
| 75 |
+
assert len(decoded) > 0
|
| 76 |
+
|
| 77 |
+
def test_special_tokens(self):
|
| 78 |
+
"""Test special token IDs."""
|
| 79 |
+
tokenizer = AamTokenizer()
|
| 80 |
+
assert tokenizer.pad_id == 0
|
| 81 |
+
assert tokenizer.bos_id == 1
|
| 82 |
+
assert tokenizer.eos_id == 2
|
| 83 |
+
|
| 84 |
+
def test_sentence_boundaries(self):
|
| 85 |
+
"""Test sentence boundary detection."""
|
| 86 |
+
tokenizer = AamTokenizer()
|
| 87 |
+
ids = [1, 10, 20, 5, 30, 40, 5, 50, 2] # BOS, sent, sent, EOS
|
| 88 |
+
boundaries = tokenizer.get_sentence_boundaries(ids)
|
| 89 |
+
assert 3 in boundaries # Index of <sent> token
|
| 90 |
+
assert 6 in boundaries
|
| 91 |
+
|
| 92 |
+
def test_save_load(self, tmp_path):
|
| 93 |
+
"""Test tokenizer save/load."""
|
| 94 |
+
tokenizer = AamTokenizer()
|
| 95 |
+
tokenizer.train(["Test text for tokenizer", "Another training example"])
|
| 96 |
+
|
| 97 |
+
path = tmp_path / "tokenizer.json"
|
| 98 |
+
tokenizer.save(path)
|
| 99 |
+
|
| 100 |
+
loaded = AamTokenizer.load(path)
|
| 101 |
+
assert loaded.vocab_size == tokenizer.vocab_size
|
| 102 |
+
assert loaded.is_trained
|
| 103 |
+
|
| 104 |
+
def test_structure_encoding(self):
|
| 105 |
+
"""Test encoding with graph structure tokens."""
|
| 106 |
+
tokenizer = AamTokenizer()
|
| 107 |
+
tokenizer.train(["Evidence text", "Anomaly description", "Reasoning step"])
|
| 108 |
+
|
| 109 |
+
ids = tokenizer.encode_with_structure(
|
| 110 |
+
text="Main narrative text",
|
| 111 |
+
evidence_nodes=["evidence1", "evidence2"],
|
| 112 |
+
anomalies=["anomaly1"],
|
| 113 |
+
)
|
| 114 |
+
assert isinstance(ids, list)
|
| 115 |
+
assert len(ids) > 0
|
| 116 |
+
|
| 117 |
+
def test_padding(self):
|
| 118 |
+
"""Test sequence padding."""
|
| 119 |
+
tokenizer = AamTokenizer()
|
| 120 |
+
ids = [1, 2, 3]
|
| 121 |
+
padded = tokenizer.pad_sequence(ids, max_len=10)
|
| 122 |
+
assert len(padded) == 10
|
| 123 |
+
assert padded[3:] == [0] * 7 # Padded with pad_id
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TestDiffusionTransformer:
|
| 127 |
+
"""Test Diffusion Transformer model."""
|
| 128 |
+
|
| 129 |
+
def test_forward_pass(self):
|
| 130 |
+
"""Test basic forward pass."""
|
| 131 |
+
config = ModelConfig(
|
| 132 |
+
d_model=128, n_layers=2, n_heads=4, d_ff=256,
|
| 133 |
+
vocab_size=1000, max_seq_len=64,
|
| 134 |
+
)
|
| 135 |
+
model = DiffusionTransformer(config)
|
| 136 |
+
|
| 137 |
+
x_t = torch.randn(2, 32, 128) # batch=2, seq=32, d=128
|
| 138 |
+
t = torch.tensor([100, 500])
|
| 139 |
+
|
| 140 |
+
output = model(x_t=x_t, t=t)
|
| 141 |
+
assert output.shape == (2, 32, 128)
|
| 142 |
+
|
| 143 |
+
def test_with_graph_conditioning(self):
|
| 144 |
+
"""Test forward pass with graph conditioning."""
|
| 145 |
+
config = ModelConfig(
|
| 146 |
+
d_model=128, n_layers=2, n_heads=4, d_ff=256,
|
| 147 |
+
vocab_size=1000, max_seq_len=64,
|
| 148 |
+
)
|
| 149 |
+
model = DiffusionTransformer(config)
|
| 150 |
+
|
| 151 |
+
x_t = torch.randn(2, 32, 128)
|
| 152 |
+
t = torch.tensor([100, 500])
|
| 153 |
+
graph_keys = torch.randn(2, 10, 128) # 10 graph nodes
|
| 154 |
+
graph_values = torch.randn(2, 10, 128)
|
| 155 |
+
|
| 156 |
+
output = model(x_t=x_t, t=t, graph_keys=graph_keys, graph_values=graph_values)
|
| 157 |
+
assert output.shape == (2, 32, 128)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TestAamDiffusionModel:
|
| 161 |
+
"""Test complete AAM Diffusion Model."""
|
| 162 |
+
|
| 163 |
+
def test_model_creation_tiny(self):
|
| 164 |
+
"""Test creating a tiny model."""
|
| 165 |
+
config = get_default_config("tiny")
|
| 166 |
+
model = AamDiffusionModel(config)
|
| 167 |
+
n_params = model.get_num_params()
|
| 168 |
+
assert n_params > 0
|
| 169 |
+
assert n_params < 100e6 # Tiny should be under 100M
|
| 170 |
+
|
| 171 |
+
def test_forward_training(self):
|
| 172 |
+
"""Test training forward pass."""
|
| 173 |
+
config = get_default_config("tiny")
|
| 174 |
+
model = AamDiffusionModel(config)
|
| 175 |
+
model.eval()
|
| 176 |
+
|
| 177 |
+
token_ids = torch.randint(0, config.model.vocab_size, (2, 32))
|
| 178 |
+
timestep = torch.randint(0, config.diffusion.n_timesteps, (2,))
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
predicted, noise = model(token_ids=token_ids, timestep=timestep)
|
| 182 |
+
|
| 183 |
+
assert predicted.shape == noise.shape
|
| 184 |
+
|
| 185 |
+
def test_loss_computation(self):
|
| 186 |
+
"""Test loss computation."""
|
| 187 |
+
config = get_default_config("tiny")
|
| 188 |
+
model = AamDiffusionModel(config)
|
| 189 |
+
model.eval()
|
| 190 |
+
|
| 191 |
+
token_ids = torch.randint(0, config.model.vocab_size, (2, 32))
|
| 192 |
+
timestep = torch.randint(0, config.diffusion.n_timesteps, (2,))
|
| 193 |
+
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
predicted, noise = model(token_ids=token_ids, timestep=timestep)
|
| 196 |
+
loss = model.compute_loss(predicted, noise, timestep)
|
| 197 |
+
|
| 198 |
+
assert loss.item() >= 0
|
| 199 |
+
assert not torch.isnan(loss)
|
| 200 |
+
|
| 201 |
+
def test_save_load(self, tmp_path):
|
| 202 |
+
"""Test model save/load."""
|
| 203 |
+
config = get_default_config("tiny")
|
| 204 |
+
model = AamDiffusionModel(config)
|
| 205 |
+
|
| 206 |
+
path = str(tmp_path / "model.pt")
|
| 207 |
+
model.save(path)
|
| 208 |
+
|
| 209 |
+
loaded = AamDiffusionModel.load(path)
|
| 210 |
+
assert loaded.config.model.d_model == config.model.d_model
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TestGraphEncoder:
|
| 214 |
+
"""Test Graph Conditioning Encoder."""
|
| 215 |
+
|
| 216 |
+
def test_evidence_encoding(self):
|
| 217 |
+
"""Test encoding evidence nodes."""
|
| 218 |
+
config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4)
|
| 219 |
+
encoder = GraphConditioningEncoder(config, vocab_size=1000)
|
| 220 |
+
|
| 221 |
+
evidence_ids = torch.randint(0, 1000, (2, 5, 16)) # 2 batch, 5 nodes, 16 tokens each
|
| 222 |
+
evidence_conf = torch.tensor([[0.8, 0.6, 0.9, 0.7, 0.5],
|
| 223 |
+
[0.7, 0.8, 0.6, 0.9, 0.5]])
|
| 224 |
+
|
| 225 |
+
result = encoder(evidence_ids=evidence_ids, evidence_confidence=evidence_conf)
|
| 226 |
+
assert "keys" in result
|
| 227 |
+
assert "values" in result
|
| 228 |
+
|
| 229 |
+
def test_no_input(self):
|
| 230 |
+
"""Test encoder with no graph data (should return zeros)."""
|
| 231 |
+
config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4)
|
| 232 |
+
encoder = GraphConditioningEncoder(config, vocab_size=1000)
|
| 233 |
+
|
| 234 |
+
result = encoder()
|
| 235 |
+
assert "keys" in result
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
pytest.main([__file__, "-v"])
|
diffusion_llm/tests/test_scheduler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Noise Scheduler."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import pytest
|
| 5 |
+
from diffusion_llm.model.noise_scheduler import NoiseScheduler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestNoiseScheduler:
|
| 9 |
+
"""Test suite for the NoiseScheduler."""
|
| 10 |
+
|
| 11 |
+
def test_cosine_schedule(self):
|
| 12 |
+
"""Test cosine noise schedule creation."""
|
| 13 |
+
scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="cosine")
|
| 14 |
+
assert scheduler.betas.shape == (1000,)
|
| 15 |
+
assert (scheduler.betas > 0).all()
|
| 16 |
+
assert (scheduler.betas < 1).all()
|
| 17 |
+
|
| 18 |
+
def test_linear_schedule(self):
|
| 19 |
+
"""Test linear noise schedule creation."""
|
| 20 |
+
scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="linear")
|
| 21 |
+
assert scheduler.betas.shape == (1000,)
|
| 22 |
+
assert scheduler.betas[0] < scheduler.betas[-1] # Increasing
|
| 23 |
+
|
| 24 |
+
def test_sigmoid_schedule(self):
|
| 25 |
+
"""Test sigmoid noise schedule creation."""
|
| 26 |
+
scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="sigmoid")
|
| 27 |
+
assert scheduler.betas.shape == (1000,)
|
| 28 |
+
assert (scheduler.betas > 0).all()
|
| 29 |
+
|
| 30 |
+
def test_add_noise(self):
|
| 31 |
+
"""Test forward diffusion (adding noise)."""
|
| 32 |
+
scheduler = NoiseScheduler(n_timesteps=1000)
|
| 33 |
+
x_0 = torch.randn(2, 10, 64) # batch=2, seq=10, d=64
|
| 34 |
+
noise = torch.randn_like(x_0)
|
| 35 |
+
t = torch.tensor([0, 500])
|
| 36 |
+
|
| 37 |
+
x_t = scheduler.add_noise(x_0, noise, t)
|
| 38 |
+
assert x_t.shape == x_0.shape
|
| 39 |
+
# At t=0, x_t should be close to x_0
|
| 40 |
+
# At t=500, x_t should be significantly different
|
| 41 |
+
|
| 42 |
+
def test_loss_target_epsilon(self):
|
| 43 |
+
"""Test epsilon prediction target."""
|
| 44 |
+
scheduler = NoiseScheduler(prediction_type="epsilon")
|
| 45 |
+
x_0 = torch.randn(2, 10, 64)
|
| 46 |
+
noise = torch.randn_like(x_0)
|
| 47 |
+
t = torch.tensor([100, 500])
|
| 48 |
+
|
| 49 |
+
target = scheduler.compute_loss_target(x_0, noise, t)
|
| 50 |
+
assert torch.allclose(target, noise)
|
| 51 |
+
|
| 52 |
+
def test_loss_target_x0(self):
|
| 53 |
+
"""Test x0 prediction target."""
|
| 54 |
+
scheduler = NoiseScheduler(prediction_type="x0")
|
| 55 |
+
x_0 = torch.randn(2, 10, 64)
|
| 56 |
+
noise = torch.randn_like(x_0)
|
| 57 |
+
t = torch.tensor([100, 500])
|
| 58 |
+
|
| 59 |
+
target = scheduler.compute_loss_target(x_0, noise, t)
|
| 60 |
+
assert torch.allclose(target, x_0)
|
| 61 |
+
|
| 62 |
+
def test_predict_x0_from_epsilon(self):
|
| 63 |
+
"""Test x0 prediction from epsilon."""
|
| 64 |
+
scheduler = NoiseScheduler(prediction_type="epsilon")
|
| 65 |
+
x_0 = torch.randn(2, 10, 64)
|
| 66 |
+
noise = torch.randn_like(x_0)
|
| 67 |
+
t = torch.tensor([100])
|
| 68 |
+
|
| 69 |
+
x_t = scheduler.add_noise(x_0, noise, t)
|
| 70 |
+
x_0_pred = scheduler.predict_x0_from_epsilon(x_t, noise, t)
|
| 71 |
+
# Should be close to original x_0
|
| 72 |
+
assert x_0_pred.shape == x_0.shape
|
| 73 |
+
|
| 74 |
+
def test_ddpm_step(self):
|
| 75 |
+
"""Test single DDPM reverse step."""
|
| 76 |
+
scheduler = NoiseScheduler(n_timesteps=1000)
|
| 77 |
+
x_t = torch.randn(2, 10, 64)
|
| 78 |
+
model_output = torch.randn_like(x_t)
|
| 79 |
+
t = torch.tensor([500, 500])
|
| 80 |
+
|
| 81 |
+
x_prev = scheduler.step_ddpm(model_output, x_t, t)
|
| 82 |
+
assert x_prev.shape == x_t.shape
|
| 83 |
+
|
| 84 |
+
def test_ddim_step(self):
|
| 85 |
+
"""Test single DDIM reverse step."""
|
| 86 |
+
scheduler = NoiseScheduler(n_timesteps=1000)
|
| 87 |
+
x_t = torch.randn(2, 10, 64)
|
| 88 |
+
model_output = torch.randn_like(x_t)
|
| 89 |
+
|
| 90 |
+
x_prev = scheduler.step_ddim(model_output, x_t, t=500, t_prev=400)
|
| 91 |
+
assert x_prev.shape == x_t.shape
|
| 92 |
+
|
| 93 |
+
def test_timestep_schedule(self):
|
| 94 |
+
"""Test inference timestep schedule."""
|
| 95 |
+
scheduler = NoiseScheduler(n_timesteps=1000)
|
| 96 |
+
schedule = scheduler.get_timestep_schedule(n_inference_steps=50)
|
| 97 |
+
assert len(schedule) > 0
|
| 98 |
+
assert schedule[0] > schedule[-1] # Descending order
|
diffusion_llm/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizer module for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 4 |
+
|
| 5 |
+
__all__ = ["AamTokenizer"]
|
diffusion_llm/tokenizer/aam_tokenizer.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Tokenizer
|
| 3 |
+
|
| 4 |
+
Sentence-level + subword BPE hybrid tokenizer designed specifically
|
| 5 |
+
for AAM's sentence arrangement task.
|
| 6 |
+
|
| 7 |
+
Unlike standard tokenizers (GPT-2 BPE, SentencePiece) that tokenize
|
| 8 |
+
at the subword level, AAM's tokenizer is designed with SENTENCE
|
| 9 |
+
ARRANGEMENT in mind:
|
| 10 |
+
|
| 11 |
+
1. Sentences are the primary unit of generation (not individual tokens)
|
| 12 |
+
2. Within sentences, subword BPE handles individual words
|
| 13 |
+
3. Special tokens for graph structure (evidence, anomaly, confidence)
|
| 14 |
+
4. Sentence boundary markers for the diffusion model
|
| 15 |
+
|
| 16 |
+
The tokenizer maintains two levels:
|
| 17 |
+
- Sentence level: Where sentences begin/end, for the diffusion model
|
| 18 |
+
to arrange and revise non-sequentially
|
| 19 |
+
- Token level: Subword units within sentences, for detailed generation
|
| 20 |
+
|
| 21 |
+
Analogi: Jin Soun tidak berpikir dalam kata-per-kata — dia
|
| 22 |
+
berpikir dalam KALIMAT. "Pencuri = Diancang pair. Ju Jangmok = cover."
|
| 23 |
+
Setiap kalimat sudah utuh, yang dia susun adalah URUTAN kalimat.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import re
|
| 30 |
+
import unicodedata
|
| 31 |
+
from collections import Counter
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
from diffusion_llm.config.model_config import TokenizerConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Special token IDs (always at the start of vocabulary)
|
| 39 |
+
SPECIAL_TOKENS = [
|
| 40 |
+
"<pad>", # 0
|
| 41 |
+
"<bos>", # 1
|
| 42 |
+
"<eos>", # 2
|
| 43 |
+
"<mask>", # 3
|
| 44 |
+
"<noise>", # 4
|
| 45 |
+
"<sent>", # 5 - sentence boundary
|
| 46 |
+
"<evidence>", # 6
|
| 47 |
+
"<anomaly>", # 7
|
| 48 |
+
"<confidence>", # 8
|
| 49 |
+
"<reasoning>", # 9
|
| 50 |
+
"<composition>",# 10
|
| 51 |
+
"<temporal>", # 11
|
| 52 |
+
"<unk>", # 12
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AamTokenizer:
|
| 57 |
+
"""AAM Sentence-Level + Subword BPE Hybrid Tokenizer.
|
| 58 |
+
|
| 59 |
+
This tokenizer is specifically designed for the AAM Diffusion LLM:
|
| 60 |
+
- It understands sentence boundaries (<sent> tokens)
|
| 61 |
+
- It has special tokens for graph structure
|
| 62 |
+
- It uses BPE for subword tokenization within sentences
|
| 63 |
+
- It can encode/decode both plain text and graph-conditioned text
|
| 64 |
+
|
| 65 |
+
Usage:
|
| 66 |
+
tokenizer = AamTokenizer()
|
| 67 |
+
tokenizer.train(texts, vocab_size=28000)
|
| 68 |
+
|
| 69 |
+
# Encode text
|
| 70 |
+
ids = tokenizer.encode("Berdasarkan analisis, pencuri adalah Diancang.")
|
| 71 |
+
|
| 72 |
+
# Decode back
|
| 73 |
+
text = tokenizer.decode(ids)
|
| 74 |
+
|
| 75 |
+
# With graph structure tokens
|
| 76 |
+
ids = tokenizer.encode_with_structure(
|
| 77 |
+
"Pencuri = Diancang pair",
|
| 78 |
+
evidence_nodes=["hefei", "diancang"],
|
| 79 |
+
anomalies=[{"desc": "no external pill consumption"}],
|
| 80 |
+
)
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, config: Optional[TokenizerConfig] = None):
|
| 84 |
+
"""Initialize the tokenizer.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
config: Tokenizer configuration. Uses defaults if None.
|
| 88 |
+
"""
|
| 89 |
+
self.config = config or TokenizerConfig()
|
| 90 |
+
|
| 91 |
+
# Build initial vocabulary with special tokens
|
| 92 |
+
self.vocab: dict[str, int] = {}
|
| 93 |
+
self.id_to_token: dict[int, str] = {}
|
| 94 |
+
self._init_special_tokens()
|
| 95 |
+
|
| 96 |
+
# BPE merges (learned during training)
|
| 97 |
+
self.merges: dict[tuple[str, str], int] = {}
|
| 98 |
+
self._bpe_cache: dict[str, str] = {}
|
| 99 |
+
|
| 100 |
+
# Compiled patterns
|
| 101 |
+
self._sentence_pattern = re.compile(
|
| 102 |
+
r'(?<=[.!?])\s+|(?<=\n)\s*'
|
| 103 |
+
)
|
| 104 |
+
self._word_pattern = re.compile(
|
| 105 |
+
r'\w+|[^\w\s]'
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Flag
|
| 109 |
+
self._is_trained = False
|
| 110 |
+
|
| 111 |
+
def _init_special_tokens(self) -> None:
|
| 112 |
+
"""Initialize special tokens in vocabulary."""
|
| 113 |
+
for i, token in enumerate(SPECIAL_TOKENS):
|
| 114 |
+
self.vocab[token] = i
|
| 115 |
+
self.id_to_token[i] = token
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def pad_id(self) -> int:
|
| 119 |
+
return self.vocab[self.config.pad_token]
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def bos_id(self) -> int:
|
| 123 |
+
return self.vocab[self.config.bos_token]
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def eos_id(self) -> int:
|
| 127 |
+
return self.vocab[self.config.eos_token]
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def mask_id(self) -> int:
|
| 131 |
+
return self.vocab[self.config.mask_token]
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def noise_id(self) -> int:
|
| 135 |
+
return self.vocab[self.config.noise_token]
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def sent_id(self) -> int:
|
| 139 |
+
return self.vocab[self.config.sentence_boundary_token]
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def unk_id(self) -> int:
|
| 143 |
+
return self.vocab.get("<unk>", len(SPECIAL_TOKENS) - 1)
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def vocab_size(self) -> int:
|
| 147 |
+
"""Current vocabulary size."""
|
| 148 |
+
return len(self.vocab)
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def is_trained(self) -> bool:
|
| 152 |
+
"""Whether the tokenizer has been trained."""
|
| 153 |
+
return self._is_trained
|
| 154 |
+
|
| 155 |
+
def train(
|
| 156 |
+
self,
|
| 157 |
+
texts: list[str],
|
| 158 |
+
vocab_size: Optional[int] = None,
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Train the BPE tokenizer on a corpus.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
texts: List of training texts.
|
| 164 |
+
vocab_size: Target vocabulary size. Uses config if None.
|
| 165 |
+
"""
|
| 166 |
+
target_vocab = vocab_size or self.config.bpe_vocab_size
|
| 167 |
+
|
| 168 |
+
# Step 1: Pre-tokenize into words
|
| 169 |
+
word_freqs: Counter = Counter()
|
| 170 |
+
for text in texts:
|
| 171 |
+
words = self._pre_tokenize(text)
|
| 172 |
+
for word in words:
|
| 173 |
+
word_freqs[word] += 1
|
| 174 |
+
|
| 175 |
+
# Step 2: Initialize character-level vocabulary
|
| 176 |
+
char_vocab: set[str] = set()
|
| 177 |
+
for word in word_freqs:
|
| 178 |
+
for char in word:
|
| 179 |
+
char_vocab.add(char)
|
| 180 |
+
|
| 181 |
+
# Add character tokens to vocabulary
|
| 182 |
+
for char in sorted(char_vocab):
|
| 183 |
+
if char not in self.vocab:
|
| 184 |
+
idx = len(self.vocab)
|
| 185 |
+
self.vocab[char] = idx
|
| 186 |
+
self.id_to_token[idx] = char
|
| 187 |
+
|
| 188 |
+
# Step 3: Split words into character sequences
|
| 189 |
+
word_splits: dict[str, list[str]] = {}
|
| 190 |
+
for word in word_freqs:
|
| 191 |
+
word_splits[word] = list(word)
|
| 192 |
+
# Add end-of-word marker
|
| 193 |
+
if len(word_splits[word]) > 1:
|
| 194 |
+
word_splits[word][-1] = word_splits[word][-1] + "</w>"
|
| 195 |
+
|
| 196 |
+
# Step 4: Learn BPE merges
|
| 197 |
+
n_merges = target_vocab - len(self.vocab)
|
| 198 |
+
for i in range(n_merges):
|
| 199 |
+
# Count pairs
|
| 200 |
+
pair_freqs: Counter = Counter()
|
| 201 |
+
for word, freq in word_freqs.items():
|
| 202 |
+
symbols = word_splits.get(word, [])
|
| 203 |
+
for j in range(len(symbols) - 1):
|
| 204 |
+
pair = (symbols[j], symbols[j + 1])
|
| 205 |
+
pair_freqs[pair] += freq
|
| 206 |
+
|
| 207 |
+
if not pair_freqs:
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
# Find most frequent pair
|
| 211 |
+
best_pair = pair_freqs.most_common(1)[0][0]
|
| 212 |
+
|
| 213 |
+
# Record merge
|
| 214 |
+
self.merges[best_pair] = i
|
| 215 |
+
|
| 216 |
+
# Apply merge
|
| 217 |
+
new_symbol = best_pair[0] + best_pair[1]
|
| 218 |
+
for word in word_splits:
|
| 219 |
+
symbols = word_splits[word]
|
| 220 |
+
new_symbols = []
|
| 221 |
+
j = 0
|
| 222 |
+
while j < len(symbols):
|
| 223 |
+
if (
|
| 224 |
+
j < len(symbols) - 1
|
| 225 |
+
and symbols[j] == best_pair[0]
|
| 226 |
+
and symbols[j + 1] == best_pair[1]
|
| 227 |
+
):
|
| 228 |
+
new_symbols.append(new_symbol)
|
| 229 |
+
j += 2
|
| 230 |
+
else:
|
| 231 |
+
new_symbols.append(symbols[j])
|
| 232 |
+
j += 1
|
| 233 |
+
word_splits[word] = new_symbols
|
| 234 |
+
|
| 235 |
+
# Add merged token to vocabulary
|
| 236 |
+
if new_symbol not in self.vocab:
|
| 237 |
+
idx = len(self.vocab)
|
| 238 |
+
self.vocab[new_symbol] = idx
|
| 239 |
+
self.id_to_token[idx] = new_symbol
|
| 240 |
+
|
| 241 |
+
self._is_trained = True
|
| 242 |
+
self._bpe_cache.clear()
|
| 243 |
+
|
| 244 |
+
def _pre_tokenize(self, text: str) -> list[str]:
|
| 245 |
+
"""Pre-tokenize text into words.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
text: Input text.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
List of words.
|
| 252 |
+
"""
|
| 253 |
+
# Normalize unicode
|
| 254 |
+
text = unicodedata.normalize("NFC", text)
|
| 255 |
+
# Split into words and punctuation
|
| 256 |
+
words = self._word_pattern.findall(text.lower())
|
| 257 |
+
return words
|
| 258 |
+
|
| 259 |
+
def _bpe_encode(self, word: str) -> list[str]:
|
| 260 |
+
"""Apply BPE to a single word.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
word: Input word (lowercase).
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
List of BPE tokens.
|
| 267 |
+
"""
|
| 268 |
+
if word in self._bpe_cache:
|
| 269 |
+
return self._bpe_cache[word].split()
|
| 270 |
+
|
| 271 |
+
# Start with character-level split
|
| 272 |
+
symbols = list(word)
|
| 273 |
+
if len(symbols) > 1:
|
| 274 |
+
symbols[-1] = symbols[-1] + "</w>"
|
| 275 |
+
|
| 276 |
+
# Apply merges in order
|
| 277 |
+
while len(symbols) > 1:
|
| 278 |
+
# Find the pair with the lowest merge rank
|
| 279 |
+
best_pair = None
|
| 280 |
+
best_rank = float("inf")
|
| 281 |
+
|
| 282 |
+
for i in range(len(symbols) - 1):
|
| 283 |
+
pair = (symbols[i], symbols[i + 1])
|
| 284 |
+
rank = self.merges.get(pair, float("inf"))
|
| 285 |
+
if rank < best_rank:
|
| 286 |
+
best_rank = rank
|
| 287 |
+
best_pair = pair
|
| 288 |
+
|
| 289 |
+
if best_pair is None or best_rank == float("inf"):
|
| 290 |
+
break
|
| 291 |
+
|
| 292 |
+
# Apply merge
|
| 293 |
+
new_symbol = best_pair[0] + best_pair[1]
|
| 294 |
+
new_symbols = []
|
| 295 |
+
i = 0
|
| 296 |
+
while i < len(symbols):
|
| 297 |
+
if (
|
| 298 |
+
i < len(symbols) - 1
|
| 299 |
+
and symbols[i] == best_pair[0]
|
| 300 |
+
and symbols[i + 1] == best_pair[1]
|
| 301 |
+
):
|
| 302 |
+
new_symbols.append(new_symbol)
|
| 303 |
+
i += 2
|
| 304 |
+
else:
|
| 305 |
+
new_symbols.append(symbols[i])
|
| 306 |
+
i += 1
|
| 307 |
+
symbols = new_symbols
|
| 308 |
+
|
| 309 |
+
# Cache result
|
| 310 |
+
self._bpe_cache[word] = " ".join(symbols)
|
| 311 |
+
return symbols
|
| 312 |
+
|
| 313 |
+
def encode(self, text: str, add_special: bool = True) -> list[int]:
|
| 314 |
+
"""Encode text to token IDs.
|
| 315 |
+
|
| 316 |
+
The encoding process:
|
| 317 |
+
1. Split text into sentences
|
| 318 |
+
2. Insert sentence boundary tokens between sentences
|
| 319 |
+
3. BPE-encode each word within sentences
|
| 320 |
+
4. Add BOS/EOS tokens if requested
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
text: Input text.
|
| 324 |
+
add_special: Whether to add BOS/EOS tokens.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
List of token IDs.
|
| 328 |
+
"""
|
| 329 |
+
ids = []
|
| 330 |
+
|
| 331 |
+
if add_special:
|
| 332 |
+
ids.append(self.bos_id)
|
| 333 |
+
|
| 334 |
+
# Split into sentences
|
| 335 |
+
sentences = self._split_sentences(text)
|
| 336 |
+
|
| 337 |
+
for i, sentence in enumerate(sentences):
|
| 338 |
+
if i > 0:
|
| 339 |
+
ids.append(self.sent_id) # Sentence boundary
|
| 340 |
+
|
| 341 |
+
# Tokenize words in the sentence
|
| 342 |
+
words = self._pre_tokenize(sentence)
|
| 343 |
+
for word in words:
|
| 344 |
+
if self._is_trained:
|
| 345 |
+
bpe_tokens = self._bpe_encode(word)
|
| 346 |
+
for token in bpe_tokens:
|
| 347 |
+
if token in self.vocab:
|
| 348 |
+
ids.append(self.vocab[token])
|
| 349 |
+
else:
|
| 350 |
+
ids.append(self.unk_id)
|
| 351 |
+
else:
|
| 352 |
+
# Fallback: character-level encoding
|
| 353 |
+
for char in word:
|
| 354 |
+
if char in self.vocab:
|
| 355 |
+
ids.append(self.vocab[char])
|
| 356 |
+
else:
|
| 357 |
+
ids.append(self.unk_id)
|
| 358 |
+
|
| 359 |
+
if add_special:
|
| 360 |
+
ids.append(self.eos_id)
|
| 361 |
+
|
| 362 |
+
return ids
|
| 363 |
+
|
| 364 |
+
def encode_with_structure(
|
| 365 |
+
self,
|
| 366 |
+
text: str,
|
| 367 |
+
evidence_nodes: Optional[list[str]] = None,
|
| 368 |
+
compositions: Optional[list[str]] = None,
|
| 369 |
+
anomalies: Optional[list[str]] = None,
|
| 370 |
+
reasoning_steps: Optional[list[str]] = None,
|
| 371 |
+
confidence: Optional[float] = None,
|
| 372 |
+
) -> list[int]:
|
| 373 |
+
"""Encode text with graph structure tokens.
|
| 374 |
+
|
| 375 |
+
Adds structural tokens that represent the graph conditioning,
|
| 376 |
+
so the model knows what kind of evidence/anomalies it's
|
| 377 |
+
generating from.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
text: The narrative text.
|
| 381 |
+
evidence_nodes: List of evidence node labels.
|
| 382 |
+
compositions: List of composition descriptions.
|
| 383 |
+
anomalies: List of anomaly descriptions.
|
| 384 |
+
reasoning_steps: List of reasoning step descriptions.
|
| 385 |
+
confidence: Overall confidence score.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
List of token IDs with structure tokens.
|
| 389 |
+
"""
|
| 390 |
+
ids = [self.bos_id]
|
| 391 |
+
|
| 392 |
+
# Evidence section
|
| 393 |
+
if evidence_nodes:
|
| 394 |
+
ids.append(self.vocab["<evidence>"])
|
| 395 |
+
for node in evidence_nodes:
|
| 396 |
+
node_ids = self.encode(node, add_special=False)
|
| 397 |
+
ids.extend(node_ids)
|
| 398 |
+
ids.append(self.vocab["<evidence>"]) # Close section
|
| 399 |
+
|
| 400 |
+
# Anomaly section
|
| 401 |
+
if anomalies:
|
| 402 |
+
ids.append(self.vocab["<anomaly>"])
|
| 403 |
+
for anomaly in anomalies:
|
| 404 |
+
anom_ids = self.encode(anomaly, add_special=False)
|
| 405 |
+
ids.extend(anom_ids)
|
| 406 |
+
ids.append(self.vocab["<anomaly>"])
|
| 407 |
+
|
| 408 |
+
# Reasoning section
|
| 409 |
+
if reasoning_steps:
|
| 410 |
+
ids.append(self.vocab["<reasoning>"])
|
| 411 |
+
for step in reasoning_steps:
|
| 412 |
+
step_ids = self.encode(step, add_special=False)
|
| 413 |
+
ids.extend(step_ids)
|
| 414 |
+
ids.append(self.sent_id)
|
| 415 |
+
ids.append(self.vocab["<reasoning>"])
|
| 416 |
+
|
| 417 |
+
# Confidence
|
| 418 |
+
if confidence is not None:
|
| 419 |
+
ids.append(self.vocab["<confidence>"])
|
| 420 |
+
# Encode confidence as a token (discretized)
|
| 421 |
+
conf_bucket = min(int(confidence * 10), 9)
|
| 422 |
+
conf_token = f"<conf_{conf_bucket}>"
|
| 423 |
+
if conf_token in self.vocab:
|
| 424 |
+
ids.append(self.vocab[conf_token])
|
| 425 |
+
|
| 426 |
+
# Composition section
|
| 427 |
+
if compositions:
|
| 428 |
+
ids.append(self.vocab["<composition>"])
|
| 429 |
+
for comp in compositions:
|
| 430 |
+
comp_ids = self.encode(comp, add_special=False)
|
| 431 |
+
ids.extend(comp_ids)
|
| 432 |
+
ids.append(self.sent_id)
|
| 433 |
+
ids.append(self.vocab["<composition>"])
|
| 434 |
+
|
| 435 |
+
# Main narrative
|
| 436 |
+
narrative_ids = self.encode(text, add_special=False)
|
| 437 |
+
ids.extend(narrative_ids)
|
| 438 |
+
|
| 439 |
+
ids.append(self.eos_id)
|
| 440 |
+
return ids
|
| 441 |
+
|
| 442 |
+
def decode(self, ids: list[int], skip_special: bool = False) -> str:
|
| 443 |
+
"""Decode token IDs back to text.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
ids: List of token IDs.
|
| 447 |
+
skip_special: Whether to skip special tokens in output.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
Decoded text string.
|
| 451 |
+
"""
|
| 452 |
+
special_ids = set()
|
| 453 |
+
if skip_special:
|
| 454 |
+
for token in SPECIAL_TOKENS:
|
| 455 |
+
if token in self.vocab:
|
| 456 |
+
special_ids.add(self.vocab[token])
|
| 457 |
+
|
| 458 |
+
tokens = []
|
| 459 |
+
for id_ in ids:
|
| 460 |
+
if skip_special and id_ in special_ids:
|
| 461 |
+
continue
|
| 462 |
+
if id_ in self.id_to_token:
|
| 463 |
+
tokens.append(self.id_to_token[id_])
|
| 464 |
+
else:
|
| 465 |
+
tokens.append("<unk>")
|
| 466 |
+
|
| 467 |
+
# Join and clean up BPE tokens
|
| 468 |
+
text = "".join(tokens)
|
| 469 |
+
text = text.replace("</w>", " ")
|
| 470 |
+
# Clean up sentence boundaries
|
| 471 |
+
text = text.replace("<sent>", ". ")
|
| 472 |
+
# Clean up multiple spaces
|
| 473 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 474 |
+
|
| 475 |
+
return text
|
| 476 |
+
|
| 477 |
+
def _split_sentences(self, text: str) -> list[str]:
|
| 478 |
+
"""Split text into sentences.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
text: Input text.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
List of sentence strings.
|
| 485 |
+
"""
|
| 486 |
+
sentences = self._sentence_pattern.split(text)
|
| 487 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 488 |
+
|
| 489 |
+
def pad_sequence(
|
| 490 |
+
self,
|
| 491 |
+
ids: list[int],
|
| 492 |
+
max_len: int,
|
| 493 |
+
pad_id: Optional[int] = None,
|
| 494 |
+
) -> list[int]:
|
| 495 |
+
"""Pad a sequence to max_len.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
ids: Token IDs.
|
| 499 |
+
max_len: Target length.
|
| 500 |
+
pad_id: Padding token ID. Uses config if None.
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
Padded sequence.
|
| 504 |
+
"""
|
| 505 |
+
padding_id = pad_id if pad_id is not None else self.pad_id
|
| 506 |
+
if len(ids) >= max_len:
|
| 507 |
+
return ids[:max_len]
|
| 508 |
+
return ids + [padding_id] * (max_len - len(ids))
|
| 509 |
+
|
| 510 |
+
def get_sentence_boundaries(self, ids: list[int]) -> list[int]:
|
| 511 |
+
"""Find sentence boundary positions in a token sequence.
|
| 512 |
+
|
| 513 |
+
This is used by the diffusion model to identify which tokens
|
| 514 |
+
belong to which sentence, enabling non-sequential generation
|
| 515 |
+
and revision at the sentence level.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
ids: Token IDs.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
List of indices where sentence boundaries occur.
|
| 522 |
+
"""
|
| 523 |
+
boundaries = []
|
| 524 |
+
for i, id_ in enumerate(ids):
|
| 525 |
+
if id_ == self.sent_id:
|
| 526 |
+
boundaries.append(i)
|
| 527 |
+
return boundaries
|
| 528 |
+
|
| 529 |
+
def save(self, path: str | Path) -> None:
|
| 530 |
+
"""Save tokenizer to file.
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
path: Output file path (JSON).
|
| 534 |
+
"""
|
| 535 |
+
path = Path(path)
|
| 536 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 537 |
+
|
| 538 |
+
data = {
|
| 539 |
+
"config": {
|
| 540 |
+
"bpe_vocab_size": self.config.bpe_vocab_size,
|
| 541 |
+
"max_sentences": self.config.max_sentences,
|
| 542 |
+
"sentence_boundary_token": self.config.sentence_boundary_token,
|
| 543 |
+
"pad_token": self.config.pad_token,
|
| 544 |
+
"bos_token": self.config.bos_token,
|
| 545 |
+
"eos_token": self.config.eos_token,
|
| 546 |
+
"mask_token": self.config.mask_token,
|
| 547 |
+
"noise_token": self.config.noise_token,
|
| 548 |
+
"min_frequency": self.config.min_frequency,
|
| 549 |
+
},
|
| 550 |
+
"vocab": self.vocab,
|
| 551 |
+
"merges": {f"{k[0]}|||{k[1]}": v for k, v in self.merges.items()},
|
| 552 |
+
"is_trained": self._is_trained,
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 556 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 557 |
+
|
| 558 |
+
@classmethod
|
| 559 |
+
def load(cls, path: str | Path) -> AamTokenizer:
|
| 560 |
+
"""Load tokenizer from file.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
path: Input file path (JSON).
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Loaded AamTokenizer.
|
| 567 |
+
"""
|
| 568 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 569 |
+
data = json.load(f)
|
| 570 |
+
|
| 571 |
+
config = TokenizerConfig(**data.get("config", {}))
|
| 572 |
+
tokenizer = cls(config=config)
|
| 573 |
+
|
| 574 |
+
# Restore vocabulary
|
| 575 |
+
tokenizer.vocab = data["vocab"]
|
| 576 |
+
tokenizer.id_to_token = {int(v): k for k, v in data["vocab"].items()}
|
| 577 |
+
|
| 578 |
+
# Restore merges
|
| 579 |
+
tokenizer.merges = {}
|
| 580 |
+
for k_str, v in data.get("merges", {}).items():
|
| 581 |
+
parts = k_str.split("|||")
|
| 582 |
+
tokenizer.merges[(parts[0], parts[1])] = v
|
| 583 |
+
|
| 584 |
+
tokenizer._is_trained = data.get("is_trained", False)
|
| 585 |
+
|
| 586 |
+
return tokenizer
|
| 587 |
+
|
| 588 |
+
def __len__(self) -> int:
|
| 589 |
+
return self.vocab_size
|
| 590 |
+
|
| 591 |
+
def __repr__(self) -> str:
|
| 592 |
+
status = "trained" if self._is_trained else "untrained"
|
| 593 |
+
return (
|
| 594 |
+
f"AamTokenizer(vocab_size={self.vocab_size}, "
|
| 595 |
+
f"merges={len(self.merges)}, status={status})"
|
| 596 |
+
)
|
diffusion_llm/training/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training module for AAM Diffusion LLM."""
|
| 2 |
+
|
| 3 |
+
from diffusion_llm.training.trainer import AamTrainer
|
| 4 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset
|
| 5 |
+
from diffusion_llm.training.losses import DiffusionLoss, compute_loss
|
| 6 |
+
|
| 7 |
+
__all__ = ["AamTrainer", "GraphNarrativeDataset", "DiffusionLoss", "compute_loss"]
|
diffusion_llm/training/dataset.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Dataset
|
| 3 |
+
|
| 4 |
+
Dataset class for Graph→Narrative training pairs.
|
| 5 |
+
|
| 6 |
+
Each training example consists of:
|
| 7 |
+
- Graph conditioning: evidence nodes, compositions, confidence,
|
| 8 |
+
anomalies, reasoning chains, temporal context
|
| 9 |
+
- Target narrative: natural language text that represents
|
| 10 |
+
the graph data in sentence form
|
| 11 |
+
|
| 12 |
+
The dataset handles:
|
| 13 |
+
- Loading from JSONL files
|
| 14 |
+
- Tokenization of both graph data and narratives
|
| 15 |
+
- Padding and batching
|
| 16 |
+
- Data augmentation (sentence shuffling, noise injection)
|
| 17 |
+
|
| 18 |
+
Analogi: Seperti Jin Soun berlatih mengungkapkan kesimpulan —
|
| 19 |
+
dia diberi "kasus" (graph data) dan "jawaban yang benar"
|
| 20 |
+
(narrative target), lalu berlatih sampai bisa menyusun
|
| 21 |
+
kalimat yang tepat dari graph.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import logging
|
| 28 |
+
import random
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Optional
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from torch.utils.data import Dataset
|
| 35 |
+
|
| 36 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class GraphNarrativeExample:
|
| 43 |
+
"""A single training example: graph conditioning + target narrative.
|
| 44 |
+
|
| 45 |
+
This represents the "input" and "expected output" for one
|
| 46 |
+
training step of the diffusion model.
|
| 47 |
+
"""
|
| 48 |
+
# Target narrative (what the model should generate)
|
| 49 |
+
narrative: str = ""
|
| 50 |
+
|
| 51 |
+
# Graph conditioning inputs
|
| 52 |
+
trigger: str = ""
|
| 53 |
+
evidence_nodes: list[str] = field(default_factory=list)
|
| 54 |
+
compositions: list[str] = field(default_factory=list)
|
| 55 |
+
confidence_map: dict[str, float] = field(default_factory=dict)
|
| 56 |
+
anomalies: list[str] = field(default_factory=list)
|
| 57 |
+
reasoning_steps: list[str] = field(default_factory=list)
|
| 58 |
+
source_trust: float = 1.0
|
| 59 |
+
temporal_context: list[str] = field(default_factory=list)
|
| 60 |
+
|
| 61 |
+
# Metadata
|
| 62 |
+
language: str = "id"
|
| 63 |
+
source: str = "synthetic"
|
| 64 |
+
|
| 65 |
+
def to_dict(self) -> dict:
|
| 66 |
+
"""Serialize to dictionary."""
|
| 67 |
+
return {
|
| 68 |
+
"narrative": self.narrative,
|
| 69 |
+
"trigger": self.trigger,
|
| 70 |
+
"evidence_nodes": self.evidence_nodes,
|
| 71 |
+
"compositions": self.compositions,
|
| 72 |
+
"confidence_map": self.confidence_map,
|
| 73 |
+
"anomalies": self.anomalies,
|
| 74 |
+
"reasoning_steps": self.reasoning_steps,
|
| 75 |
+
"source_trust": self.source_trust,
|
| 76 |
+
"temporal_context": self.temporal_context,
|
| 77 |
+
"language": self.language,
|
| 78 |
+
"source": self.source,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def from_dict(cls, data: dict) -> GraphNarrativeExample:
|
| 83 |
+
"""Deserialize from dictionary."""
|
| 84 |
+
return cls(
|
| 85 |
+
narrative=data.get("narrative", ""),
|
| 86 |
+
trigger=data.get("trigger", ""),
|
| 87 |
+
evidence_nodes=data.get("evidence_nodes", []),
|
| 88 |
+
compositions=data.get("compositions", []),
|
| 89 |
+
confidence_map=data.get("confidence_map", {}),
|
| 90 |
+
anomalies=data.get("anomalies", []),
|
| 91 |
+
reasoning_steps=data.get("reasoning_steps", []),
|
| 92 |
+
source_trust=data.get("source_trust", 1.0),
|
| 93 |
+
temporal_context=data.get("temporal_context", []),
|
| 94 |
+
language=data.get("language", "id"),
|
| 95 |
+
source=data.get("source", "synthetic"),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class BatchOutput:
|
| 101 |
+
"""Output from a single batch.
|
| 102 |
+
|
| 103 |
+
All tensors are already padded to uniform length.
|
| 104 |
+
"""
|
| 105 |
+
token_ids: torch.Tensor
|
| 106 |
+
"""Target narrative token IDs, shape (batch, seq_len)."""
|
| 107 |
+
|
| 108 |
+
evidence_ids: Optional[torch.Tensor] = None
|
| 109 |
+
"""Evidence node token IDs, shape (batch, n_evidence, ev_seq_len)."""
|
| 110 |
+
|
| 111 |
+
evidence_confidence: Optional[torch.Tensor] = None
|
| 112 |
+
"""Evidence confidence, shape (batch, n_evidence)."""
|
| 113 |
+
|
| 114 |
+
anomaly_ids: Optional[torch.Tensor] = None
|
| 115 |
+
"""Anomaly token IDs, shape (batch, n_anomalies, an_seq_len)."""
|
| 116 |
+
|
| 117 |
+
anomaly_confidence: Optional[torch.Tensor] = None
|
| 118 |
+
"""Anomaly confidence, shape (batch, n_anomalies)."""
|
| 119 |
+
|
| 120 |
+
reasoning_ids: Optional[torch.Tensor] = None
|
| 121 |
+
"""Reasoning step token IDs, shape (batch, n_steps, r_seq_len)."""
|
| 122 |
+
|
| 123 |
+
reasoning_confidence: Optional[torch.Tensor] = None
|
| 124 |
+
"""Reasoning confidence, shape (batch, n_steps)."""
|
| 125 |
+
|
| 126 |
+
source_trust: Optional[torch.Tensor] = None
|
| 127 |
+
"""Source trust scores, shape (batch,)."""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class GraphNarrativeDataset(Dataset):
|
| 131 |
+
"""Dataset for Graph→Narrative training pairs.
|
| 132 |
+
|
| 133 |
+
Loads training examples from JSONL files and provides
|
| 134 |
+
tokenized, padded batches for training.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
data_path: Path to JSONL file with training data.
|
| 138 |
+
tokenizer: AamTokenizer instance for encoding.
|
| 139 |
+
max_seq_len: Maximum sequence length for narratives.
|
| 140 |
+
max_evidence: Maximum number of evidence nodes.
|
| 141 |
+
max_anomalies: Maximum number of anomalies.
|
| 142 |
+
max_reasoning: Maximum number of reasoning steps.
|
| 143 |
+
augment: Whether to apply data augmentation.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
data_path: str | Path,
|
| 149 |
+
tokenizer: AamTokenizer,
|
| 150 |
+
max_seq_len: int = 512,
|
| 151 |
+
max_evidence: int = 50,
|
| 152 |
+
max_anomalies: int = 10,
|
| 153 |
+
max_reasoning: int = 15,
|
| 154 |
+
augment: bool = True,
|
| 155 |
+
):
|
| 156 |
+
self.data_path = Path(data_path)
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
self.max_seq_len = max_seq_len
|
| 159 |
+
self.max_evidence = max_evidence
|
| 160 |
+
self.max_anomalies = max_anomalies
|
| 161 |
+
self.max_reasoning = max_reasoning
|
| 162 |
+
self.augment = augment
|
| 163 |
+
|
| 164 |
+
# Load data
|
| 165 |
+
self.examples: list[GraphNarrativeExample] = []
|
| 166 |
+
self._load_data()
|
| 167 |
+
|
| 168 |
+
logger.info(
|
| 169 |
+
"GraphNarrativeDataset: %d examples loaded from %s",
|
| 170 |
+
len(self.examples),
|
| 171 |
+
self.data_path,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def _load_data(self) -> None:
|
| 175 |
+
"""Load examples from JSONL file."""
|
| 176 |
+
if not self.data_path.exists():
|
| 177 |
+
logger.warning("Data file not found: %s", self.data_path)
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
with open(self.data_path, "r", encoding="utf-8") as f:
|
| 181 |
+
for line_num, line in enumerate(f, 1):
|
| 182 |
+
line = line.strip()
|
| 183 |
+
if not line:
|
| 184 |
+
continue
|
| 185 |
+
try:
|
| 186 |
+
data = json.loads(line)
|
| 187 |
+
example = GraphNarrativeExample.from_dict(data)
|
| 188 |
+
if example.narrative: # Skip empty narratives
|
| 189 |
+
self.examples.append(example)
|
| 190 |
+
except json.JSONDecodeError:
|
| 191 |
+
logger.warning("Invalid JSON at line %d", line_num)
|
| 192 |
+
|
| 193 |
+
def __len__(self) -> int:
|
| 194 |
+
return len(self.examples)
|
| 195 |
+
|
| 196 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 197 |
+
"""Get a single training example.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Dictionary with tokenized inputs.
|
| 201 |
+
"""
|
| 202 |
+
example = self.examples[idx]
|
| 203 |
+
|
| 204 |
+
# Data augmentation
|
| 205 |
+
if self.augment:
|
| 206 |
+
example = self._augment(example)
|
| 207 |
+
|
| 208 |
+
# Tokenize narrative (target)
|
| 209 |
+
narrative_ids = self.tokenizer.encode(example.narrative, add_special=True)
|
| 210 |
+
narrative_ids = self.tokenizer.pad_sequence(narrative_ids, self.max_seq_len)
|
| 211 |
+
narrative_tensor = torch.tensor(narrative_ids, dtype=torch.long)
|
| 212 |
+
|
| 213 |
+
# Tokenize evidence nodes
|
| 214 |
+
evidence_data = self._tokenize_node_list(
|
| 215 |
+
example.evidence_nodes, max_nodes=self.max_evidence
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Tokenize anomalies
|
| 219 |
+
anomaly_data = self._tokenize_node_list(
|
| 220 |
+
example.anomalies, max_nodes=self.max_anomalies
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Tokenize reasoning steps
|
| 224 |
+
reasoning_data = self._tokenize_node_list(
|
| 225 |
+
example.reasoning_steps, max_nodes=self.max_reasoning
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Source trust
|
| 229 |
+
source_trust = torch.tensor(example.source_trust, dtype=torch.float32)
|
| 230 |
+
|
| 231 |
+
# Evidence confidence
|
| 232 |
+
conf_values = list(example.confidence_map.values())[:self.max_evidence]
|
| 233 |
+
if conf_values:
|
| 234 |
+
evidence_conf = torch.tensor(conf_values, dtype=torch.float32)
|
| 235 |
+
evidence_conf = torch.nn.functional.pad(
|
| 236 |
+
evidence_conf, (0, self.max_evidence - len(conf_values))
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
evidence_conf = torch.zeros(self.max_evidence, dtype=torch.float32)
|
| 240 |
+
|
| 241 |
+
# Anomaly confidence (default 0.6 for detected anomalies)
|
| 242 |
+
anomaly_conf = torch.full(
|
| 243 |
+
(self.max_anomalies,), 0.6, dtype=torch.float32
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Reasoning confidence (default 0.7)
|
| 247 |
+
reasoning_conf = torch.full(
|
| 248 |
+
(self.max_reasoning,), 0.7, dtype=torch.float32
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"token_ids": narrative_tensor,
|
| 253 |
+
"evidence_ids": evidence_data["ids"],
|
| 254 |
+
"evidence_confidence": evidence_conf,
|
| 255 |
+
"anomaly_ids": anomaly_data["ids"],
|
| 256 |
+
"anomaly_confidence": anomaly_conf,
|
| 257 |
+
"reasoning_ids": reasoning_data["ids"],
|
| 258 |
+
"reasoning_confidence": reasoning_conf,
|
| 259 |
+
"source_trust": source_trust,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
def _tokenize_node_list(
|
| 263 |
+
self,
|
| 264 |
+
nodes: list[str],
|
| 265 |
+
max_nodes: int,
|
| 266 |
+
max_node_len: int = 32,
|
| 267 |
+
) -> dict[str, torch.Tensor]:
|
| 268 |
+
"""Tokenize a list of node descriptions.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
nodes: List of node text descriptions.
|
| 272 |
+
max_nodes: Maximum number of nodes to encode.
|
| 273 |
+
max_node_len: Maximum token length per node.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dictionary with padded token IDs tensor.
|
| 277 |
+
"""
|
| 278 |
+
if not nodes:
|
| 279 |
+
return {
|
| 280 |
+
"ids": torch.zeros(max_nodes, max_node_len, dtype=torch.long),
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
# Limit to max_nodes
|
| 284 |
+
nodes = nodes[:max_nodes]
|
| 285 |
+
|
| 286 |
+
all_ids = []
|
| 287 |
+
for node in nodes:
|
| 288 |
+
ids = self.tokenizer.encode(node, add_special=False)
|
| 289 |
+
ids = self.tokenizer.pad_sequence(ids, max_node_len)
|
| 290 |
+
all_ids.append(ids)
|
| 291 |
+
|
| 292 |
+
# Pad to max_nodes
|
| 293 |
+
while len(all_ids) < max_nodes:
|
| 294 |
+
all_ids.append([0] * max_node_len)
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
"ids": torch.tensor(all_ids, dtype=torch.long),
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
def _augment(self, example: GraphNarrativeExample) -> GraphNarrativeExample:
|
| 301 |
+
"""Apply data augmentation.
|
| 302 |
+
|
| 303 |
+
Augmentation strategies:
|
| 304 |
+
1. Random sentence shuffling within the narrative
|
| 305 |
+
2. Random evidence node dropping (simulate incomplete data)
|
| 306 |
+
3. Random confidence perturbation
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
example: Original training example.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Augmented example.
|
| 313 |
+
"""
|
| 314 |
+
import copy
|
| 315 |
+
augmented = copy.deepcopy(example)
|
| 316 |
+
|
| 317 |
+
# 1. Sentence shuffling (with 20% probability)
|
| 318 |
+
if random.random() < 0.2:
|
| 319 |
+
sentences = self.tokenizer._split_sentences(augmented.narrative)
|
| 320 |
+
if len(sentences) > 2:
|
| 321 |
+
# Keep first sentence, shuffle the rest
|
| 322 |
+
first = sentences[0]
|
| 323 |
+
rest = sentences[1:]
|
| 324 |
+
random.shuffle(rest)
|
| 325 |
+
augmented.narrative = first + " " + " ".join(rest)
|
| 326 |
+
|
| 327 |
+
# 2. Evidence dropping (with 10% probability per node)
|
| 328 |
+
if augmented.evidence_nodes:
|
| 329 |
+
augmented.evidence_nodes = [
|
| 330 |
+
node for node in augmented.evidence_nodes
|
| 331 |
+
if random.random() > 0.1
|
| 332 |
+
]
|
| 333 |
+
|
| 334 |
+
# 3. Confidence perturbation
|
| 335 |
+
if augmented.confidence_map:
|
| 336 |
+
perturbed = {}
|
| 337 |
+
for k, v in augmented.confidence_map.items():
|
| 338 |
+
noise = random.gauss(0, 0.05)
|
| 339 |
+
perturbed[k] = max(0.0, min(1.0, v + noise))
|
| 340 |
+
augmented.confidence_map = perturbed
|
| 341 |
+
|
| 342 |
+
return augmented
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
| 346 |
+
"""Custom collate function for DataLoader.
|
| 347 |
+
|
| 348 |
+
Handles variable-length graph conditioning by padding
|
| 349 |
+
all tensors in the batch to the same size.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
batch: List of example dictionaries.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
Batched dictionary of tensors.
|
| 356 |
+
"""
|
| 357 |
+
result = {}
|
| 358 |
+
|
| 359 |
+
# Stack all tensors
|
| 360 |
+
for key in batch[0]:
|
| 361 |
+
tensors = [item[key] for item in batch]
|
| 362 |
+
if tensors[0].dim() == 0:
|
| 363 |
+
result[key] = torch.stack(tensors)
|
| 364 |
+
elif tensors[0].dim() == 1:
|
| 365 |
+
result[key] = torch.stack(tensors)
|
| 366 |
+
elif tensors[0].dim() == 2:
|
| 367 |
+
result[key] = torch.stack(tensors)
|
| 368 |
+
else:
|
| 369 |
+
result[key] = torch.stack(tensors)
|
| 370 |
+
|
| 371 |
+
return result
|
diffusion_llm/training/losses.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Loss Functions
|
| 3 |
+
|
| 4 |
+
Implements various loss functions for training the diffusion model,
|
| 5 |
+
including MSE, MAE, Huber, and weighted variants.
|
| 6 |
+
|
| 7 |
+
Analogi: Seperti Jin Soun mengukur seberapa jauh prediksinya
|
| 8 |
+
dari kenyataan — semakin besar gap, semakin besar "rasa sakit"
|
| 9 |
+
yang mendorong perbaikan.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from diffusion_llm.config.model_config import DiffusionConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DiffusionLoss(nn.Module):
|
| 22 |
+
"""Loss function for diffusion model training.
|
| 23 |
+
|
| 24 |
+
Computes the loss between predicted and target values,
|
| 25 |
+
with optional weighting strategies to balance training
|
| 26 |
+
across different noise levels.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
config: DiffusionConfig with loss hyperparameters.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: DiffusionConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.config = config
|
| 35 |
+
|
| 36 |
+
def forward(
|
| 37 |
+
self,
|
| 38 |
+
predicted: torch.Tensor,
|
| 39 |
+
target: torch.Tensor,
|
| 40 |
+
timestep: torch.Tensor,
|
| 41 |
+
alphas_cumprod: torch.Tensor,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
"""Compute diffusion loss.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
predicted: Model output (predicted noise/x0/v).
|
| 47 |
+
target: Target values.
|
| 48 |
+
timestep: Timestep indices for weighting.
|
| 49 |
+
alphas_cumprod: Cumulative product of alphas from scheduler.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Scalar loss value.
|
| 53 |
+
"""
|
| 54 |
+
# Base loss
|
| 55 |
+
if self.config.loss_type == "mse":
|
| 56 |
+
loss = F.mse_loss(predicted, target, reduction="none")
|
| 57 |
+
elif self.config.loss_type == "mae":
|
| 58 |
+
loss = F.l1_loss(predicted, target, reduction="none")
|
| 59 |
+
elif self.config.loss_type == "huber":
|
| 60 |
+
loss = F.smooth_l1_loss(predicted, target, reduction="none")
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown loss_type: {self.config.loss_type}")
|
| 63 |
+
|
| 64 |
+
# Average over feature dimension
|
| 65 |
+
loss = loss.mean(dim=-1) # (batch, seq_len)
|
| 66 |
+
|
| 67 |
+
# Apply weighting
|
| 68 |
+
if self.config.loss_weighting == "min_snr":
|
| 69 |
+
loss = self._min_snr_weight(loss, timestep, alphas_cumprod)
|
| 70 |
+
elif self.config.loss_weighting == "p2":
|
| 71 |
+
loss = self._p2_weight(loss, timestep, alphas_cumprod)
|
| 72 |
+
|
| 73 |
+
return loss.mean()
|
| 74 |
+
|
| 75 |
+
def _min_snr_weight(
|
| 76 |
+
self,
|
| 77 |
+
loss: torch.Tensor,
|
| 78 |
+
timestep: torch.Tensor,
|
| 79 |
+
alphas_cumprod: torch.Tensor,
|
| 80 |
+
gamma: float = 5.0,
|
| 81 |
+
) -> torch.Tensor:
|
| 82 |
+
"""Min-SNR-gamma weighting (Hang et al., 2023)."""
|
| 83 |
+
snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8)
|
| 84 |
+
weight = torch.clamp(snr, max=gamma) / (snr + 1e-8)
|
| 85 |
+
weight = weight.unsqueeze(-1).expand_as(loss)
|
| 86 |
+
return loss * weight
|
| 87 |
+
|
| 88 |
+
def _p2_weight(
|
| 89 |
+
self,
|
| 90 |
+
loss: torch.Tensor,
|
| 91 |
+
timestep: torch.Tensor,
|
| 92 |
+
alphas_cumprod: torch.Tensor,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
"""P2 weighting (Choi et al., 2022)."""
|
| 95 |
+
snr = alphas_cumprod[timestep] / (1 - alphas_cumprod[timestep] + 1e-8)
|
| 96 |
+
weight = 1.0 / (snr ** self.config.p2_gamma + self.config.p2_k)
|
| 97 |
+
weight = weight.unsqueeze(-1).expand_as(loss)
|
| 98 |
+
return loss * weight
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def compute_loss(
|
| 102 |
+
predicted: torch.Tensor,
|
| 103 |
+
target: torch.Tensor,
|
| 104 |
+
timestep: torch.Tensor,
|
| 105 |
+
alphas_cumprod: torch.Tensor,
|
| 106 |
+
loss_type: str = "mse",
|
| 107 |
+
loss_weighting: str = "none",
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""Convenience function to compute diffusion loss without creating a module.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
predicted: Model output.
|
| 113 |
+
target: Target values.
|
| 114 |
+
timestep: Timestep indices.
|
| 115 |
+
alphas_cumprod: Alpha cumulative products.
|
| 116 |
+
loss_type: Loss function type.
|
| 117 |
+
loss_weighting: Weighting strategy.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Scalar loss value.
|
| 121 |
+
"""
|
| 122 |
+
config = DiffusionConfig(
|
| 123 |
+
loss_type=loss_type,
|
| 124 |
+
loss_weighting=loss_weighting,
|
| 125 |
+
)
|
| 126 |
+
loss_fn = DiffusionLoss(config)
|
| 127 |
+
return loss_fn(predicted, target, timestep, alphas_cumprod)
|
diffusion_llm/training/trainer.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AAM Diffusion LLM — Trainer
|
| 3 |
+
|
| 4 |
+
Training loop for the AAM Diffusion Model.
|
| 5 |
+
|
| 6 |
+
Handles:
|
| 7 |
+
- Training loop with gradient accumulation
|
| 8 |
+
- Learning rate scheduling with warmup
|
| 9 |
+
- Mixed precision training (AMP)
|
| 10 |
+
- EMA model updates
|
| 11 |
+
- Checkpoint saving/loading
|
| 12 |
+
- Logging to console and Weights & Biases
|
| 13 |
+
- Evaluation on validation set
|
| 14 |
+
|
| 15 |
+
Analogi: Seperti latihan fisik Jin Soun — berulang-ulang,
|
| 16 |
+
bertahap meningkat intensitas, dengan instruktur yang
|
| 17 |
+
mengawasi dan memberi koreksi.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import math
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
from torch.utils.data import DataLoader
|
| 32 |
+
|
| 33 |
+
from diffusion_llm.config.model_config import AamDiffusionConfig
|
| 34 |
+
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
|
| 35 |
+
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
|
| 36 |
+
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
|
| 37 |
+
from diffusion_llm.training.losses import DiffusionLoss
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class AamTrainer:
|
| 43 |
+
"""Trainer for the AAM Diffusion Model.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
config: AamDiffusionConfig with training settings.
|
| 47 |
+
model: AamDiffusionModel instance.
|
| 48 |
+
tokenizer: AamTokenizer instance.
|
| 49 |
+
train_dataset: Training dataset.
|
| 50 |
+
val_dataset: Optional validation dataset.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
config: AamDiffusionConfig,
|
| 56 |
+
model: AamDiffusionModel,
|
| 57 |
+
tokenizer: AamTokenizer,
|
| 58 |
+
train_dataset: GraphNarrativeDataset,
|
| 59 |
+
val_dataset: Optional[GraphNarrativeDataset] = None,
|
| 60 |
+
):
|
| 61 |
+
self.config = config
|
| 62 |
+
self.model = model
|
| 63 |
+
self.tokenizer = tokenizer
|
| 64 |
+
self.train_dataset = train_dataset
|
| 65 |
+
self.val_dataset = val_dataset
|
| 66 |
+
|
| 67 |
+
# Device
|
| 68 |
+
self.device = torch.device(
|
| 69 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
)
|
| 71 |
+
self.model.to(self.device)
|
| 72 |
+
logger.info("Training on device: %s", self.device)
|
| 73 |
+
|
| 74 |
+
# Optimizer
|
| 75 |
+
self.optimizer = torch.optim.AdamW(
|
| 76 |
+
self.model.parameters(),
|
| 77 |
+
lr=config.training.learning_rate,
|
| 78 |
+
weight_decay=config.training.weight_decay,
|
| 79 |
+
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
| 80 |
+
eps=config.training.adam_eps,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Loss function
|
| 84 |
+
self.loss_fn = DiffusionLoss(config.diffusion)
|
| 85 |
+
|
| 86 |
+
# Data loaders
|
| 87 |
+
self.train_loader = DataLoader(
|
| 88 |
+
train_dataset,
|
| 89 |
+
batch_size=config.training.batch_size,
|
| 90 |
+
shuffle=True,
|
| 91 |
+
num_workers=config.training.num_workers,
|
| 92 |
+
collate_fn=collate_fn,
|
| 93 |
+
pin_memory=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if val_dataset:
|
| 97 |
+
self.val_loader = DataLoader(
|
| 98 |
+
val_dataset,
|
| 99 |
+
batch_size=config.training.batch_size,
|
| 100 |
+
shuffle=False,
|
| 101 |
+
num_workers=config.training.num_workers,
|
| 102 |
+
collate_fn=collate_fn,
|
| 103 |
+
pin_memory=True,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
self.val_loader = None
|
| 107 |
+
|
| 108 |
+
# LR scheduler
|
| 109 |
+
self.scheduler = self._create_lr_scheduler()
|
| 110 |
+
|
| 111 |
+
# AMP
|
| 112 |
+
self.scaler = None
|
| 113 |
+
if config.training.use_amp:
|
| 114 |
+
dtype = torch.bfloat16 if config.training.amp_dtype == "bf16" else torch.float16
|
| 115 |
+
self.scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
|
| 116 |
+
|
| 117 |
+
# EMA
|
| 118 |
+
self.ema_model = None
|
| 119 |
+
if config.training.use_ema:
|
| 120 |
+
self.ema_model = self._create_ema_model()
|
| 121 |
+
|
| 122 |
+
# State tracking
|
| 123 |
+
self.global_step = 0
|
| 124 |
+
self.best_val_loss = float("inf")
|
| 125 |
+
self.train_losses: list[float] = []
|
| 126 |
+
|
| 127 |
+
# Output directory
|
| 128 |
+
self.output_dir = Path(config.output_dir)
|
| 129 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
|
| 131 |
+
# Seed
|
| 132 |
+
torch.manual_seed(config.seed)
|
| 133 |
+
|
| 134 |
+
def _create_lr_scheduler(self):
|
| 135 |
+
"""Create learning rate scheduler with warmup."""
|
| 136 |
+
total_steps = self.config.training.max_steps
|
| 137 |
+
warmup_steps = self.config.training.warmup_steps
|
| 138 |
+
|
| 139 |
+
def lr_lambda(step: int) -> float:
|
| 140 |
+
if step < warmup_steps:
|
| 141 |
+
return step / max(warmup_steps, 1)
|
| 142 |
+
if self.config.training.lr_schedule == "cosine":
|
| 143 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 144 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 145 |
+
elif self.config.training.lr_schedule == "linear":
|
| 146 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 147 |
+
return 1.0 - progress
|
| 148 |
+
else:
|
| 149 |
+
return 1.0
|
| 150 |
+
|
| 151 |
+
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
|
| 152 |
+
|
| 153 |
+
def _create_ema_model(self) -> AamDiffusionModel:
|
| 154 |
+
"""Create EMA copy of the model."""
|
| 155 |
+
import copy
|
| 156 |
+
ema = copy.deepcopy(self.model)
|
| 157 |
+
for param in ema.parameters():
|
| 158 |
+
param.requires_grad = False
|
| 159 |
+
return ema
|
| 160 |
+
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def _update_ema(self) -> None:
|
| 163 |
+
"""Update EMA model weights."""
|
| 164 |
+
if self.ema_model is None:
|
| 165 |
+
return
|
| 166 |
+
decay = self.config.training.ema_decay
|
| 167 |
+
for ema_param, model_param in zip(
|
| 168 |
+
self.ema_model.parameters(), self.model.parameters()
|
| 169 |
+
):
|
| 170 |
+
ema_param.data.mul_(decay).add_(model_param.data, alpha=1 - decay)
|
| 171 |
+
|
| 172 |
+
def train(self) -> None:
|
| 173 |
+
"""Main training loop.
|
| 174 |
+
|
| 175 |
+
Runs for max_steps or max_epochs, whichever comes first.
|
| 176 |
+
Saves checkpoints and runs evaluation periodically.
|
| 177 |
+
"""
|
| 178 |
+
logger.info("Starting training...")
|
| 179 |
+
logger.info(" Max steps: %d", self.config.training.max_steps)
|
| 180 |
+
logger.info(" Batch size: %d", self.config.training.batch_size)
|
| 181 |
+
logger.info(" Gradient accumulation: %d", self.config.training.gradient_accumulation_steps)
|
| 182 |
+
logger.info(" Effective batch size: %d",
|
| 183 |
+
self.config.training.batch_size * self.config.training.gradient_accumulation_steps)
|
| 184 |
+
|
| 185 |
+
start_time = time.time()
|
| 186 |
+
epoch = 0
|
| 187 |
+
|
| 188 |
+
while self.global_step < self.config.training.max_steps:
|
| 189 |
+
epoch += 1
|
| 190 |
+
if epoch > self.config.training.max_epochs:
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
logger.info("=== Epoch %d ===", epoch)
|
| 194 |
+
epoch_loss = 0.0
|
| 195 |
+
n_batches = 0
|
| 196 |
+
|
| 197 |
+
for batch_idx, batch in enumerate(self.train_loader):
|
| 198 |
+
loss = self._train_step(batch)
|
| 199 |
+
epoch_loss += loss
|
| 200 |
+
n_batches += 1
|
| 201 |
+
|
| 202 |
+
# Logging
|
| 203 |
+
if self.global_step % self.config.training.log_every_steps == 0:
|
| 204 |
+
avg_loss = epoch_loss / max(n_batches, 1)
|
| 205 |
+
lr = self.optimizer.param_groups[0]["lr"]
|
| 206 |
+
elapsed = time.time() - start_time
|
| 207 |
+
steps_per_sec = self.global_step / max(elapsed, 1)
|
| 208 |
+
|
| 209 |
+
logger.info(
|
| 210 |
+
"Step %d | Loss: %.4f | LR: %.2e | Speed: %.1f steps/s",
|
| 211 |
+
self.global_step, loss, lr, steps_per_sec,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Evaluation
|
| 215 |
+
if (self.global_step % self.config.training.eval_every_steps == 0
|
| 216 |
+
and self.val_loader is not None):
|
| 217 |
+
val_loss = self.evaluate()
|
| 218 |
+
logger.info("Validation loss: %.4f", val_loss)
|
| 219 |
+
if val_loss < self.best_val_loss:
|
| 220 |
+
self.best_val_loss = val_loss
|
| 221 |
+
self._save_checkpoint("best.pt")
|
| 222 |
+
|
| 223 |
+
# Checkpoint
|
| 224 |
+
if self.global_step % self.config.training.save_every_steps == 0:
|
| 225 |
+
self._save_checkpoint(f"step_{self.global_step}.pt")
|
| 226 |
+
|
| 227 |
+
# Stop condition
|
| 228 |
+
if self.global_step >= self.config.training.max_steps:
|
| 229 |
+
break
|
| 230 |
+
|
| 231 |
+
avg_epoch_loss = epoch_loss / max(n_batches, 1)
|
| 232 |
+
logger.info("Epoch %d complete. Average loss: %.4f", epoch, avg_epoch_loss)
|
| 233 |
+
|
| 234 |
+
# Final save
|
| 235 |
+
self._save_checkpoint("final.pt")
|
| 236 |
+
elapsed = time.time() - start_time
|
| 237 |
+
logger.info(
|
| 238 |
+
"Training complete! %d steps in %.1f hours",
|
| 239 |
+
self.global_step, elapsed / 3600,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def _train_step(self, batch: dict[str, torch.Tensor]) -> float:
|
| 243 |
+
"""Single training step.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
batch: Batch of training data.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Loss value for this step.
|
| 250 |
+
"""
|
| 251 |
+
self.model.train()
|
| 252 |
+
|
| 253 |
+
# Move batch to device
|
| 254 |
+
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 255 |
+
for k, v in batch.items()}
|
| 256 |
+
|
| 257 |
+
# Sample random timesteps
|
| 258 |
+
batch_size = batch["token_ids"].shape[0]
|
| 259 |
+
t = torch.randint(
|
| 260 |
+
0, self.config.diffusion.n_timesteps,
|
| 261 |
+
(batch_size,), device=self.device,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Forward pass
|
| 265 |
+
if self.scaler is not None:
|
| 266 |
+
with torch.amp.autocast("cuda", enabled=True):
|
| 267 |
+
predicted, target = self.model(
|
| 268 |
+
token_ids=batch["token_ids"],
|
| 269 |
+
timestep=t,
|
| 270 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 271 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 272 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 273 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 274 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 275 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 276 |
+
source_trust=batch.get("source_trust"),
|
| 277 |
+
)
|
| 278 |
+
loss = self.model.compute_loss(predicted, target, t)
|
| 279 |
+
loss = loss / self.config.training.gradient_accumulation_steps
|
| 280 |
+
else:
|
| 281 |
+
predicted, target = self.model(
|
| 282 |
+
token_ids=batch["token_ids"],
|
| 283 |
+
timestep=t,
|
| 284 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 285 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 286 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 287 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 288 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 289 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 290 |
+
source_trust=batch.get("source_trust"),
|
| 291 |
+
)
|
| 292 |
+
loss = self.model.compute_loss(predicted, target, t)
|
| 293 |
+
loss = loss / self.config.training.gradient_accumulation_steps
|
| 294 |
+
|
| 295 |
+
# Backward pass
|
| 296 |
+
if self.scaler is not None:
|
| 297 |
+
self.scaler.scale(loss).backward()
|
| 298 |
+
else:
|
| 299 |
+
loss.backward()
|
| 300 |
+
|
| 301 |
+
# Gradient accumulation
|
| 302 |
+
if (self.global_step + 1) % self.config.training.gradient_accumulation_steps == 0:
|
| 303 |
+
# Gradient clipping
|
| 304 |
+
if self.scaler is not None:
|
| 305 |
+
self.scaler.unscale_(self.optimizer)
|
| 306 |
+
torch.nn.utils.clip_grad_norm_(
|
| 307 |
+
self.model.parameters(),
|
| 308 |
+
self.config.training.grad_clip_norm,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Optimizer step
|
| 312 |
+
if self.scaler is not None:
|
| 313 |
+
self.scaler.step(self.optimizer)
|
| 314 |
+
self.scaler.update()
|
| 315 |
+
else:
|
| 316 |
+
self.optimizer.step()
|
| 317 |
+
|
| 318 |
+
# LR schedule
|
| 319 |
+
self.scheduler.step()
|
| 320 |
+
|
| 321 |
+
# Zero gradients
|
| 322 |
+
self.optimizer.zero_grad()
|
| 323 |
+
|
| 324 |
+
# EMA update
|
| 325 |
+
self._update_ema()
|
| 326 |
+
|
| 327 |
+
self.global_step += 1
|
| 328 |
+
self.train_losses.append(loss.item())
|
| 329 |
+
|
| 330 |
+
return loss.item()
|
| 331 |
+
|
| 332 |
+
@torch.no_grad()
|
| 333 |
+
def evaluate(self) -> float:
|
| 334 |
+
"""Evaluate on validation set.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Average validation loss.
|
| 338 |
+
"""
|
| 339 |
+
if self.val_loader is None:
|
| 340 |
+
return float("inf")
|
| 341 |
+
|
| 342 |
+
self.model.eval()
|
| 343 |
+
total_loss = 0.0
|
| 344 |
+
n_batches = 0
|
| 345 |
+
|
| 346 |
+
for batch in self.val_loader:
|
| 347 |
+
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
| 348 |
+
for k, v in batch.items()}
|
| 349 |
+
|
| 350 |
+
batch_size = batch["token_ids"].shape[0]
|
| 351 |
+
t = torch.randint(
|
| 352 |
+
0, self.config.diffusion.n_timesteps,
|
| 353 |
+
(batch_size,), device=self.device,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
predicted, target = self.model(
|
| 357 |
+
token_ids=batch["token_ids"],
|
| 358 |
+
timestep=t,
|
| 359 |
+
evidence_ids=batch.get("evidence_ids"),
|
| 360 |
+
evidence_confidence=batch.get("evidence_confidence"),
|
| 361 |
+
anomaly_ids=batch.get("anomaly_ids"),
|
| 362 |
+
anomaly_confidence=batch.get("anomaly_confidence"),
|
| 363 |
+
reasoning_ids=batch.get("reasoning_ids"),
|
| 364 |
+
reasoning_confidence=batch.get("reasoning_confidence"),
|
| 365 |
+
source_trust=batch.get("source_trust"),
|
| 366 |
+
)
|
| 367 |
+
loss = self.model.compute_loss(predicted, target, t)
|
| 368 |
+
total_loss += loss.item()
|
| 369 |
+
n_batches += 1
|
| 370 |
+
|
| 371 |
+
avg_loss = total_loss / max(n_batches, 1)
|
| 372 |
+
self.model.train()
|
| 373 |
+
return avg_loss
|
| 374 |
+
|
| 375 |
+
def _save_checkpoint(self, filename: str) -> None:
|
| 376 |
+
"""Save training checkpoint.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
filename: Checkpoint filename.
|
| 380 |
+
"""
|
| 381 |
+
path = self.output_dir / filename
|
| 382 |
+
checkpoint = {
|
| 383 |
+
"model_state_dict": self.model.state_dict(),
|
| 384 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 385 |
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
| 386 |
+
"global_step": self.global_step,
|
| 387 |
+
"best_val_loss": self.best_val_loss,
|
| 388 |
+
"config": self.config.to_dict(),
|
| 389 |
+
}
|
| 390 |
+
if self.ema_model is not None:
|
| 391 |
+
checkpoint["ema_state_dict"] = self.ema_model.state_dict()
|
| 392 |
+
|
| 393 |
+
torch.save(checkpoint, path)
|
| 394 |
+
logger.info("Checkpoint saved: %s", path)
|
| 395 |
+
|
| 396 |
+
# Clean up old checkpoints
|
| 397 |
+
self._cleanup_checkpoints()
|
| 398 |
+
|
| 399 |
+
def _cleanup_checkpoints(self) -> None:
|
| 400 |
+
"""Remove old checkpoints, keeping only the last N."""
|
| 401 |
+
keep_n = self.config.training.keep_last_n_checkpoints
|
| 402 |
+
checkpoints = sorted(self.output_dir.glob("step_*.pt"))
|
| 403 |
+
while len(checkpoints) > keep_n:
|
| 404 |
+
oldest = checkpoints.pop(0)
|
| 405 |
+
oldest.unlink()
|
| 406 |
+
logger.info("Removed old checkpoint: %s", oldest)
|
| 407 |
+
|
| 408 |
+
def load_checkpoint(self, path: str) -> None:
|
| 409 |
+
"""Load from checkpoint.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
path: Checkpoint file path.
|
| 413 |
+
"""
|
| 414 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 415 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 416 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 417 |
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 418 |
+
self.global_step = checkpoint["global_step"]
|
| 419 |
+
self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
|
| 420 |
+
logger.info("Loaded checkpoint from step %d", self.global_step)
|
inference_example.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""AAM Diffusion LLM v1.0 — Inference Example"""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from diffusion_llm import AamDiffusionModel, AamTokenizer, AamGenerator, AamDiffusionConfig
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
# Load model and tokenizer
|
| 13 |
+
config = AamDiffusionConfig.from_json("config.json")
|
| 14 |
+
model = AamDiffusionModel.load("model.pt", device="cpu")
|
| 15 |
+
tokenizer = AamTokenizer.load("tokenizer.json")
|
| 16 |
+
|
| 17 |
+
# Create generator
|
| 18 |
+
generator = AamGenerator(model, tokenizer, config)
|
| 19 |
+
|
| 20 |
+
# Generate narrative from graph conditioning
|
| 21 |
+
result = generator.generate(
|
| 22 |
+
trigger="Siapa yang mencuri Snow Plum Pill?",
|
| 23 |
+
evidence_nodes=["Hefei", "Diancang Five Swords", "Ju Jangmok"],
|
| 24 |
+
anomalies=["Tidak ada konsumsi pil baru di pasar gelap"],
|
| 25 |
+
reasoning_steps=["Cross-reference tanggal kejadian", "Deteksi anomali pola"],
|
| 26 |
+
source_trust=0.85,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
print("=" * 60)
|
| 30 |
+
print(" AAM Diffusion LLM — Generated Narrative")
|
| 31 |
+
print("=" * 60)
|
| 32 |
+
print(f" Narrative: {result.narrative}")
|
| 33 |
+
print(f" Confidence: {result.confidence:.1%}")
|
| 34 |
+
print(f" Steps: {result.n_diffusion_steps}")
|
| 35 |
+
print(f" Time: {result.generation_time_s:.2f}s")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:651e28c3b1fd60919884cc7e6311cd7a604c368669b9abecf27adb2efbc1eaea
|
| 3 |
+
size 1297247
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.24.0
|
tokenizer.json
ADDED
|
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config": {
|
| 3 |
+
"bpe_vocab_size": 28000,
|
| 4 |
+
"max_sentences": 32,
|
| 5 |
+
"sentence_boundary_token": "<sent>",
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"bos_token": "<bos>",
|
| 8 |
+
"eos_token": "<eos>",
|
| 9 |
+
"mask_token": "<mask>",
|
| 10 |
+
"noise_token": "<noise>",
|
| 11 |
+
"min_frequency": 2
|
| 12 |
+
},
|
| 13 |
+
"vocab": {
|
| 14 |
+
"<pad>": 0,
|
| 15 |
+
"<bos>": 1,
|
| 16 |
+
"<eos>": 2,
|
| 17 |
+
"<mask>": 3,
|
| 18 |
+
"<noise>": 4,
|
| 19 |
+
"<sent>": 5,
|
| 20 |
+
"<evidence>": 6,
|
| 21 |
+
"<anomaly>": 7,
|
| 22 |
+
"<confidence>": 8,
|
| 23 |
+
"<reasoning>": 9,
|
| 24 |
+
"<composition>": 10,
|
| 25 |
+
"<temporal>": 11,
|
| 26 |
+
"<unk>": 12,
|
| 27 |
+
"%": 13,
|
| 28 |
+
",": 14,
|
| 29 |
+
"-": 15,
|
| 30 |
+
".": 16,
|
| 31 |
+
"0": 17,
|
| 32 |
+
"1": 18,
|
| 33 |
+
"2": 19,
|
| 34 |
+
"3": 20,
|
| 35 |
+
"4": 21,
|
| 36 |
+
"6": 22,
|
| 37 |
+
"7": 23,
|
| 38 |
+
"8": 24,
|
| 39 |
+
"9": 25,
|
| 40 |
+
":": 26,
|
| 41 |
+
";": 27,
|
| 42 |
+
"?": 28,
|
| 43 |
+
"_": 29,
|
| 44 |
+
"a": 30,
|
| 45 |
+
"b": 31,
|
| 46 |
+
"c": 32,
|
| 47 |
+
"d": 33,
|
| 48 |
+
"e": 34,
|
| 49 |
+
"f": 35,
|
| 50 |
+
"g": 36,
|
| 51 |
+
"h": 37,
|
| 52 |
+
"i": 38,
|
| 53 |
+
"j": 39,
|
| 54 |
+
"k": 40,
|
| 55 |
+
"l": 41,
|
| 56 |
+
"m": 42,
|
| 57 |
+
"n": 43,
|
| 58 |
+
"o": 44,
|
| 59 |
+
"p": 45,
|
| 60 |
+
"r": 46,
|
| 61 |
+
"s": 47,
|
| 62 |
+
"t": 48,
|
| 63 |
+
"u": 49,
|
| 64 |
+
"v": 50,
|
| 65 |
+
"w": 51,
|
| 66 |
+
"y": 52,
|
| 67 |
+
"z": 53,
|
| 68 |
+
"an": 54,
|
| 69 |
+
"an</w>": 55,
|
| 70 |
+
"er": 56,
|
| 71 |
+
"en": 57,
|
| 72 |
+
"da": 58,
|
| 73 |
+
"ti": 59,
|
| 74 |
+
"il": 60,
|
| 75 |
+
"si": 61,
|
| 76 |
+
"di": 62,
|
| 77 |
+
"ang</w>": 63,
|
| 78 |
+
"si</w>": 64,
|
| 79 |
+
"anc": 65,
|
| 80 |
+
"kan</w>": 66,
|
| 81 |
+
"al": 67,
|
| 82 |
+
"su": 68,
|
| 83 |
+
"ang": 69,
|
| 84 |
+
"ri</w>": 70,
|
| 85 |
+
"ke": 71,
|
| 86 |
+
"ef": 72,
|
| 87 |
+
"ter": 73,
|
| 88 |
+
"se": 74,
|
| 89 |
+
"te": 75,
|
| 90 |
+
"pa": 76,
|
| 91 |
+
"ng": 77,
|
| 92 |
+
"on</w>": 78,
|
| 93 |
+
"on": 79,
|
| 94 |
+
"hef": 80,
|
| 95 |
+
"hefe": 81,
|
| 96 |
+
"enc": 82,
|
| 97 |
+
"or": 83,
|
| 98 |
+
"la": 84,
|
| 99 |
+
"sim": 85,
|
| 100 |
+
"ul": 86,
|
| 101 |
+
"tida": 87,
|
| 102 |
+
"ar": 88,
|
| 103 |
+
"eng": 89,
|
| 104 |
+
"dari</w>": 90,
|
| 105 |
+
"re": 91,
|
| 106 |
+
"bu": 92,
|
| 107 |
+
"ance</w>": 93,
|
| 108 |
+
"ra": 94,
|
| 109 |
+
"om": 95,
|
| 110 |
+
"hefei</w>": 96,
|
| 111 |
+
"jang": 97,
|
| 112 |
+
"sa": 98,
|
| 113 |
+
"ju</w>": 99,
|
| 114 |
+
"jangm": 100,
|
| 115 |
+
"jangmo": 101,
|
| 116 |
+
"jangmok</w>": 102,
|
| 117 |
+
"al</w>": 103,
|
| 118 |
+
"os": 104,
|
| 119 |
+
"dianc": 105,
|
| 120 |
+
"diancang</w>": 106,
|
| 121 |
+
"ai": 107,
|
| 122 |
+
"in": 108,
|
| 123 |
+
"ja": 109,
|
| 124 |
+
"kon": 110,
|
| 125 |
+
"li": 111,
|
| 126 |
+
"ct</w>": 112,
|
| 127 |
+
"tidak</w>": 113,
|
| 128 |
+
"eri": 114,
|
| 129 |
+
"fi": 115,
|
| 130 |
+
"meng": 116,
|
| 131 |
+
"asi</w>": 117,
|
| 132 |
+
"kesim": 118,
|
| 133 |
+
"kesimp": 119,
|
| 134 |
+
"kesimpul": 120,
|
| 135 |
+
"kesimpulan</w>": 121,
|
| 136 |
+
"di</w>": 122,
|
| 137 |
+
"ngkan</w>": 123,
|
| 138 |
+
"ksi</w>": 124,
|
| 139 |
+
"pi": 125,
|
| 140 |
+
"ya</w>": 126,
|
| 141 |
+
"yang</w>": 127,
|
| 142 |
+
"encu": 128,
|
| 143 |
+
"ta": 129,
|
| 144 |
+
"buk": 130,
|
| 145 |
+
"bukt": 131,
|
| 146 |
+
"bukti</w>": 132,
|
| 147 |
+
"pen": 133,
|
| 148 |
+
"per": 134,
|
| 149 |
+
"lu": 135,
|
| 150 |
+
"le": 136,
|
| 151 |
+
"fiv": 137,
|
| 152 |
+
"five</w>": 138,
|
| 153 |
+
"sw": 139,
|
| 154 |
+
"swor": 140,
|
| 155 |
+
"sword": 141,
|
| 156 |
+
"swords</w>": 142,
|
| 157 |
+
"pencu": 143,
|
| 158 |
+
"ence</w>": 144,
|
| 159 |
+
"ce": 145,
|
| 160 |
+
"ku": 146,
|
| 161 |
+
"ili": 147,
|
| 162 |
+
"sn": 148,
|
| 163 |
+
"sno": 149,
|
| 164 |
+
"snow</w>": 150,
|
| 165 |
+
"plu": 151,
|
| 166 |
+
"plum</w>": 152,
|
| 167 |
+
"pil": 153,
|
| 168 |
+
"pill</w>": 154,
|
| 169 |
+
"mengh": 155,
|
| 170 |
+
"menghil": 156,
|
| 171 |
+
"menghilang</w>": 157,
|
| 172 |
+
"lo": 158,
|
| 173 |
+
"bi": 159,
|
| 174 |
+
"de": 160,
|
| 175 |
+
"anom": 161,
|
| 176 |
+
"anomal": 162,
|
| 177 |
+
"mar": 163,
|
| 178 |
+
"marti": 164,
|
| 179 |
+
"martial</w>": 165,
|
| 180 |
+
"alli": 166,
|
| 181 |
+
"alliance</w>": 167,
|
| 182 |
+
"mu": 168,
|
| 183 |
+
"anal": 169,
|
| 184 |
+
"anali": 170,
|
| 185 |
+
"analisi": 171,
|
| 186 |
+
"analisis</w>": 172,
|
| 187 |
+
"gy": 173,
|
| 188 |
+
"gyer": 174,
|
| 189 |
+
"gyery": 175,
|
| 190 |
+
"gyeryon": 176,
|
| 191 |
+
"gyeryong</w>": 177,
|
| 192 |
+
"mer": 178,
|
| 193 |
+
"merc": 179,
|
| 194 |
+
"merch": 180,
|
| 195 |
+
"merchan": 181,
|
| 196 |
+
"merchant</w>": 182,
|
| 197 |
+
"gu": 183,
|
| 198 |
+
"guil": 184,
|
| 199 |
+
"guild</w>": 185,
|
| 200 |
+
"ha": 186,
|
| 201 |
+
"cr": 187,
|
| 202 |
+
"cros": 188,
|
| 203 |
+
"cross</w>": 189,
|
| 204 |
+
"ref": 190,
|
| 205 |
+
"refer": 191,
|
| 206 |
+
"reference</w>": 192,
|
| 207 |
+
"keja": 193,
|
| 208 |
+
"kejadi": 194,
|
| 209 |
+
"kejadian</w>": 195,
|
| 210 |
+
"simh": 196,
|
| 211 |
+
"simhy": 197,
|
| 212 |
+
"simhye": 198,
|
| 213 |
+
"simhyeon</w>": 199,
|
| 214 |
+
"pav": 200,
|
| 215 |
+
"pavili": 201,
|
| 216 |
+
"pavilion</w>": 202,
|
| 217 |
+
"me": 203,
|
| 218 |
+
"tion</w>": 204,
|
| 219 |
+
"sum": 205,
|
| 220 |
+
"blo": 206,
|
| 221 |
+
"bloo": 207,
|
| 222 |
+
"blood</w>": 208,
|
| 223 |
+
"ser": 209,
|
| 224 |
+
"serpen": 210,
|
| 225 |
+
"serpent</w>": 211,
|
| 226 |
+
"dance</w>": 212,
|
| 227 |
+
"ste": 213,
|
| 228 |
+
"step</w>": 214,
|
| 229 |
+
"pre": 215,
|
| 230 |
+
"predi": 216,
|
| 231 |
+
"tin": 217,
|
| 232 |
+
"tinda": 218,
|
| 233 |
+
"tindakan</w>": 219,
|
| 234 |
+
"beri": 220,
|
| 235 |
+
"beriku": 221,
|
| 236 |
+
"berikut": 222,
|
| 237 |
+
"berikutn": 223,
|
| 238 |
+
"berikutnya</w>": 224,
|
| 239 |
+
"tae": 225,
|
| 240 |
+
"taeul": 226,
|
| 241 |
+
"taeul_": 227,
|
| 242 |
+
"taeul_se": 228,
|
| 243 |
+
"taeul_sect</w>": 229,
|
| 244 |
+
"po": 230,
|
| 245 |
+
"pol": 231,
|
| 246 |
+
"pola</w>": 232,
|
| 247 |
+
"jang</w>": 233,
|
| 248 |
+
"hang": 234,
|
| 249 |
+
"hangi</w>": 235,
|
| 250 |
+
"ad": 236,
|
| 251 |
+
"ada</w>": 237,
|
| 252 |
+
"bar": 238,
|
| 253 |
+
"baru</w>": 239,
|
| 254 |
+
"pat": 240,
|
| 255 |
+
"patter": 241,
|
| 256 |
+
"pattern</w>": 242,
|
| 257 |
+
"terpi": 243,
|
| 258 |
+
"terpisa": 244,
|
| 259 |
+
"terpisah</w>": 245,
|
| 260 |
+
"com": 246,
|
| 261 |
+
"comp": 247,
|
| 262 |
+
"as": 248,
|
| 263 |
+
"dete": 249,
|
| 264 |
+
"deteksi</w>": 250,
|
| 265 |
+
"gu</w>": 251,
|
| 266 |
+
"ilm": 252,
|
| 267 |
+
"ilmu</w>": 253,
|
| 268 |
+
"ketida": 254,
|
| 269 |
+
"ketidak": 255,
|
| 270 |
+
"ketidakse": 256,
|
| 271 |
+
"ketidaksesu": 257,
|
| 272 |
+
"ketidaksesuai": 258,
|
| 273 |
+
"ketidaksesuaian</w>": 259,
|
| 274 |
+
"terk": 260,
|
| 275 |
+
"terkai": 261,
|
| 276 |
+
"terkait</w>": 262,
|
| 277 |
+
"lap": 263,
|
| 278 |
+
"lapor": 264,
|
| 279 |
+
"laporan</w>": 265,
|
| 280 |
+
"hu": 266,
|
| 281 |
+
"hubu": 267,
|
| 282 |
+
"ela": 268,
|
| 283 |
+
"dar": 269,
|
| 284 |
+
"dark": 270,
|
| 285 |
+
"dark_": 271,
|
| 286 |
+
"dark_f": 272,
|
| 287 |
+
"dark_fa": 273,
|
| 288 |
+
"dark_fac": 274,
|
| 289 |
+
"dark_faction</w>": 275,
|
| 290 |
+
"at</w>": 276,
|
| 291 |
+
"anomaly</w>": 277,
|
| 292 |
+
"ban": 278,
|
| 293 |
+
"bandi": 279,
|
| 294 |
+
"bandingkan</w>": 280,
|
| 295 |
+
"tang": 281,
|
| 296 |
+
"tangg": 282,
|
| 297 |
+
"tanggal</w>": 283,
|
| 298 |
+
"hefei": 284,
|
| 299 |
+
"hefei_": 285,
|
| 300 |
+
"hefei_b": 286,
|
| 301 |
+
"hefei_br": 287,
|
| 302 |
+
"hefei_branc": 288,
|
| 303 |
+
"hefei_branch</w>": 289,
|
| 304 |
+
"deng": 290,
|
| 305 |
+
"dengan</w>": 291,
|
| 306 |
+
"hubungkan</w>": 292,
|
| 307 |
+
"fra": 293,
|
| 308 |
+
"frag": 294,
|
| 309 |
+
"fragme": 295,
|
| 310 |
+
"fragmen</w>": 296,
|
| 311 |
+
"pencuri</w>": 297,
|
| 312 |
+
"compos": 298,
|
| 313 |
+
"compose</w>": 299,
|
| 314 |
+
"susu": 300,
|
| 315 |
+
"susun</w>": 301,
|
| 316 |
+
"rec": 302,
|
| 317 |
+
"recal": 303,
|
| 318 |
+
"recall</w>": 304,
|
| 319 |
+
"ing": 305,
|
| 320 |
+
"ingat</w>": 306,
|
| 321 |
+
"semu": 307,
|
| 322 |
+
"semua</w>": 308,
|
| 323 |
+
"predict</w>": 309,
|
| 324 |
+
"perk": 310,
|
| 325 |
+
"perki": 311,
|
| 326 |
+
"perkira": 312,
|
| 327 |
+
"perkirakan</w>": 313,
|
| 328 |
+
"veri": 314,
|
| 329 |
+
"verif": 315,
|
| 330 |
+
"verify</w>": 316,
|
| 331 |
+
"cek</w>": 317,
|
| 332 |
+
"konsi": 318,
|
| 333 |
+
"konsis": 319,
|
| 334 |
+
"konsist": 320,
|
| 335 |
+
"konsisten": 321,
|
| 336 |
+
"konsistensi</w>": 322,
|
| 337 |
+
"konsum": 323,
|
| 338 |
+
"konsumsi</w>": 324,
|
| 339 |
+
"pa</w>": 325,
|
| 340 |
+
"men": 326,
|
| 341 |
+
"ting": 327,
|
| 342 |
+
"fil": 328,
|
| 343 |
+
"filte": 329,
|
| 344 |
+
"filter</w>": 330,
|
| 345 |
+
"eli": 331,
|
| 346 |
+
"elim": 332,
|
| 347 |
+
"elimin": 333,
|
| 348 |
+
"eliminasi</w>": 334,
|
| 349 |
+
"rele": 335,
|
| 350 |
+
"relev": 336,
|
| 351 |
+
"relevan</w>": 337,
|
| 352 |
+
"pil</w>": 338,
|
| 353 |
+
"pasa": 339,
|
| 354 |
+
"pasar</w>": 340,
|
| 355 |
+
"gela": 341,
|
| 356 |
+
"gelap</w>": 342,
|
| 357 |
+
"suc": 343,
|
| 358 |
+
"succe": 344,
|
| 359 |
+
"succes": 345,
|
| 360 |
+
"success</w>": 346,
|
| 361 |
+
"rat": 347,
|
| 362 |
+
"rate</w>": 348,
|
| 363 |
+
"pai": 349,
|
| 364 |
+
"pair</w>": 350,
|
| 365 |
+
"lebi": 351,
|
| 366 |
+
"lebih</w>": 352,
|
| 367 |
+
"tingg": 353,
|
| 368 |
+
"tinggi</w>": 354,
|
| 369 |
+
"bias": 355,
|
| 370 |
+
"biasan": 356,
|
| 371 |
+
"biasanya</w>": 357,
|
| 372 |
+
"dala": 358,
|
| 373 |
+
"dalam</w>": 359,
|
| 374 |
+
"ber": 360,
|
| 375 |
+
"pencur": 361,
|
| 376 |
+
"pencuri": 362,
|
| 377 |
+
"pencurian</w>": 363,
|
| 378 |
+
"ka</w>": 364,
|
| 379 |
+
"tan": 365,
|
| 380 |
+
"tanpa</w>": 366,
|
| 381 |
+
"je": 367,
|
| 382 |
+
"jeja": 368,
|
| 383 |
+
"jejak</w>": 369,
|
| 384 |
+
"perg": 370,
|
| 385 |
+
"perger": 371,
|
| 386 |
+
"pergera": 372,
|
| 387 |
+
"pergerakan</w>": 373,
|
| 388 |
+
"masi</w>": 374,
|
| 389 |
+
"inv": 375,
|
| 390 |
+
"inve": 376,
|
| 391 |
+
"inves": 377,
|
| 392 |
+
"investi": 378,
|
| 393 |
+
"investig": 379,
|
| 394 |
+
"investigasi</w>": 380,
|
| 395 |
+
"hari</w>": 381,
|
| 396 |
+
"sam": 382,
|
| 397 |
+
"sama</w>": 383,
|
| 398 |
+
"dat": 384,
|
| 399 |
+
"data</w>": 385,
|
| 400 |
+
"menu": 386,
|
| 401 |
+
"menun": 387,
|
| 402 |
+
"menunj": 388,
|
| 403 |
+
"menunju": 389,
|
| 404 |
+
"menunjuk": 390,
|
| 405 |
+
"menunjukkan</w>": 391,
|
| 406 |
+
"ca": 392,
|
| 407 |
+
"mi": 393,
|
| 408 |
+
"misi</w>": 394,
|
| 409 |
+
"assi": 395,
|
| 410 |
+
"assig": 396,
|
| 411 |
+
"assign</w>": 397,
|
| 412 |
+
"sen": 398,
|
| 413 |
+
"sendi": 399,
|
| 414 |
+
"sendiri</w>": 400,
|
| 415 |
+
"ka": 401,
|
| 416 |
+
"sete": 402,
|
| 417 |
+
"setela": 403,
|
| 418 |
+
"setelah</w>": 404,
|
| 419 |
+
"temu": 405,
|
| 420 |
+
"ke</w>": 406,
|
| 421 |
+
"sumb": 407,
|
| 422 |
+
"sumbe": 408,
|
| 423 |
+
"sumber</w>": 409,
|
| 424 |
+
"prediksi</w>": 410,
|
| 425 |
+
"ters": 411,
|
| 426 |
+
"tersang": 412,
|
| 427 |
+
"tersangka</w>": 413,
|
| 428 |
+
"penal": 414,
|
| 429 |
+
"penalar": 415,
|
| 430 |
+
"penalaran</w>": 416,
|
| 431 |
+
"menja": 417,
|
| 432 |
+
"menjadi</w>": 418,
|
| 433 |
+
"kun": 419,
|
| 434 |
+
"kunc": 420,
|
| 435 |
+
"kunci</w>": 421,
|
| 436 |
+
"hasi": 422,
|
| 437 |
+
"hasil</w>": 423,
|
| 438 |
+
"inf": 424,
|
| 439 |
+
"infor": 425,
|
| 440 |
+
"informasi</w>": 426,
|
| 441 |
+
"anomali</w>": 427,
|
| 442 |
+
"ya": 428,
|
| 443 |
+
"temuan</w>": 429,
|
| 444 |
+
"berk": 430,
|
| 445 |
+
"berkor": 431,
|
| 446 |
+
"berkorela": 432,
|
| 447 |
+
"berkorelasi</w>": 433,
|
| 448 |
+
"cata": 434,
|
| 449 |
+
"catat": 435,
|
| 450 |
+
"catatan</w>": 436,
|
| 451 |
+
"sia": 437,
|
| 452 |
+
"siapa</w>": 438,
|
| 453 |
+
"mencu": 439,
|
| 454 |
+
"mencuri</w>": 440,
|
| 455 |
+
"terse": 441,
|
| 456 |
+
"tersedi": 442,
|
| 457 |
+
"tersedia</w>": 443,
|
| 458 |
+
"mem": 444,
|
| 459 |
+
"memili": 445,
|
| 460 |
+
"memilik": 446,
|
| 461 |
+
"memiliki</w>": 447,
|
| 462 |
+
"kone": 448,
|
| 463 |
+
"koneksi</w>": 449,
|
| 464 |
+
"con": 450,
|
| 465 |
+
"confi": 451,
|
| 466 |
+
"confid": 452,
|
| 467 |
+
"confidence</w>": 453,
|
| 468 |
+
"mengin": 454,
|
| 469 |
+
"mengindi": 455,
|
| 470 |
+
"mengindika": 456,
|
| 471 |
+
"mengindikasi": 457,
|
| 472 |
+
"mengindikasikan</w>": 458,
|
| 473 |
+
"pene": 459,
|
| 474 |
+
"penelu": 460,
|
| 475 |
+
"penelusu": 461,
|
| 476 |
+
"penelusur": 462,
|
| 477 |
+
"penelusuran</w>": 463,
|
| 478 |
+
"log": 464,
|
| 479 |
+
"logi": 465,
|
| 480 |
+
"logika</w>": 466,
|
| 481 |
+
"pr": 467,
|
| 482 |
+
"pro": 468,
|
| 483 |
+
"prose": 469,
|
| 484 |
+
"proses</w>": 470,
|
| 485 |
+
"ded": 471,
|
| 486 |
+
"dedu": 472,
|
| 487 |
+
"deduksi</w>": 473,
|
| 488 |
+
"mengkon": 474,
|
| 489 |
+
"mengkonfi": 475,
|
| 490 |
+
"mengkonfir": 476,
|
| 491 |
+
"mengkonfirmasi</w>": 477,
|
| 492 |
+
"comple": 478,
|
| 493 |
+
"completion</w>": 479,
|
| 494 |
+
"ba": 480,
|
| 495 |
+
"bah": 481,
|
| 496 |
+
"bahw": 482,
|
| 497 |
+
"bahwa</w>": 483,
|
| 498 |
+
"ev": 484,
|
| 499 |
+
"eval": 485,
|
| 500 |
+
"evalu": 486,
|
| 501 |
+
"evaluasi</w>": 487,
|
| 502 |
+
"keper": 488,
|
| 503 |
+
"keperca": 489,
|
| 504 |
+
"kepercaya": 490,
|
| 505 |
+
"kepercayaan</w>": 491,
|
| 506 |
+
"berta": 492,
|
| 507 |
+
"bertaha": 493,
|
| 508 |
+
"bertahap</w>": 494,
|
| 509 |
+
"insi": 495,
|
| 510 |
+
"insid": 496,
|
| 511 |
+
"inside</w>": 497,
|
| 512 |
+
"jo": 498,
|
| 513 |
+
"job</w>": 499
|
| 514 |
+
},
|
| 515 |
+
"merges": {
|
| 516 |
+
"a|||n": 0,
|
| 517 |
+
"a|||n</w>": 1,
|
| 518 |
+
"e|||r": 2,
|
| 519 |
+
"e|||n": 3,
|
| 520 |
+
"d|||a": 4,
|
| 521 |
+
"t|||i": 5,
|
| 522 |
+
"i|||l": 6,
|
| 523 |
+
"s|||i": 7,
|
| 524 |
+
"d|||i": 8,
|
| 525 |
+
"an|||g</w>": 9,
|
| 526 |
+
"s|||i</w>": 10,
|
| 527 |
+
"an|||c": 11,
|
| 528 |
+
"k|||an</w>": 12,
|
| 529 |
+
"a|||l": 13,
|
| 530 |
+
"s|||u": 14,
|
| 531 |
+
"an|||g": 15,
|
| 532 |
+
"r|||i</w>": 16,
|
| 533 |
+
"k|||e": 17,
|
| 534 |
+
"e|||f": 18,
|
| 535 |
+
"t|||er": 19,
|
| 536 |
+
"s|||e": 20,
|
| 537 |
+
"t|||e": 21,
|
| 538 |
+
"p|||a": 22,
|
| 539 |
+
"n|||g": 23,
|
| 540 |
+
"o|||n</w>": 24,
|
| 541 |
+
"o|||n": 25,
|
| 542 |
+
"h|||ef": 26,
|
| 543 |
+
"hef|||e": 27,
|
| 544 |
+
"en|||c": 28,
|
| 545 |
+
"o|||r": 29,
|
| 546 |
+
"l|||a": 30,
|
| 547 |
+
"si|||m": 31,
|
| 548 |
+
"u|||l": 32,
|
| 549 |
+
"ti|||da": 33,
|
| 550 |
+
"a|||r": 34,
|
| 551 |
+
"en|||g": 35,
|
| 552 |
+
"da|||ri</w>": 36,
|
| 553 |
+
"r|||e": 37,
|
| 554 |
+
"b|||u": 38,
|
| 555 |
+
"anc|||e</w>": 39,
|
| 556 |
+
"r|||a": 40,
|
| 557 |
+
"o|||m": 41,
|
| 558 |
+
"hefe|||i</w>": 42,
|
| 559 |
+
"j|||ang": 43,
|
| 560 |
+
"s|||a": 44,
|
| 561 |
+
"j|||u</w>": 45,
|
| 562 |
+
"jang|||m": 46,
|
| 563 |
+
"jangm|||o": 47,
|
| 564 |
+
"jangmo|||k</w>": 48,
|
| 565 |
+
"a|||l</w>": 49,
|
| 566 |
+
"o|||s": 50,
|
| 567 |
+
"di|||anc": 51,
|
| 568 |
+
"dianc|||ang</w>": 52,
|
| 569 |
+
"a|||i": 53,
|
| 570 |
+
"i|||n": 54,
|
| 571 |
+
"j|||a": 55,
|
| 572 |
+
"k|||on": 56,
|
| 573 |
+
"l|||i": 57,
|
| 574 |
+
"c|||t</w>": 58,
|
| 575 |
+
"tida|||k</w>": 59,
|
| 576 |
+
"er|||i": 60,
|
| 577 |
+
"f|||i": 61,
|
| 578 |
+
"m|||eng": 62,
|
| 579 |
+
"a|||si</w>": 63,
|
| 580 |
+
"ke|||sim": 64,
|
| 581 |
+
"kesim|||p": 65,
|
| 582 |
+
"kesimp|||ul": 66,
|
| 583 |
+
"kesimpul|||an</w>": 67,
|
| 584 |
+
"d|||i</w>": 68,
|
| 585 |
+
"ng|||kan</w>": 69,
|
| 586 |
+
"k|||si</w>": 70,
|
| 587 |
+
"p|||i": 71,
|
| 588 |
+
"y|||a</w>": 72,
|
| 589 |
+
"y|||ang</w>": 73,
|
| 590 |
+
"enc|||u": 74,
|
| 591 |
+
"t|||a": 75,
|
| 592 |
+
"bu|||k": 76,
|
| 593 |
+
"buk|||t": 77,
|
| 594 |
+
"bukt|||i</w>": 78,
|
| 595 |
+
"p|||en": 79,
|
| 596 |
+
"p|||er": 80,
|
| 597 |
+
"l|||u": 81,
|
| 598 |
+
"l|||e": 82,
|
| 599 |
+
"fi|||v": 83,
|
| 600 |
+
"fiv|||e</w>": 84,
|
| 601 |
+
"s|||w": 85,
|
| 602 |
+
"sw|||or": 86,
|
| 603 |
+
"swor|||d": 87,
|
| 604 |
+
"sword|||s</w>": 88,
|
| 605 |
+
"p|||encu": 89,
|
| 606 |
+
"enc|||e</w>": 90,
|
| 607 |
+
"c|||e": 91,
|
| 608 |
+
"k|||u": 92,
|
| 609 |
+
"il|||i": 93,
|
| 610 |
+
"s|||n": 94,
|
| 611 |
+
"sn|||o": 95,
|
| 612 |
+
"sno|||w</w>": 96,
|
| 613 |
+
"p|||lu": 97,
|
| 614 |
+
"plu|||m</w>": 98,
|
| 615 |
+
"p|||il": 99,
|
| 616 |
+
"pil|||l</w>": 100,
|
| 617 |
+
"meng|||h": 101,
|
| 618 |
+
"mengh|||il": 102,
|
| 619 |
+
"menghil|||ang</w>": 103,
|
| 620 |
+
"l|||o": 104,
|
| 621 |
+
"b|||i": 105,
|
| 622 |
+
"d|||e": 106,
|
| 623 |
+
"an|||om": 107,
|
| 624 |
+
"anom|||al": 108,
|
| 625 |
+
"m|||ar": 109,
|
| 626 |
+
"mar|||ti": 110,
|
| 627 |
+
"marti|||al</w>": 111,
|
| 628 |
+
"al|||li": 112,
|
| 629 |
+
"alli|||ance</w>": 113,
|
| 630 |
+
"m|||u": 114,
|
| 631 |
+
"an|||al": 115,
|
| 632 |
+
"anal|||i": 116,
|
| 633 |
+
"anali|||si": 117,
|
| 634 |
+
"analisi|||s</w>": 118,
|
| 635 |
+
"g|||y": 119,
|
| 636 |
+
"gy|||er": 120,
|
| 637 |
+
"gyer|||y": 121,
|
| 638 |
+
"gyery|||on": 122,
|
| 639 |
+
"gyeryon|||g</w>": 123,
|
| 640 |
+
"m|||er": 124,
|
| 641 |
+
"mer|||c": 125,
|
| 642 |
+
"merc|||h": 126,
|
| 643 |
+
"merch|||an": 127,
|
| 644 |
+
"merchan|||t</w>": 128,
|
| 645 |
+
"g|||u": 129,
|
| 646 |
+
"gu|||il": 130,
|
| 647 |
+
"guil|||d</w>": 131,
|
| 648 |
+
"h|||a": 132,
|
| 649 |
+
"c|||r": 133,
|
| 650 |
+
"cr|||os": 134,
|
| 651 |
+
"cros|||s</w>": 135,
|
| 652 |
+
"r|||ef": 136,
|
| 653 |
+
"ref|||er": 137,
|
| 654 |
+
"refer|||ence</w>": 138,
|
| 655 |
+
"ke|||ja": 139,
|
| 656 |
+
"keja|||di": 140,
|
| 657 |
+
"kejadi|||an</w>": 141,
|
| 658 |
+
"sim|||h": 142,
|
| 659 |
+
"simh|||y": 143,
|
| 660 |
+
"simhy|||e": 144,
|
| 661 |
+
"simhye|||on</w>": 145,
|
| 662 |
+
"pa|||v": 146,
|
| 663 |
+
"pav|||ili": 147,
|
| 664 |
+
"pavili|||on</w>": 148,
|
| 665 |
+
"m|||e": 149,
|
| 666 |
+
"ti|||on</w>": 150,
|
| 667 |
+
"su|||m": 151,
|
| 668 |
+
"b|||lo": 152,
|
| 669 |
+
"blo|||o": 153,
|
| 670 |
+
"bloo|||d</w>": 154,
|
| 671 |
+
"s|||er": 155,
|
| 672 |
+
"ser|||pen": 156,
|
| 673 |
+
"serpen|||t</w>": 157,
|
| 674 |
+
"d|||ance</w>": 158,
|
| 675 |
+
"s|||te": 159,
|
| 676 |
+
"ste|||p</w>": 160,
|
| 677 |
+
"p|||re": 161,
|
| 678 |
+
"pre|||di": 162,
|
| 679 |
+
"ti|||n": 163,
|
| 680 |
+
"tin|||da": 164,
|
| 681 |
+
"tinda|||kan</w>": 165,
|
| 682 |
+
"b|||eri": 166,
|
| 683 |
+
"beri|||ku": 167,
|
| 684 |
+
"beriku|||t": 168,
|
| 685 |
+
"berikut|||n": 169,
|
| 686 |
+
"berikutn|||ya</w>": 170,
|
| 687 |
+
"ta|||e": 171,
|
| 688 |
+
"tae|||ul": 172,
|
| 689 |
+
"taeul|||_": 173,
|
| 690 |
+
"taeul_|||se": 174,
|
| 691 |
+
"taeul_se|||ct</w>": 175,
|
| 692 |
+
"p|||o": 176,
|
| 693 |
+
"po|||l": 177,
|
| 694 |
+
"pol|||a</w>": 178,
|
| 695 |
+
"j|||ang</w>": 179,
|
| 696 |
+
"h|||ang": 180,
|
| 697 |
+
"hang|||i</w>": 181,
|
| 698 |
+
"a|||d": 182,
|
| 699 |
+
"ad|||a</w>": 183,
|
| 700 |
+
"b|||ar": 184,
|
| 701 |
+
"bar|||u</w>": 185,
|
| 702 |
+
"pa|||t": 186,
|
| 703 |
+
"pat|||ter": 187,
|
| 704 |
+
"patter|||n</w>": 188,
|
| 705 |
+
"ter|||pi": 189,
|
| 706 |
+
"terpi|||sa": 190,
|
| 707 |
+
"terpisa|||h</w>": 191,
|
| 708 |
+
"c|||om": 192,
|
| 709 |
+
"com|||p": 193,
|
| 710 |
+
"a|||s": 194,
|
| 711 |
+
"de|||te": 195,
|
| 712 |
+
"dete|||ksi</w>": 196,
|
| 713 |
+
"g|||u</w>": 197,
|
| 714 |
+
"il|||m": 198,
|
| 715 |
+
"ilm|||u</w>": 199,
|
| 716 |
+
"ke|||tida": 200,
|
| 717 |
+
"ketida|||k": 201,
|
| 718 |
+
"ketidak|||se": 202,
|
| 719 |
+
"ketidakse|||su": 203,
|
| 720 |
+
"ketidaksesu|||ai": 204,
|
| 721 |
+
"ketidaksesuai|||an</w>": 205,
|
| 722 |
+
"ter|||k": 206,
|
| 723 |
+
"terk|||ai": 207,
|
| 724 |
+
"terkai|||t</w>": 208,
|
| 725 |
+
"la|||p": 209,
|
| 726 |
+
"lap|||or": 210,
|
| 727 |
+
"lapor|||an</w>": 211,
|
| 728 |
+
"h|||u": 212,
|
| 729 |
+
"hu|||bu": 213,
|
| 730 |
+
"e|||la": 214,
|
| 731 |
+
"da|||r": 215,
|
| 732 |
+
"dar|||k": 216,
|
| 733 |
+
"dark|||_": 217,
|
| 734 |
+
"dark_|||f": 218,
|
| 735 |
+
"dark_f|||a": 219,
|
| 736 |
+
"dark_fa|||c": 220,
|
| 737 |
+
"dark_fac|||tion</w>": 221,
|
| 738 |
+
"a|||t</w>": 222,
|
| 739 |
+
"anomal|||y</w>": 223,
|
| 740 |
+
"b|||an": 224,
|
| 741 |
+
"ban|||di": 225,
|
| 742 |
+
"bandi|||ngkan</w>": 226,
|
| 743 |
+
"t|||ang": 227,
|
| 744 |
+
"tang|||g": 228,
|
| 745 |
+
"tangg|||al</w>": 229,
|
| 746 |
+
"hefe|||i": 230,
|
| 747 |
+
"hefei|||_": 231,
|
| 748 |
+
"hefei_|||b": 232,
|
| 749 |
+
"hefei_b|||r": 233,
|
| 750 |
+
"hefei_br|||anc": 234,
|
| 751 |
+
"hefei_branc|||h</w>": 235,
|
| 752 |
+
"d|||eng": 236,
|
| 753 |
+
"deng|||an</w>": 237,
|
| 754 |
+
"hubu|||ngkan</w>": 238,
|
| 755 |
+
"f|||ra": 239,
|
| 756 |
+
"fra|||g": 240,
|
| 757 |
+
"frag|||me": 241,
|
| 758 |
+
"fragme|||n</w>": 242,
|
| 759 |
+
"pencu|||ri</w>": 243,
|
| 760 |
+
"comp|||os": 244,
|
| 761 |
+
"compos|||e</w>": 245,
|
| 762 |
+
"su|||su": 246,
|
| 763 |
+
"susu|||n</w>": 247,
|
| 764 |
+
"re|||c": 248,
|
| 765 |
+
"rec|||al": 249,
|
| 766 |
+
"recal|||l</w>": 250,
|
| 767 |
+
"i|||ng": 251,
|
| 768 |
+
"ing|||at</w>": 252,
|
| 769 |
+
"se|||mu": 253,
|
| 770 |
+
"semu|||a</w>": 254,
|
| 771 |
+
"predi|||ct</w>": 255,
|
| 772 |
+
"per|||k": 256,
|
| 773 |
+
"perk|||i": 257,
|
| 774 |
+
"perki|||ra": 258,
|
| 775 |
+
"perkira|||kan</w>": 259,
|
| 776 |
+
"v|||eri": 260,
|
| 777 |
+
"veri|||f": 261,
|
| 778 |
+
"verif|||y</w>": 262,
|
| 779 |
+
"ce|||k</w>": 263,
|
| 780 |
+
"kon|||si": 264,
|
| 781 |
+
"konsi|||s": 265,
|
| 782 |
+
"konsis|||t": 266,
|
| 783 |
+
"konsist|||en": 267,
|
| 784 |
+
"konsisten|||si</w>": 268,
|
| 785 |
+
"kon|||sum": 269,
|
| 786 |
+
"konsum|||si</w>": 270,
|
| 787 |
+
"p|||a</w>": 271,
|
| 788 |
+
"m|||en": 272,
|
| 789 |
+
"ti|||ng": 273,
|
| 790 |
+
"f|||il": 274,
|
| 791 |
+
"fil|||te": 275,
|
| 792 |
+
"filte|||r</w>": 276,
|
| 793 |
+
"e|||li": 277,
|
| 794 |
+
"eli|||m": 278,
|
| 795 |
+
"elim|||in": 279,
|
| 796 |
+
"elimin|||asi</w>": 280,
|
| 797 |
+
"re|||le": 281,
|
| 798 |
+
"rele|||v": 282,
|
| 799 |
+
"relev|||an</w>": 283,
|
| 800 |
+
"pi|||l</w>": 284,
|
| 801 |
+
"pa|||sa": 285,
|
| 802 |
+
"pasa|||r</w>": 286,
|
| 803 |
+
"g|||ela": 287,
|
| 804 |
+
"gela|||p</w>": 288,
|
| 805 |
+
"su|||c": 289,
|
| 806 |
+
"suc|||ce": 290,
|
| 807 |
+
"succe|||s": 291,
|
| 808 |
+
"succes|||s</w>": 292,
|
| 809 |
+
"ra|||t": 293,
|
| 810 |
+
"rat|||e</w>": 294,
|
| 811 |
+
"pa|||i": 295,
|
| 812 |
+
"pai|||r</w>": 296,
|
| 813 |
+
"le|||bi": 297,
|
| 814 |
+
"lebi|||h</w>": 298,
|
| 815 |
+
"ting|||g": 299,
|
| 816 |
+
"tingg|||i</w>": 300,
|
| 817 |
+
"bi|||as": 301,
|
| 818 |
+
"bias|||an": 302,
|
| 819 |
+
"biasan|||ya</w>": 303,
|
| 820 |
+
"da|||la": 304,
|
| 821 |
+
"dala|||m</w>": 305,
|
| 822 |
+
"b|||er": 306,
|
| 823 |
+
"pencu|||r": 307,
|
| 824 |
+
"pencur|||i": 308,
|
| 825 |
+
"pencuri|||an</w>": 309,
|
| 826 |
+
"k|||a</w>": 310,
|
| 827 |
+
"t|||an": 311,
|
| 828 |
+
"tan|||pa</w>": 312,
|
| 829 |
+
"j|||e": 313,
|
| 830 |
+
"je|||ja": 314,
|
| 831 |
+
"jeja|||k</w>": 315,
|
| 832 |
+
"per|||g": 316,
|
| 833 |
+
"perg|||er": 317,
|
| 834 |
+
"perger|||a": 318,
|
| 835 |
+
"pergera|||kan</w>": 319,
|
| 836 |
+
"m|||asi</w>": 320,
|
| 837 |
+
"in|||v": 321,
|
| 838 |
+
"inv|||e": 322,
|
| 839 |
+
"inve|||s": 323,
|
| 840 |
+
"inves|||ti": 324,
|
| 841 |
+
"investi|||g": 325,
|
| 842 |
+
"investig|||asi</w>": 326,
|
| 843 |
+
"ha|||ri</w>": 327,
|
| 844 |
+
"sa|||m": 328,
|
| 845 |
+
"sam|||a</w>": 329,
|
| 846 |
+
"da|||t": 330,
|
| 847 |
+
"dat|||a</w>": 331,
|
| 848 |
+
"men|||u": 332,
|
| 849 |
+
"menu|||n": 333,
|
| 850 |
+
"menun|||j": 334,
|
| 851 |
+
"menunj|||u": 335,
|
| 852 |
+
"menunju|||k": 336,
|
| 853 |
+
"menunjuk|||kan</w>": 337,
|
| 854 |
+
"c|||a": 338,
|
| 855 |
+
"m|||i": 339,
|
| 856 |
+
"mi|||si</w>": 340,
|
| 857 |
+
"as|||si": 341,
|
| 858 |
+
"assi|||g": 342,
|
| 859 |
+
"assig|||n</w>": 343,
|
| 860 |
+
"s|||en": 344,
|
| 861 |
+
"sen|||di": 345,
|
| 862 |
+
"sendi|||ri</w>": 346,
|
| 863 |
+
"k|||a": 347,
|
| 864 |
+
"se|||te": 348,
|
| 865 |
+
"sete|||la": 349,
|
| 866 |
+
"setela|||h</w>": 350,
|
| 867 |
+
"te|||mu": 351,
|
| 868 |
+
"k|||e</w>": 352,
|
| 869 |
+
"sum|||b": 353,
|
| 870 |
+
"sumb|||e": 354,
|
| 871 |
+
"sumbe|||r</w>": 355,
|
| 872 |
+
"predi|||ksi</w>": 356,
|
| 873 |
+
"ter|||s": 357,
|
| 874 |
+
"ters|||ang": 358,
|
| 875 |
+
"tersang|||ka</w>": 359,
|
| 876 |
+
"pen|||al": 360,
|
| 877 |
+
"penal|||ar": 361,
|
| 878 |
+
"penalar|||an</w>": 362,
|
| 879 |
+
"men|||ja": 363,
|
| 880 |
+
"menja|||di</w>": 364,
|
| 881 |
+
"ku|||n": 365,
|
| 882 |
+
"kun|||c": 366,
|
| 883 |
+
"kunc|||i</w>": 367,
|
| 884 |
+
"ha|||si": 368,
|
| 885 |
+
"hasi|||l</w>": 369,
|
| 886 |
+
"in|||f": 370,
|
| 887 |
+
"inf|||or": 371,
|
| 888 |
+
"infor|||masi</w>": 372,
|
| 889 |
+
"anomal|||i</w>": 373,
|
| 890 |
+
"y|||a": 374,
|
| 891 |
+
"temu|||an</w>": 375,
|
| 892 |
+
"ber|||k": 376,
|
| 893 |
+
"berk|||or": 377,
|
| 894 |
+
"berkor|||ela": 378,
|
| 895 |
+
"berkorela|||si</w>": 379,
|
| 896 |
+
"ca|||ta": 380,
|
| 897 |
+
"cata|||t": 381,
|
| 898 |
+
"catat|||an</w>": 382,
|
| 899 |
+
"si|||a": 383,
|
| 900 |
+
"sia|||pa</w>": 384,
|
| 901 |
+
"m|||encu": 385,
|
| 902 |
+
"mencu|||ri</w>": 386,
|
| 903 |
+
"ter|||se": 387,
|
| 904 |
+
"terse|||di": 388,
|
| 905 |
+
"tersedi|||a</w>": 389,
|
| 906 |
+
"me|||m": 390,
|
| 907 |
+
"mem|||ili": 391,
|
| 908 |
+
"memili|||k": 392,
|
| 909 |
+
"memilik|||i</w>": 393,
|
| 910 |
+
"kon|||e": 394,
|
| 911 |
+
"kone|||ksi</w>": 395,
|
| 912 |
+
"c|||on": 396,
|
| 913 |
+
"con|||fi": 397,
|
| 914 |
+
"confi|||d": 398,
|
| 915 |
+
"confid|||ence</w>": 399,
|
| 916 |
+
"meng|||in": 400,
|
| 917 |
+
"mengin|||di": 401,
|
| 918 |
+
"mengindi|||ka": 402,
|
| 919 |
+
"mengindika|||si": 403,
|
| 920 |
+
"mengindikasi|||kan</w>": 404,
|
| 921 |
+
"pen|||e": 405,
|
| 922 |
+
"pene|||lu": 406,
|
| 923 |
+
"penelu|||su": 407,
|
| 924 |
+
"penelusu|||r": 408,
|
| 925 |
+
"penelusur|||an</w>": 409,
|
| 926 |
+
"lo|||g": 410,
|
| 927 |
+
"log|||i": 411,
|
| 928 |
+
"logi|||ka</w>": 412,
|
| 929 |
+
"p|||r": 413,
|
| 930 |
+
"pr|||o": 414,
|
| 931 |
+
"pro|||se": 415,
|
| 932 |
+
"prose|||s</w>": 416,
|
| 933 |
+
"de|||d": 417,
|
| 934 |
+
"ded|||u": 418,
|
| 935 |
+
"dedu|||ksi</w>": 419,
|
| 936 |
+
"meng|||kon": 420,
|
| 937 |
+
"mengkon|||fi": 421,
|
| 938 |
+
"mengkonfi|||r": 422,
|
| 939 |
+
"mengkonfir|||masi</w>": 423,
|
| 940 |
+
"comp|||le": 424,
|
| 941 |
+
"comple|||tion</w>": 425,
|
| 942 |
+
"b|||a": 426,
|
| 943 |
+
"ba|||h": 427,
|
| 944 |
+
"bah|||w": 428,
|
| 945 |
+
"bahw|||a</w>": 429,
|
| 946 |
+
"e|||v": 430,
|
| 947 |
+
"ev|||al": 431,
|
| 948 |
+
"eval|||u": 432,
|
| 949 |
+
"evalu|||asi</w>": 433,
|
| 950 |
+
"ke|||per": 434,
|
| 951 |
+
"keper|||ca": 435,
|
| 952 |
+
"keperca|||ya": 436,
|
| 953 |
+
"kepercaya|||an</w>": 437,
|
| 954 |
+
"ber|||ta": 438,
|
| 955 |
+
"berta|||ha": 439,
|
| 956 |
+
"bertaha|||p</w>": 440,
|
| 957 |
+
"in|||si": 441,
|
| 958 |
+
"insi|||d": 442,
|
| 959 |
+
"insid|||e</w>": 443,
|
| 960 |
+
"j|||o": 444,
|
| 961 |
+
"jo|||b</w>": 445
|
| 962 |
+
},
|
| 963 |
+
"is_trained": true
|
| 964 |
+
}
|
training_config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name": "aam-diffusion-v1.0",
|
| 3 |
+
"aam_mind_source": "rsvs_graph",
|
| 4 |
+
"aam_body_type": "specialized_diffusion",
|
| 5 |
+
"architecture": {
|
| 6 |
+
"type": "diffusion_transformer",
|
| 7 |
+
"d_model": 64,
|
| 8 |
+
"n_layers": 2,
|
| 9 |
+
"n_heads": 4,
|
| 10 |
+
"d_ff": 128,
|
| 11 |
+
"vocab_size": 500,
|
| 12 |
+
"max_seq_len": 32,
|
| 13 |
+
"pos_encoding_type": "learned"
|
| 14 |
+
},
|
| 15 |
+
"diffusion": {
|
| 16 |
+
"n_timesteps": 50,
|
| 17 |
+
"n_inference_steps": 5,
|
| 18 |
+
"schedule_type": "cosine",
|
| 19 |
+
"prediction_type": "epsilon",
|
| 20 |
+
"sampling_method": "ddim"
|
| 21 |
+
},
|
| 22 |
+
"graph_encoder": {
|
| 23 |
+
"d_graph": 32,
|
| 24 |
+
"n_graph_layers": 1,
|
| 25 |
+
"conditioning_method": "cross_attention"
|
| 26 |
+
},
|
| 27 |
+
"parameters": 311670
|
| 28 |
+
}
|