| """ |
| AAM Diffusion LLM — Dataset |
| |
| Dataset class for Graph→Narrative training pairs. |
| |
| Each training example consists of: |
| - Graph conditioning: evidence nodes, compositions, confidence, |
| anomalies, reasoning chains, temporal context |
| - Target narrative: natural language text that represents |
| the graph data in sentence form |
| |
| The dataset handles: |
| - Loading from JSONL files |
| - Tokenization of both graph data and narratives |
| - Padding and batching |
| - Data augmentation (sentence shuffling, noise injection) |
| |
| Analogi: Seperti Jin Soun berlatih mengungkapkan kesimpulan — |
| dia diberi "kasus" (graph data) dan "jawaban yang benar" |
| (narrative target), lalu berlatih sampai bisa menyusun |
| kalimat yang tepat dari graph. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import random |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class GraphNarrativeExample: |
| """A single training example: graph conditioning + target narrative. |
| |
| This represents the "input" and "expected output" for one |
| training step of the diffusion model. |
| """ |
| |
| narrative: str = "" |
|
|
| |
| trigger: str = "" |
| evidence_nodes: list[str] = field(default_factory=list) |
| compositions: list[str] = field(default_factory=list) |
| confidence_map: dict[str, float] = field(default_factory=dict) |
| anomalies: list[str] = field(default_factory=list) |
| reasoning_steps: list[str] = field(default_factory=list) |
| source_trust: float = 1.0 |
| temporal_context: list[str] = field(default_factory=list) |
|
|
| |
| language: str = "id" |
| source: str = "synthetic" |
|
|
| def to_dict(self) -> dict: |
| """Serialize to dictionary.""" |
| return { |
| "narrative": self.narrative, |
| "trigger": self.trigger, |
| "evidence_nodes": self.evidence_nodes, |
| "compositions": self.compositions, |
| "confidence_map": self.confidence_map, |
| "anomalies": self.anomalies, |
| "reasoning_steps": self.reasoning_steps, |
| "source_trust": self.source_trust, |
| "temporal_context": self.temporal_context, |
| "language": self.language, |
| "source": self.source, |
| } |
|
|
| @classmethod |
| def from_dict(cls, data: dict) -> GraphNarrativeExample: |
| """Deserialize from dictionary.""" |
| return cls( |
| narrative=data.get("narrative", ""), |
| trigger=data.get("trigger", ""), |
| evidence_nodes=data.get("evidence_nodes", []), |
| compositions=data.get("compositions", []), |
| confidence_map=data.get("confidence_map", {}), |
| anomalies=data.get("anomalies", []), |
| reasoning_steps=data.get("reasoning_steps", []), |
| source_trust=data.get("source_trust", 1.0), |
| temporal_context=data.get("temporal_context", []), |
| language=data.get("language", "id"), |
| source=data.get("source", "synthetic"), |
| ) |
|
|
|
|
| @dataclass |
| class BatchOutput: |
| """Output from a single batch. |
| |
| All tensors are already padded to uniform length. |
| """ |
| token_ids: torch.Tensor |
| """Target narrative token IDs, shape (batch, seq_len).""" |
|
|
| evidence_ids: Optional[torch.Tensor] = None |
| """Evidence node token IDs, shape (batch, n_evidence, ev_seq_len).""" |
|
|
| evidence_confidence: Optional[torch.Tensor] = None |
| """Evidence confidence, shape (batch, n_evidence).""" |
|
|
| anomaly_ids: Optional[torch.Tensor] = None |
| """Anomaly token IDs, shape (batch, n_anomalies, an_seq_len).""" |
|
|
| anomaly_confidence: Optional[torch.Tensor] = None |
| """Anomaly confidence, shape (batch, n_anomalies).""" |
|
|
| reasoning_ids: Optional[torch.Tensor] = None |
| """Reasoning step token IDs, shape (batch, n_steps, r_seq_len).""" |
|
|
| reasoning_confidence: Optional[torch.Tensor] = None |
| """Reasoning confidence, shape (batch, n_steps).""" |
|
|
| source_trust: Optional[torch.Tensor] = None |
| """Source trust scores, shape (batch,).""" |
|
|
|
|
| class GraphNarrativeDataset(Dataset): |
| """Dataset for Graph→Narrative training pairs. |
| |
| Loads training examples from JSONL files and provides |
| tokenized, padded batches for training. |
| |
| Args: |
| data_path: Path to JSONL file with training data. |
| tokenizer: AamTokenizer instance for encoding. |
| max_seq_len: Maximum sequence length for narratives. |
| max_evidence: Maximum number of evidence nodes. |
| max_anomalies: Maximum number of anomalies. |
| max_reasoning: Maximum number of reasoning steps. |
| augment: Whether to apply data augmentation. |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str | Path, |
| tokenizer: AamTokenizer, |
| max_seq_len: int = 512, |
| max_evidence: int = 50, |
| max_anomalies: int = 10, |
| max_reasoning: int = 15, |
| augment: bool = True, |
| ): |
| self.data_path = Path(data_path) |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.max_evidence = max_evidence |
| self.max_anomalies = max_anomalies |
| self.max_reasoning = max_reasoning |
| self.augment = augment |
|
|
| |
| self.examples: list[GraphNarrativeExample] = [] |
| self._load_data() |
|
|
| logger.info( |
| "GraphNarrativeDataset: %d examples loaded from %s", |
| len(self.examples), |
| self.data_path, |
| ) |
|
|
| def _load_data(self) -> None: |
| """Load examples from JSONL file.""" |
| if not self.data_path.exists(): |
| logger.warning("Data file not found: %s", self.data_path) |
| return |
|
|
| with open(self.data_path, "r", encoding="utf-8") as f: |
| for line_num, line in enumerate(f, 1): |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| data = json.loads(line) |
| example = GraphNarrativeExample.from_dict(data) |
| if example.narrative: |
| self.examples.append(example) |
| except json.JSONDecodeError: |
| logger.warning("Invalid JSON at line %d", line_num) |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| """Get a single training example. |
| |
| Returns: |
| Dictionary with tokenized inputs. |
| """ |
| example = self.examples[idx] |
|
|
| |
| if self.augment: |
| example = self._augment(example) |
|
|
| |
| narrative_ids = self.tokenizer.encode(example.narrative, add_special=True) |
| narrative_ids = self.tokenizer.pad_sequence(narrative_ids, self.max_seq_len) |
| narrative_tensor = torch.tensor(narrative_ids, dtype=torch.long) |
|
|
| |
| evidence_data = self._tokenize_node_list( |
| example.evidence_nodes, max_nodes=self.max_evidence |
| ) |
|
|
| |
| anomaly_data = self._tokenize_node_list( |
| example.anomalies, max_nodes=self.max_anomalies |
| ) |
|
|
| |
| reasoning_data = self._tokenize_node_list( |
| example.reasoning_steps, max_nodes=self.max_reasoning |
| ) |
|
|
| |
| source_trust = torch.tensor(example.source_trust, dtype=torch.float32) |
|
|
| |
| conf_values = list(example.confidence_map.values())[:self.max_evidence] |
| if conf_values: |
| evidence_conf = torch.tensor(conf_values, dtype=torch.float32) |
| evidence_conf = torch.nn.functional.pad( |
| evidence_conf, (0, self.max_evidence - len(conf_values)) |
| ) |
| else: |
| evidence_conf = torch.zeros(self.max_evidence, dtype=torch.float32) |
|
|
| |
| anomaly_conf = torch.full( |
| (self.max_anomalies,), 0.6, dtype=torch.float32 |
| ) |
|
|
| |
| reasoning_conf = torch.full( |
| (self.max_reasoning,), 0.7, dtype=torch.float32 |
| ) |
|
|
| return { |
| "token_ids": narrative_tensor, |
| "evidence_ids": evidence_data["ids"], |
| "evidence_confidence": evidence_conf, |
| "anomaly_ids": anomaly_data["ids"], |
| "anomaly_confidence": anomaly_conf, |
| "reasoning_ids": reasoning_data["ids"], |
| "reasoning_confidence": reasoning_conf, |
| "source_trust": source_trust, |
| } |
|
|
| def _tokenize_node_list( |
| self, |
| nodes: list[str], |
| max_nodes: int, |
| max_node_len: int = 32, |
| ) -> dict[str, torch.Tensor]: |
| """Tokenize a list of node descriptions. |
| |
| Args: |
| nodes: List of node text descriptions. |
| max_nodes: Maximum number of nodes to encode. |
| max_node_len: Maximum token length per node. |
| |
| Returns: |
| Dictionary with padded token IDs tensor. |
| """ |
| if not nodes: |
| return { |
| "ids": torch.zeros(max_nodes, max_node_len, dtype=torch.long), |
| } |
|
|
| |
| nodes = nodes[:max_nodes] |
|
|
| all_ids = [] |
| for node in nodes: |
| ids = self.tokenizer.encode(node, add_special=False) |
| ids = self.tokenizer.pad_sequence(ids, max_node_len) |
| all_ids.append(ids) |
|
|
| |
| while len(all_ids) < max_nodes: |
| all_ids.append([0] * max_node_len) |
|
|
| return { |
| "ids": torch.tensor(all_ids, dtype=torch.long), |
| } |
|
|
| def _augment(self, example: GraphNarrativeExample) -> GraphNarrativeExample: |
| """Apply data augmentation. |
| |
| Augmentation strategies: |
| 1. Random sentence shuffling within the narrative |
| 2. Random evidence node dropping (simulate incomplete data) |
| 3. Random confidence perturbation |
| |
| Args: |
| example: Original training example. |
| |
| Returns: |
| Augmented example. |
| """ |
| import copy |
| augmented = copy.deepcopy(example) |
|
|
| |
| if random.random() < 0.2: |
| sentences = self.tokenizer._split_sentences(augmented.narrative) |
| if len(sentences) > 2: |
| |
| first = sentences[0] |
| rest = sentences[1:] |
| random.shuffle(rest) |
| augmented.narrative = first + " " + " ".join(rest) |
|
|
| |
| if augmented.evidence_nodes: |
| augmented.evidence_nodes = [ |
| node for node in augmented.evidence_nodes |
| if random.random() > 0.1 |
| ] |
|
|
| |
| if augmented.confidence_map: |
| perturbed = {} |
| for k, v in augmented.confidence_map.items(): |
| noise = random.gauss(0, 0.05) |
| perturbed[k] = max(0.0, min(1.0, v + noise)) |
| augmented.confidence_map = perturbed |
|
|
| return augmented |
|
|
|
|
| def collate_fn(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: |
| """Custom collate function for DataLoader. |
| |
| Handles variable-length graph conditioning by padding |
| all tensors in the batch to the same size. |
| |
| Args: |
| batch: List of example dictionaries. |
| |
| Returns: |
| Batched dictionary of tensors. |
| """ |
| result = {} |
|
|
| |
| for key in batch[0]: |
| tensors = [item[key] for item in batch] |
| if tensors[0].dim() == 0: |
| result[key] = torch.stack(tensors) |
| elif tensors[0].dim() == 1: |
| result[key] = torch.stack(tensors) |
| elif tensors[0].dim() == 2: |
| result[key] = torch.stack(tensors) |
| else: |
| result[key] = torch.stack(tensors) |
|
|
| return result |
|
|