Wolfvin's picture
AAM Diffusion LLM v1.0 — The Body of Aphantasic Abstraction Model
2d7e335 verified
"""
AAM Diffusion LLM — Dataset
Dataset class for Graph→Narrative training pairs.
Each training example consists of:
- Graph conditioning: evidence nodes, compositions, confidence,
anomalies, reasoning chains, temporal context
- Target narrative: natural language text that represents
the graph data in sentence form
The dataset handles:
- Loading from JSONL files
- Tokenization of both graph data and narratives
- Padding and batching
- Data augmentation (sentence shuffling, noise injection)
Analogi: Seperti Jin Soun berlatih mengungkapkan kesimpulan —
dia diberi "kasus" (graph data) dan "jawaban yang benar"
(narrative target), lalu berlatih sampai bisa menyusun
kalimat yang tepat dari graph.
"""
from __future__ import annotations
import json
import logging
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import torch
from torch.utils.data import Dataset
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
logger = logging.getLogger(__name__)
@dataclass
class GraphNarrativeExample:
"""A single training example: graph conditioning + target narrative.
This represents the "input" and "expected output" for one
training step of the diffusion model.
"""
# Target narrative (what the model should generate)
narrative: str = ""
# Graph conditioning inputs
trigger: str = ""
evidence_nodes: list[str] = field(default_factory=list)
compositions: list[str] = field(default_factory=list)
confidence_map: dict[str, float] = field(default_factory=dict)
anomalies: list[str] = field(default_factory=list)
reasoning_steps: list[str] = field(default_factory=list)
source_trust: float = 1.0
temporal_context: list[str] = field(default_factory=list)
# Metadata
language: str = "id"
source: str = "synthetic"
def to_dict(self) -> dict:
"""Serialize to dictionary."""
return {
"narrative": self.narrative,
"trigger": self.trigger,
"evidence_nodes": self.evidence_nodes,
"compositions": self.compositions,
"confidence_map": self.confidence_map,
"anomalies": self.anomalies,
"reasoning_steps": self.reasoning_steps,
"source_trust": self.source_trust,
"temporal_context": self.temporal_context,
"language": self.language,
"source": self.source,
}
@classmethod
def from_dict(cls, data: dict) -> GraphNarrativeExample:
"""Deserialize from dictionary."""
return cls(
narrative=data.get("narrative", ""),
trigger=data.get("trigger", ""),
evidence_nodes=data.get("evidence_nodes", []),
compositions=data.get("compositions", []),
confidence_map=data.get("confidence_map", {}),
anomalies=data.get("anomalies", []),
reasoning_steps=data.get("reasoning_steps", []),
source_trust=data.get("source_trust", 1.0),
temporal_context=data.get("temporal_context", []),
language=data.get("language", "id"),
source=data.get("source", "synthetic"),
)
@dataclass
class BatchOutput:
"""Output from a single batch.
All tensors are already padded to uniform length.
"""
token_ids: torch.Tensor
"""Target narrative token IDs, shape (batch, seq_len)."""
evidence_ids: Optional[torch.Tensor] = None
"""Evidence node token IDs, shape (batch, n_evidence, ev_seq_len)."""
evidence_confidence: Optional[torch.Tensor] = None
"""Evidence confidence, shape (batch, n_evidence)."""
anomaly_ids: Optional[torch.Tensor] = None
"""Anomaly token IDs, shape (batch, n_anomalies, an_seq_len)."""
anomaly_confidence: Optional[torch.Tensor] = None
"""Anomaly confidence, shape (batch, n_anomalies)."""
reasoning_ids: Optional[torch.Tensor] = None
"""Reasoning step token IDs, shape (batch, n_steps, r_seq_len)."""
reasoning_confidence: Optional[torch.Tensor] = None
"""Reasoning confidence, shape (batch, n_steps)."""
source_trust: Optional[torch.Tensor] = None
"""Source trust scores, shape (batch,)."""
class GraphNarrativeDataset(Dataset):
"""Dataset for Graph→Narrative training pairs.
Loads training examples from JSONL files and provides
tokenized, padded batches for training.
Args:
data_path: Path to JSONL file with training data.
tokenizer: AamTokenizer instance for encoding.
max_seq_len: Maximum sequence length for narratives.
max_evidence: Maximum number of evidence nodes.
max_anomalies: Maximum number of anomalies.
max_reasoning: Maximum number of reasoning steps.
augment: Whether to apply data augmentation.
"""
def __init__(
self,
data_path: str | Path,
tokenizer: AamTokenizer,
max_seq_len: int = 512,
max_evidence: int = 50,
max_anomalies: int = 10,
max_reasoning: int = 15,
augment: bool = True,
):
self.data_path = Path(data_path)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_evidence = max_evidence
self.max_anomalies = max_anomalies
self.max_reasoning = max_reasoning
self.augment = augment
# Load data
self.examples: list[GraphNarrativeExample] = []
self._load_data()
logger.info(
"GraphNarrativeDataset: %d examples loaded from %s",
len(self.examples),
self.data_path,
)
def _load_data(self) -> None:
"""Load examples from JSONL file."""
if not self.data_path.exists():
logger.warning("Data file not found: %s", self.data_path)
return
with open(self.data_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
example = GraphNarrativeExample.from_dict(data)
if example.narrative: # Skip empty narratives
self.examples.append(example)
except json.JSONDecodeError:
logger.warning("Invalid JSON at line %d", line_num)
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""Get a single training example.
Returns:
Dictionary with tokenized inputs.
"""
example = self.examples[idx]
# Data augmentation
if self.augment:
example = self._augment(example)
# Tokenize narrative (target)
narrative_ids = self.tokenizer.encode(example.narrative, add_special=True)
narrative_ids = self.tokenizer.pad_sequence(narrative_ids, self.max_seq_len)
narrative_tensor = torch.tensor(narrative_ids, dtype=torch.long)
# Tokenize evidence nodes
evidence_data = self._tokenize_node_list(
example.evidence_nodes, max_nodes=self.max_evidence
)
# Tokenize anomalies
anomaly_data = self._tokenize_node_list(
example.anomalies, max_nodes=self.max_anomalies
)
# Tokenize reasoning steps
reasoning_data = self._tokenize_node_list(
example.reasoning_steps, max_nodes=self.max_reasoning
)
# Source trust
source_trust = torch.tensor(example.source_trust, dtype=torch.float32)
# Evidence confidence
conf_values = list(example.confidence_map.values())[:self.max_evidence]
if conf_values:
evidence_conf = torch.tensor(conf_values, dtype=torch.float32)
evidence_conf = torch.nn.functional.pad(
evidence_conf, (0, self.max_evidence - len(conf_values))
)
else:
evidence_conf = torch.zeros(self.max_evidence, dtype=torch.float32)
# Anomaly confidence (default 0.6 for detected anomalies)
anomaly_conf = torch.full(
(self.max_anomalies,), 0.6, dtype=torch.float32
)
# Reasoning confidence (default 0.7)
reasoning_conf = torch.full(
(self.max_reasoning,), 0.7, dtype=torch.float32
)
return {
"token_ids": narrative_tensor,
"evidence_ids": evidence_data["ids"],
"evidence_confidence": evidence_conf,
"anomaly_ids": anomaly_data["ids"],
"anomaly_confidence": anomaly_conf,
"reasoning_ids": reasoning_data["ids"],
"reasoning_confidence": reasoning_conf,
"source_trust": source_trust,
}
def _tokenize_node_list(
self,
nodes: list[str],
max_nodes: int,
max_node_len: int = 32,
) -> dict[str, torch.Tensor]:
"""Tokenize a list of node descriptions.
Args:
nodes: List of node text descriptions.
max_nodes: Maximum number of nodes to encode.
max_node_len: Maximum token length per node.
Returns:
Dictionary with padded token IDs tensor.
"""
if not nodes:
return {
"ids": torch.zeros(max_nodes, max_node_len, dtype=torch.long),
}
# Limit to max_nodes
nodes = nodes[:max_nodes]
all_ids = []
for node in nodes:
ids = self.tokenizer.encode(node, add_special=False)
ids = self.tokenizer.pad_sequence(ids, max_node_len)
all_ids.append(ids)
# Pad to max_nodes
while len(all_ids) < max_nodes:
all_ids.append([0] * max_node_len)
return {
"ids": torch.tensor(all_ids, dtype=torch.long),
}
def _augment(self, example: GraphNarrativeExample) -> GraphNarrativeExample:
"""Apply data augmentation.
Augmentation strategies:
1. Random sentence shuffling within the narrative
2. Random evidence node dropping (simulate incomplete data)
3. Random confidence perturbation
Args:
example: Original training example.
Returns:
Augmented example.
"""
import copy
augmented = copy.deepcopy(example)
# 1. Sentence shuffling (with 20% probability)
if random.random() < 0.2:
sentences = self.tokenizer._split_sentences(augmented.narrative)
if len(sentences) > 2:
# Keep first sentence, shuffle the rest
first = sentences[0]
rest = sentences[1:]
random.shuffle(rest)
augmented.narrative = first + " " + " ".join(rest)
# 2. Evidence dropping (with 10% probability per node)
if augmented.evidence_nodes:
augmented.evidence_nodes = [
node for node in augmented.evidence_nodes
if random.random() > 0.1
]
# 3. Confidence perturbation
if augmented.confidence_map:
perturbed = {}
for k, v in augmented.confidence_map.items():
noise = random.gauss(0, 0.05)
perturbed[k] = max(0.0, min(1.0, v + noise))
augmented.confidence_map = perturbed
return augmented
def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
"""Custom collate function for DataLoader.
Handles variable-length graph conditioning by padding
all tensors in the batch to the same size.
Args:
batch: List of example dictionaries.
Returns:
Batched dictionary of tensors.
"""
result = {}
# Stack all tensors
for key in batch[0]:
tensors = [item[key] for item in batch]
if tensors[0].dim() == 0:
result[key] = torch.stack(tensors)
elif tensors[0].dim() == 1:
result[key] = torch.stack(tensors)
elif tensors[0].dim() == 2:
result[key] = torch.stack(tensors)
else:
result[key] = torch.stack(tensors)
return result