| """OpenThoughts-1.2M Advanced Processor with Multi-Source Integration"""
|
|
|
| import json
|
| import logging
|
| from dataclasses import dataclass, field
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional, Tuple, Union, Iterator
|
|
|
| import numpy as np
|
| from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
|
| from torch.utils.data import DataLoader, IterableDataset
|
|
|
| from .quality_filter import QualityFilter, QualityMetrics
|
| from .curriculum_sampler import CurriculumSampler
|
| from .data_augmentation import DataAugmenter
|
| from .preprocessing import (
|
| preprocess_conversation,
|
| extract_thoughts,
|
| format_for_training,
|
| detect_domain,
|
| estimate_difficulty,
|
| )
|
| from .utils import compute_length_statistics
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class OpenThoughtsConfig:
|
| """Configuration for OpenThoughts processing."""
|
| dataset_name: str = "open-thoughts/OpenThoughts3-1.2M"
|
| split: str = "train"
|
| cache_dir: str = "./data/cache"
|
| streaming: bool = False
|
| max_samples: Optional[int] = None
|
|
|
|
|
| include_thoughts: bool = True
|
| include_reasoning: bool = True
|
| include_conversations: bool = True
|
| min_quality_score: float = 0.7
|
| min_thoughts_length: int = 100
|
|
|
|
|
| quality_filter: QualityFilter = field(default_factory=QualityFilter)
|
|
|
|
|
| use_curriculum: bool = True
|
| curriculum_sampler: Optional[CurriculumSampler] = None
|
|
|
|
|
| use_augmentation: bool = False
|
| augmenter: Optional[DataAugmenter] = None
|
| augmentation_prob: float = 0.1
|
|
|
|
|
| tokenizer: Optional[Any] = None
|
| max_seq_length: int = 8192
|
|
|
|
|
| custom_datasets: List[str] = field(default_factory=list)
|
| dataset_weights: List[float] = field(default_factory=list)
|
|
|
|
|
| class OpenThoughtsProcessor:
|
| """Advanced processor for OpenThoughts-1.2M with quality filtering, curriculum, and augmentation."""
|
|
|
| def __init__(self, config: OpenThoughtsConfig):
|
| self.config = config
|
| self.quality_filter = config.quality_filter
|
| self.augmenter = config.augmenter if config.use_augmentation else None
|
| self.curriculum_sampler = config.curriculum_sampler
|
| self.tokenizer = config.tokenizer
|
|
|
|
|
| self.stats: Dict[str, Any] = {}
|
| self._processed_count = 0
|
| self._filtered_count = 0
|
|
|
| def load_dataset(self) -> Dataset:
|
| """Load OpenThoughts dataset with optional custom datasets."""
|
| logger.info(f"Loading OpenThoughts dataset: {self.config.dataset_name}")
|
|
|
| try:
|
|
|
| ds = load_dataset(
|
| self.config.dataset_name,
|
| split=self.config.split,
|
| cache_dir=self.config.cache_dir,
|
| streaming=self.config.streaming,
|
| )
|
| logger.info(f"Loaded {len(ds) if not self.config.streaming else 'streaming'} samples")
|
|
|
|
|
| ds = self._apply_initial_filters(ds)
|
|
|
|
|
| if self.config.custom_datasets:
|
| ds = self._mix_datasets(ds)
|
|
|
|
|
| self._compute_statistics(ds)
|
|
|
| return ds
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to load dataset: {e}")
|
| raise
|
|
|
| def _apply_initial_filters(self, ds: Dataset) -> Dataset:
|
| """Apply quality and content filters."""
|
| logger.info("Applying initial filters...")
|
|
|
|
|
| required_cols = ["conversations"]
|
| for col in required_cols:
|
| if col not in ds.column_names:
|
| raise ValueError(f"Missing required column: {col}")
|
|
|
|
|
| if "quality_score" in ds.column_names:
|
| initial_size = len(ds)
|
| ds = ds.filter(lambda x: x.get("quality_score", 1.0) >= self.config.min_quality_score)
|
| logger.info(f"Quality filter: {initial_size} -> {len(ds)}")
|
|
|
|
|
| if self.config.include_thoughts and self.config.min_thoughts_length > 0:
|
| initial_size = len(ds)
|
| ds = ds.filter(lambda x: len(x.get("thoughts", "")) >= self.config.min_thoughts_length)
|
| logger.info(f"Thoughts length filter: {initial_size} -> {len(ds)}")
|
|
|
| self._filtered_count = len(ds)
|
| return ds
|
|
|
| def _mix_datasets(self, main_ds: Dataset) -> Dataset:
|
| """Mix main dataset with custom datasets according to weights."""
|
| logger.info(f"Mixing with custom datasets: {self.config.custom_datasets}")
|
|
|
| datasets = [main_ds]
|
| weights = [1.0]
|
|
|
| for custom_path in self.config.custom_datasets:
|
| try:
|
| custom_ds = load_dataset(custom_path, split=self.config.split, cache_dir=self.config.cache_dir)
|
| datasets.append(custom_ds)
|
| weights.append(self.config.dataset_weights.pop(0) if self.config.dataset_weights else 1.0)
|
| logger.info(f"Loaded custom dataset: {custom_path} ({len(custom_ds)} samples)")
|
| except Exception as e:
|
| logger.warning(f"Failed to load custom dataset {custom_path}: {e}")
|
|
|
| if len(datasets) > 1:
|
|
|
| total_weight = sum(weights)
|
| weights = [w / total_weight for w in weights]
|
|
|
|
|
|
|
| mixed = concatenate_datasets(datasets)
|
| logger.info(f"Mixed dataset size: {len(mixed)}")
|
| return mixed
|
| else:
|
| return main_ds
|
|
|
| def _compute_statistics(self, ds: Dataset):
|
| """Compute dataset statistics."""
|
| logger.info("Computing dataset statistics...")
|
|
|
|
|
| sample_size = min(1000, len(ds) if not self.config.streaming else 1000)
|
| if self.config.streaming:
|
| samples = list(ds.take(sample_size))
|
| else:
|
| samples = ds.select(range(sample_size))
|
|
|
|
|
| lengths = []
|
| for sample in samples:
|
| conv = sample.get("conversations", "")
|
| if isinstance(conv, str):
|
| lengths.append(len(conv))
|
| elif isinstance(conv, list):
|
| lengths.append(sum(len(msg.get("content", "")) for msg in conv))
|
|
|
| length_stats = compute_length_statistics(lengths)
|
|
|
| self.stats = {
|
| "total_samples": len(ds) if not self.config.streaming else "streaming",
|
| "avg_length": length_stats["mean"],
|
| "length_std": length_stats["std"],
|
| "min_length": length_stats["min"],
|
| "max_length": length_stats["max"],
|
| "p90_length": length_stats["p90"],
|
| "p99_length": length_stats["p99"],
|
| }
|
|
|
| logger.info(f"Dataset stats: {self.stats}")
|
|
|
| def preprocess_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| """Preprocess a single sample with all transformations."""
|
| self._processed_count += 1
|
|
|
|
|
| conversations = sample.get("conversations", [])
|
| if isinstance(conversations, str):
|
| conversations = json.loads(conversations) if isinstance(conversations, str) else conversations
|
|
|
|
|
| processed = preprocess_conversation(
|
| conversations,
|
| include_thoughts=self.config.include_thoughts,
|
| include_reasoning=self.config.include_reasoning,
|
| )
|
|
|
|
|
| if "thoughts" in sample:
|
| thoughts = extract_thoughts(sample["thoughts"])
|
| processed["thoughts"] = thoughts
|
|
|
|
|
| if "domain" not in processed:
|
| processed["domain"] = detect_domain(conversations)
|
|
|
|
|
| if "difficulty" not in processed:
|
| processed["difficulty"] = estimate_difficulty(conversations, sample.get("thoughts", ""))
|
|
|
|
|
| quality_metrics = self.quality_filter.compute_metrics(processed)
|
| processed["quality_metrics"] = quality_metrics
|
|
|
|
|
| if self.augmenter and np.random.random() < self.config.augmentation_prob:
|
| augmented = self.augmenter.augment(processed)
|
| if augmented:
|
| processed.update(augmented)
|
|
|
|
|
| if self.tokenizer:
|
| tokenized = self._tokenize_sample(processed)
|
| processed.update(tokenized)
|
|
|
| return processed
|
|
|
| def _tokenize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
| """Tokenize sample with advanced options."""
|
| if not self.tokenizer:
|
| return {}
|
|
|
| text = format_for_training(sample, include_thoughts=self.config.include_thoughts)
|
|
|
|
|
| tokenized = self.tokenizer(
|
| text,
|
| truncation=True,
|
| max_length=self.config.max_seq_length,
|
| padding="max_length",
|
| return_tensors="pt",
|
| )
|
|
|
|
|
| input_ids = tokenized["input_ids"][0]
|
| attention_mask = tokenized["attention_mask"][0]
|
| labels = input_ids.clone()
|
| labels[labels == self.tokenizer.pad_token_id] = -100
|
|
|
|
|
| if self.config.include_thoughts and "thoughts" in sample:
|
|
|
| pass
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels,
|
| }
|
|
|
| def create_dataloader(
|
| self,
|
| ds: Dataset,
|
| batch_size: int = 32,
|
| shuffle: bool = True,
|
| num_workers: int = 4,
|
| curriculum_epoch: int = 0,
|
| ) -> DataLoader:
|
| """Create DataLoader with optional curriculum sampling."""
|
|
|
| if not self.config.streaming:
|
| logger.info("Preprocessing entire dataset...")
|
| ds = ds.map(
|
| self.preprocess_sample,
|
| batched=False,
|
| num_proc=num_workers,
|
| remove_columns=ds.column_names,
|
| desc="Preprocessing",
|
| )
|
|
|
|
|
| if self.curriculum_sampler and shuffle:
|
| sampler = self.curriculum_sampler.get_sampler(ds, epoch=curriculum_epoch)
|
| shuffle = False
|
| else:
|
| sampler = None
|
|
|
| return DataLoader(
|
| ds,
|
| batch_size=batch_size,
|
| shuffle=shuffle,
|
| sampler=sampler,
|
| num_workers=num_workers,
|
| pin_memory=True,
|
| drop_last=True,
|
| )
|
|
|
|
|
| class OpenThoughtsDataset(IterableDataset):
|
| """Streaming dataset for large-scale OpenThoughts training."""
|
|
|
| def __init__(
|
| self,
|
| processor: OpenThoughtsProcessor,
|
| infinite: bool = False,
|
| ):
|
| self.processor = processor
|
| self.infinite = infinite
|
| self._buffer = []
|
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]:
|
| """Iterate over samples."""
|
| while True:
|
|
|
| if not self._buffer or self.processor.config.streaming:
|
| ds = self.processor.load_dataset()
|
| for sample in ds:
|
| processed = self.processor.preprocess_sample(sample)
|
| yield processed
|
|
|
| if not self.infinite:
|
| break
|
|
|
| def __len__(self) -> int:
|
| """Return approximate length."""
|
| if self.processor.config.streaming:
|
| return float("inf")
|
| return self.processor._filtered_count
|
|
|
|
|
| def create_openthoughts_pipeline(
|
| config: OpenThoughtsConfig,
|
| tokenizer: Optional[Any] = None,
|
| ) -> Tuple[OpenThoughtsProcessor, Dataset]:
|
| """Factory function to create complete OpenThoughts pipeline."""
|
| config.tokenizer = tokenizer
|
|
|
| processor = OpenThoughtsProcessor(config)
|
| dataset = processor.load_dataset()
|
|
|
| return processor, dataset
|
|
|