File size: 12,811 Bytes
8d18b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""OpenThoughts-1.2M Advanced Processor with Multi-Source Integration"""

import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, Iterator

import numpy as np
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from torch.utils.data import DataLoader, IterableDataset

from .quality_filter import QualityFilter, QualityMetrics
from .curriculum_sampler import CurriculumSampler
from .data_augmentation import DataAugmenter
from .preprocessing import (
    preprocess_conversation,
    extract_thoughts,
    format_for_training,
    detect_domain,
    estimate_difficulty,
)
from .utils import compute_length_statistics

logger = logging.getLogger(__name__)


@dataclass
class OpenThoughtsConfig:
    """Configuration for OpenThoughts processing."""
    dataset_name: str = "open-thoughts/OpenThoughts3-1.2M"
    split: str = "train"
    cache_dir: str = "./data/cache"
    streaming: bool = False
    max_samples: Optional[int] = None

    # Processing options
    include_thoughts: bool = True
    include_reasoning: bool = True
    include_conversations: bool = True
    min_quality_score: float = 0.7
    min_thoughts_length: int = 100

    # Filtering
    quality_filter: QualityFilter = field(default_factory=QualityFilter)

    # Curriculum
    use_curriculum: bool = True
    curriculum_sampler: Optional[CurriculumSampler] = None

    # Augmentation
    use_augmentation: bool = False
    augmenter: Optional[DataAugmenter] = None
    augmentation_prob: float = 0.1

    # Tokenization
    tokenizer: Optional[Any] = None
    max_seq_length: int = 8192

    # Multi-dataset mixing
    custom_datasets: List[str] = field(default_factory=list)
    dataset_weights: List[float] = field(default_factory=list)


class OpenThoughtsProcessor:
    """Advanced processor for OpenThoughts-1.2M with quality filtering, curriculum, and augmentation."""

    def __init__(self, config: OpenThoughtsConfig):
        self.config = config
        self.quality_filter = config.quality_filter
        self.augmenter = config.augmenter if config.use_augmentation else None
        self.curriculum_sampler = config.curriculum_sampler
        self.tokenizer = config.tokenizer

        # Statistics
        self.stats: Dict[str, Any] = {}
        self._processed_count = 0
        self._filtered_count = 0

    def load_dataset(self) -> Dataset:
        """Load OpenThoughts dataset with optional custom datasets."""
        logger.info(f"Loading OpenThoughts dataset: {self.config.dataset_name}")

        try:
            # Load main dataset
            ds = load_dataset(
                self.config.dataset_name,
                split=self.config.split,
                cache_dir=self.config.cache_dir,
                streaming=self.config.streaming,
            )
            logger.info(f"Loaded {len(ds) if not self.config.streaming else 'streaming'} samples")

            # Apply initial filtering
            ds = self._apply_initial_filters(ds)

            # Load and mix custom datasets if specified
            if self.config.custom_datasets:
                ds = self._mix_datasets(ds)

            # Compute statistics
            self._compute_statistics(ds)

            return ds

        except Exception as e:
            logger.error(f"Failed to load dataset: {e}")
            raise

    def _apply_initial_filters(self, ds: Dataset) -> Dataset:
        """Apply quality and content filters."""
        logger.info("Applying initial filters...")

        # Check required columns
        required_cols = ["conversations"]
        for col in required_cols:
            if col not in ds.column_names:
                raise ValueError(f"Missing required column: {col}")

        # Filter by quality score if available
        if "quality_score" in ds.column_names:
            initial_size = len(ds)
            ds = ds.filter(lambda x: x.get("quality_score", 1.0) >= self.config.min_quality_score)
            logger.info(f"Quality filter: {initial_size} -> {len(ds)}")

        # Filter by thoughts length if required
        if self.config.include_thoughts and self.config.min_thoughts_length > 0:
            initial_size = len(ds)
            ds = ds.filter(lambda x: len(x.get("thoughts", "")) >= self.config.min_thoughts_length)
            logger.info(f"Thoughts length filter: {initial_size} -> {len(ds)}")

        self._filtered_count = len(ds)
        return ds

    def _mix_datasets(self, main_ds: Dataset) -> Dataset:
        """Mix main dataset with custom datasets according to weights."""
        logger.info(f"Mixing with custom datasets: {self.config.custom_datasets}")

        datasets = [main_ds]
        weights = [1.0]  # Main dataset weight

        for custom_path in self.config.custom_datasets:
            try:
                custom_ds = load_dataset(custom_path, split=self.config.split, cache_dir=self.config.cache_dir)
                datasets.append(custom_ds)
                weights.append(self.config.dataset_weights.pop(0) if self.config.dataset_weights else 1.0)
                logger.info(f"Loaded custom dataset: {custom_path} ({len(custom_ds)} samples)")
            except Exception as e:
                logger.warning(f"Failed to load custom dataset {custom_path}: {e}")

        if len(datasets) > 1:
            # Normalize weights
            total_weight = sum(weights)
            weights = [w / total_weight for w in weights]

            # Interleave datasets according to weights
            # For simplicity, we concatenate and will sample later
            mixed = concatenate_datasets(datasets)
            logger.info(f"Mixed dataset size: {len(mixed)}")
            return mixed
        else:
            return main_ds

    def _compute_statistics(self, ds: Dataset):
        """Compute dataset statistics."""
        logger.info("Computing dataset statistics...")

        # Sample for analysis
        sample_size = min(1000, len(ds) if not self.config.streaming else 1000)
        if self.config.streaming:
            samples = list(ds.take(sample_size))
        else:
            samples = ds.select(range(sample_size))

        # Length statistics
        lengths = []
        for sample in samples:
            conv = sample.get("conversations", "")
            if isinstance(conv, str):
                lengths.append(len(conv))
            elif isinstance(conv, list):
                lengths.append(sum(len(msg.get("content", "")) for msg in conv))

        length_stats = compute_length_statistics(lengths)

        self.stats = {
            "total_samples": len(ds) if not self.config.streaming else "streaming",
            "avg_length": length_stats["mean"],
            "length_std": length_stats["std"],
            "min_length": length_stats["min"],
            "max_length": length_stats["max"],
            "p90_length": length_stats["p90"],
            "p99_length": length_stats["p99"],
        }

        logger.info(f"Dataset stats: {self.stats}")

    def preprocess_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        """Preprocess a single sample with all transformations."""
        self._processed_count += 1

        # Extract conversations
        conversations = sample.get("conversations", [])
        if isinstance(conversations, str):
            conversations = json.loads(conversations) if isinstance(conversations, str) else conversations

        # Preprocess conversation
        processed = preprocess_conversation(
            conversations,
            include_thoughts=self.config.include_thoughts,
            include_reasoning=self.config.include_reasoning,
        )

        # Extract thoughts if available
        if "thoughts" in sample:
            thoughts = extract_thoughts(sample["thoughts"])
            processed["thoughts"] = thoughts

        # Detect domain
        if "domain" not in processed:
            processed["domain"] = detect_domain(conversations)

        # Estimate difficulty
        if "difficulty" not in processed:
            processed["difficulty"] = estimate_difficulty(conversations, sample.get("thoughts", ""))

        # Quality metrics
        quality_metrics = self.quality_filter.compute_metrics(processed)
        processed["quality_metrics"] = quality_metrics

        # Apply augmentation if enabled
        if self.augmenter and np.random.random() < self.config.augmentation_prob:
            augmented = self.augmenter.augment(processed)
            if augmented:
                processed.update(augmented)

        # Tokenize if tokenizer provided
        if self.tokenizer:
            tokenized = self._tokenize_sample(processed)
            processed.update(tokenized)

        return processed

    def _tokenize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        """Tokenize sample with advanced options."""
        if not self.tokenizer:
            return {}

        text = format_for_training(sample, include_thoughts=self.config.include_thoughts)

        # Tokenize with truncation
        tokenized = self.tokenizer(
            text,
            truncation=True,
            max_length=self.config.max_seq_length,
            padding="max_length",
            return_tensors="pt",
        )

        # Create labels (shifted for causal LM)
        input_ids = tokenized["input_ids"][0]
        attention_mask = tokenized["attention_mask"][0]
        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        # Mask out non-target tokens if needed
        if self.config.include_thoughts and "thoughts" in sample:
            # Could mask specific parts for multi-task learning
            pass

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

    def create_dataloader(

        self,

        ds: Dataset,

        batch_size: int = 32,

        shuffle: bool = True,

        num_workers: int = 4,

        curriculum_epoch: int = 0,

    ) -> DataLoader:
        """Create DataLoader with optional curriculum sampling."""
        # Apply preprocessing to entire dataset if not streaming
        if not self.config.streaming:
            logger.info("Preprocessing entire dataset...")
            ds = ds.map(
                self.preprocess_sample,
                batched=False,
                num_proc=num_workers,
                remove_columns=ds.column_names,
                desc="Preprocessing",
            )

        # Apply curriculum sampling if configured
        if self.curriculum_sampler and shuffle:
            sampler = self.curriculum_sampler.get_sampler(ds, epoch=curriculum_epoch)
            shuffle = False  # Sampler handles shuffling
        else:
            sampler = None

        return DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
        )


class OpenThoughtsDataset(IterableDataset):
    """Streaming dataset for large-scale OpenThoughts training."""

    def __init__(

        self,

        processor: OpenThoughtsProcessor,

        infinite: bool = False,

    ):
        self.processor = processor
        self.infinite = infinite
        self._buffer = []

    def __iter__(self) -> Iterator[Dict[str, Any]]:
        """Iterate over samples."""
        while True:
            # Load dataset (or use cached iterator)
            if not self._buffer or self.processor.config.streaming:
                ds = self.processor.load_dataset()
                for sample in ds:
                    processed = self.processor.preprocess_sample(sample)
                    yield processed

            if not self.infinite:
                break

    def __len__(self) -> int:
        """Return approximate length."""
        if self.processor.config.streaming:
            return float("inf")
        return self.processor._filtered_count


def create_openthoughts_pipeline(

    config: OpenThoughtsConfig,

    tokenizer: Optional[Any] = None,

) -> Tuple[OpenThoughtsProcessor, Dataset]:
    """Factory function to create complete OpenThoughts pipeline."""
    config.tokenizer = tokenizer

    processor = OpenThoughtsProcessor(config)
    dataset = processor.load_dataset()

    return processor, dataset