File size: 6,505 Bytes
2d7e335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
AAM Diffusion LLM — Data Pipeline

Orchestrates data preparation: from raw graph data and narratives
to tokenized, batched training data.

The pipeline handles:
1. Loading raw graph→narrative pairs
2. Generating synthetic data if real data isn't available
3. Tokenizing all data
4. Creating train/val splits
5. Building DataLoaders

Analogi: Seperti proses persiapan sebelum Jin Soun berlatih —
mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan
data latihan yang terstruktur.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional

from torch.utils.data import DataLoader

from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator

logger = logging.getLogger(__name__)


class DataPipeline:
    """Data preparation pipeline for AAM Diffusion LLM training.

    Orchestrates the entire data preparation process:
    1. Check for existing data
    2. Generate synthetic data if needed
    3. Train tokenizer on the data
    4. Create datasets and dataloaders

    Usage:
        pipeline = DataPipeline(config)
        tokenizer, train_loader, val_loader = pipeline.prepare()
    """

    def __init__(self, config: AamDiffusionConfig):
        self.config = config
        self.output_dir = Path(config.output_dir) / "data"
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def prepare(
        self,
        tokenizer: Optional[AamTokenizer] = None,
        force_regenerate: bool = False,
    ) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]:
        """Prepare all data for training.

        Args:
            tokenizer: Optional pre-trained tokenizer.
            force_regenerate: Whether to regenerate synthetic data.

        Returns:
            Tuple of (tokenizer, train_loader, val_loader).
        """
        train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None
        val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None

        # Step 1: Generate synthetic data if no real data
        if not train_path or not train_path.exists() or force_regenerate:
            logger.info("Generating synthetic training data...")
            train_path, val_path = SyntheticDataGenerator.generate_training_split(
                output_dir=self.output_dir,
                n_train=10000,
                n_val=500,
                language=self.config.inference.language,
                seed=self.config.seed,
            )

        # Step 2: Train tokenizer if not provided
        if tokenizer is None or not tokenizer.is_trained:
            logger.info("Training tokenizer...")
            tokenizer = AamTokenizer()
            # Read training texts for tokenizer training
            texts = self._read_texts(train_path)
            tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size)
            tokenizer.save(self.output_dir / "tokenizer.json")
            logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size)

        # Step 3: Create datasets
        logger.info("Creating datasets...")
        train_dataset = GraphNarrativeDataset(
            data_path=train_path,
            tokenizer=tokenizer,
            max_seq_len=self.config.model.max_seq_len,
            max_evidence=self.config.graph_encoder.max_evidence_nodes,
            max_anomalies=self.config.graph_encoder.max_anomalies,
            max_reasoning=self.config.graph_encoder.max_reasoning_steps,
        )

        val_dataset = None
        if val_path and val_path.exists():
            val_dataset = GraphNarrativeDataset(
                data_path=val_path,
                tokenizer=tokenizer,
                max_seq_len=self.config.model.max_seq_len,
                max_evidence=self.config.graph_encoder.max_evidence_nodes,
                max_anomalies=self.config.graph_encoder.max_anomalies,
                max_reasoning=self.config.graph_encoder.max_reasoning_steps,
                augment=False,  # No augmentation for validation
            )

        # Step 4: Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=self.config.training.num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

        val_loader = None
        if val_dataset:
            val_loader = DataLoader(
                val_dataset,
                batch_size=self.config.training.batch_size,
                shuffle=False,
                num_workers=self.config.training.num_workers,
                collate_fn=collate_fn,
                pin_memory=True,
            )

        logger.info(
            "Data pipeline ready: %d training examples, %s validation examples",
            len(train_dataset),
            len(val_dataset) if val_dataset else 0,
        )

        return tokenizer, train_loader, val_loader

    def _read_texts(self, path: Path) -> list[str]:
        """Read narrative texts from JSONL file for tokenizer training.

        Args:
            path: Path to JSONL data file.

        Returns:
            List of narrative texts.
        """
        import json
        texts = []
        if not path.exists():
            return texts

        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                    # Collect both narratives and evidence for richer tokenizer
                    if data.get("narrative"):
                        texts.append(data["narrative"])
                    if data.get("trigger"):
                        texts.append(data["trigger"])
                    for ev in data.get("evidence_nodes", []):
                        texts.append(ev)
                    for anom in data.get("anomalies", []):
                        texts.append(anom)
                    for step in data.get("reasoning_steps", []):
                        texts.append(step)
                except json.JSONDecodeError:
                    continue

        return texts