File size: 6,505 Bytes
2d7e335 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
AAM Diffusion LLM — Data Pipeline
Orchestrates data preparation: from raw graph data and narratives
to tokenized, batched training data.
The pipeline handles:
1. Loading raw graph→narrative pairs
2. Generating synthetic data if real data isn't available
3. Tokenizing all data
4. Creating train/val splits
5. Building DataLoaders
Analogi: Seperti proses persiapan sebelum Jin Soun berlatih —
mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan
data latihan yang terstruktur.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from torch.utils.data import DataLoader
from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
logger = logging.getLogger(__name__)
class DataPipeline:
"""Data preparation pipeline for AAM Diffusion LLM training.
Orchestrates the entire data preparation process:
1. Check for existing data
2. Generate synthetic data if needed
3. Train tokenizer on the data
4. Create datasets and dataloaders
Usage:
pipeline = DataPipeline(config)
tokenizer, train_loader, val_loader = pipeline.prepare()
"""
def __init__(self, config: AamDiffusionConfig):
self.config = config
self.output_dir = Path(config.output_dir) / "data"
self.output_dir.mkdir(parents=True, exist_ok=True)
def prepare(
self,
tokenizer: Optional[AamTokenizer] = None,
force_regenerate: bool = False,
) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]:
"""Prepare all data for training.
Args:
tokenizer: Optional pre-trained tokenizer.
force_regenerate: Whether to regenerate synthetic data.
Returns:
Tuple of (tokenizer, train_loader, val_loader).
"""
train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None
val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None
# Step 1: Generate synthetic data if no real data
if not train_path or not train_path.exists() or force_regenerate:
logger.info("Generating synthetic training data...")
train_path, val_path = SyntheticDataGenerator.generate_training_split(
output_dir=self.output_dir,
n_train=10000,
n_val=500,
language=self.config.inference.language,
seed=self.config.seed,
)
# Step 2: Train tokenizer if not provided
if tokenizer is None or not tokenizer.is_trained:
logger.info("Training tokenizer...")
tokenizer = AamTokenizer()
# Read training texts for tokenizer training
texts = self._read_texts(train_path)
tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size)
tokenizer.save(self.output_dir / "tokenizer.json")
logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size)
# Step 3: Create datasets
logger.info("Creating datasets...")
train_dataset = GraphNarrativeDataset(
data_path=train_path,
tokenizer=tokenizer,
max_seq_len=self.config.model.max_seq_len,
max_evidence=self.config.graph_encoder.max_evidence_nodes,
max_anomalies=self.config.graph_encoder.max_anomalies,
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
)
val_dataset = None
if val_path and val_path.exists():
val_dataset = GraphNarrativeDataset(
data_path=val_path,
tokenizer=tokenizer,
max_seq_len=self.config.model.max_seq_len,
max_evidence=self.config.graph_encoder.max_evidence_nodes,
max_anomalies=self.config.graph_encoder.max_anomalies,
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
augment=False, # No augmentation for validation
)
# Step 4: Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=self.config.training.batch_size,
shuffle=True,
num_workers=self.config.training.num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
val_loader = None
if val_dataset:
val_loader = DataLoader(
val_dataset,
batch_size=self.config.training.batch_size,
shuffle=False,
num_workers=self.config.training.num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
logger.info(
"Data pipeline ready: %d training examples, %s validation examples",
len(train_dataset),
len(val_dataset) if val_dataset else 0,
)
return tokenizer, train_loader, val_loader
def _read_texts(self, path: Path) -> list[str]:
"""Read narrative texts from JSONL file for tokenizer training.
Args:
path: Path to JSONL data file.
Returns:
List of narrative texts.
"""
import json
texts = []
if not path.exists():
return texts
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# Collect both narratives and evidence for richer tokenizer
if data.get("narrative"):
texts.append(data["narrative"])
if data.get("trigger"):
texts.append(data["trigger"])
for ev in data.get("evidence_nodes", []):
texts.append(ev)
for anom in data.get("anomalies", []):
texts.append(anom)
for step in data.get("reasoning_steps", []):
texts.append(step)
except json.JSONDecodeError:
continue
return texts
|