""" AAM Diffusion LLM — Data Pipeline Orchestrates data preparation: from raw graph data and narratives to tokenized, batched training data. The pipeline handles: 1. Loading raw graph→narrative pairs 2. Generating synthetic data if real data isn't available 3. Tokenizing all data 4. Creating train/val splits 5. Building DataLoaders Analogi: Seperti proses persiapan sebelum Jin Soun berlatih — mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan data latihan yang terstruktur. """ from __future__ import annotations import logging from pathlib import Path from typing import Optional from torch.utils.data import DataLoader from diffusion_llm.config.model_config import AamDiffusionConfig from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator logger = logging.getLogger(__name__) class DataPipeline: """Data preparation pipeline for AAM Diffusion LLM training. Orchestrates the entire data preparation process: 1. Check for existing data 2. Generate synthetic data if needed 3. Train tokenizer on the data 4. Create datasets and dataloaders Usage: pipeline = DataPipeline(config) tokenizer, train_loader, val_loader = pipeline.prepare() """ def __init__(self, config: AamDiffusionConfig): self.config = config self.output_dir = Path(config.output_dir) / "data" self.output_dir.mkdir(parents=True, exist_ok=True) def prepare( self, tokenizer: Optional[AamTokenizer] = None, force_regenerate: bool = False, ) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]: """Prepare all data for training. Args: tokenizer: Optional pre-trained tokenizer. force_regenerate: Whether to regenerate synthetic data. Returns: Tuple of (tokenizer, train_loader, val_loader). """ train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None # Step 1: Generate synthetic data if no real data if not train_path or not train_path.exists() or force_regenerate: logger.info("Generating synthetic training data...") train_path, val_path = SyntheticDataGenerator.generate_training_split( output_dir=self.output_dir, n_train=10000, n_val=500, language=self.config.inference.language, seed=self.config.seed, ) # Step 2: Train tokenizer if not provided if tokenizer is None or not tokenizer.is_trained: logger.info("Training tokenizer...") tokenizer = AamTokenizer() # Read training texts for tokenizer training texts = self._read_texts(train_path) tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size) tokenizer.save(self.output_dir / "tokenizer.json") logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size) # Step 3: Create datasets logger.info("Creating datasets...") train_dataset = GraphNarrativeDataset( data_path=train_path, tokenizer=tokenizer, max_seq_len=self.config.model.max_seq_len, max_evidence=self.config.graph_encoder.max_evidence_nodes, max_anomalies=self.config.graph_encoder.max_anomalies, max_reasoning=self.config.graph_encoder.max_reasoning_steps, ) val_dataset = None if val_path and val_path.exists(): val_dataset = GraphNarrativeDataset( data_path=val_path, tokenizer=tokenizer, max_seq_len=self.config.model.max_seq_len, max_evidence=self.config.graph_encoder.max_evidence_nodes, max_anomalies=self.config.graph_encoder.max_anomalies, max_reasoning=self.config.graph_encoder.max_reasoning_steps, augment=False, # No augmentation for validation ) # Step 4: Create dataloaders train_loader = DataLoader( train_dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=self.config.training.num_workers, collate_fn=collate_fn, pin_memory=True, ) val_loader = None if val_dataset: val_loader = DataLoader( val_dataset, batch_size=self.config.training.batch_size, shuffle=False, num_workers=self.config.training.num_workers, collate_fn=collate_fn, pin_memory=True, ) logger.info( "Data pipeline ready: %d training examples, %s validation examples", len(train_dataset), len(val_dataset) if val_dataset else 0, ) return tokenizer, train_loader, val_loader def _read_texts(self, path: Path) -> list[str]: """Read narrative texts from JSONL file for tokenizer training. Args: path: Path to JSONL data file. Returns: List of narrative texts. """ import json texts = [] if not path.exists(): return texts with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: data = json.loads(line) # Collect both narratives and evidence for richer tokenizer if data.get("narrative"): texts.append(data["narrative"]) if data.get("trigger"): texts.append(data["trigger"]) for ev in data.get("evidence_nodes", []): texts.append(ev) for anom in data.get("anomalies", []): texts.append(anom) for step in data.get("reasoning_steps", []): texts.append(step) except json.JSONDecodeError: continue return texts