"""OpenThoughts-1.2M Advanced Processor with Multi-Source Integration""" import json import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, Iterator import numpy as np from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets from torch.utils.data import DataLoader, IterableDataset from .quality_filter import QualityFilter, QualityMetrics from .curriculum_sampler import CurriculumSampler from .data_augmentation import DataAugmenter from .preprocessing import ( preprocess_conversation, extract_thoughts, format_for_training, detect_domain, estimate_difficulty, ) from .utils import compute_length_statistics logger = logging.getLogger(__name__) @dataclass class OpenThoughtsConfig: """Configuration for OpenThoughts processing.""" dataset_name: str = "open-thoughts/OpenThoughts3-1.2M" split: str = "train" cache_dir: str = "./data/cache" streaming: bool = False max_samples: Optional[int] = None # Processing options include_thoughts: bool = True include_reasoning: bool = True include_conversations: bool = True min_quality_score: float = 0.7 min_thoughts_length: int = 100 # Filtering quality_filter: QualityFilter = field(default_factory=QualityFilter) # Curriculum use_curriculum: bool = True curriculum_sampler: Optional[CurriculumSampler] = None # Augmentation use_augmentation: bool = False augmenter: Optional[DataAugmenter] = None augmentation_prob: float = 0.1 # Tokenization tokenizer: Optional[Any] = None max_seq_length: int = 8192 # Multi-dataset mixing custom_datasets: List[str] = field(default_factory=list) dataset_weights: List[float] = field(default_factory=list) class OpenThoughtsProcessor: """Advanced processor for OpenThoughts-1.2M with quality filtering, curriculum, and augmentation.""" def __init__(self, config: OpenThoughtsConfig): self.config = config self.quality_filter = config.quality_filter self.augmenter = config.augmenter if config.use_augmentation else None self.curriculum_sampler = config.curriculum_sampler self.tokenizer = config.tokenizer # Statistics self.stats: Dict[str, Any] = {} self._processed_count = 0 self._filtered_count = 0 def load_dataset(self) -> Dataset: """Load OpenThoughts dataset with optional custom datasets.""" logger.info(f"Loading OpenThoughts dataset: {self.config.dataset_name}") try: # Load main dataset ds = load_dataset( self.config.dataset_name, split=self.config.split, cache_dir=self.config.cache_dir, streaming=self.config.streaming, ) logger.info(f"Loaded {len(ds) if not self.config.streaming else 'streaming'} samples") # Apply initial filtering ds = self._apply_initial_filters(ds) # Load and mix custom datasets if specified if self.config.custom_datasets: ds = self._mix_datasets(ds) # Compute statistics self._compute_statistics(ds) return ds except Exception as e: logger.error(f"Failed to load dataset: {e}") raise def _apply_initial_filters(self, ds: Dataset) -> Dataset: """Apply quality and content filters.""" logger.info("Applying initial filters...") # Check required columns required_cols = ["conversations"] for col in required_cols: if col not in ds.column_names: raise ValueError(f"Missing required column: {col}") # Filter by quality score if available if "quality_score" in ds.column_names: initial_size = len(ds) ds = ds.filter(lambda x: x.get("quality_score", 1.0) >= self.config.min_quality_score) logger.info(f"Quality filter: {initial_size} -> {len(ds)}") # Filter by thoughts length if required if self.config.include_thoughts and self.config.min_thoughts_length > 0: initial_size = len(ds) ds = ds.filter(lambda x: len(x.get("thoughts", "")) >= self.config.min_thoughts_length) logger.info(f"Thoughts length filter: {initial_size} -> {len(ds)}") self._filtered_count = len(ds) return ds def _mix_datasets(self, main_ds: Dataset) -> Dataset: """Mix main dataset with custom datasets according to weights.""" logger.info(f"Mixing with custom datasets: {self.config.custom_datasets}") datasets = [main_ds] weights = [1.0] # Main dataset weight for custom_path in self.config.custom_datasets: try: custom_ds = load_dataset(custom_path, split=self.config.split, cache_dir=self.config.cache_dir) datasets.append(custom_ds) weights.append(self.config.dataset_weights.pop(0) if self.config.dataset_weights else 1.0) logger.info(f"Loaded custom dataset: {custom_path} ({len(custom_ds)} samples)") except Exception as e: logger.warning(f"Failed to load custom dataset {custom_path}: {e}") if len(datasets) > 1: # Normalize weights total_weight = sum(weights) weights = [w / total_weight for w in weights] # Interleave datasets according to weights # For simplicity, we concatenate and will sample later mixed = concatenate_datasets(datasets) logger.info(f"Mixed dataset size: {len(mixed)}") return mixed else: return main_ds def _compute_statistics(self, ds: Dataset): """Compute dataset statistics.""" logger.info("Computing dataset statistics...") # Sample for analysis sample_size = min(1000, len(ds) if not self.config.streaming else 1000) if self.config.streaming: samples = list(ds.take(sample_size)) else: samples = ds.select(range(sample_size)) # Length statistics lengths = [] for sample in samples: conv = sample.get("conversations", "") if isinstance(conv, str): lengths.append(len(conv)) elif isinstance(conv, list): lengths.append(sum(len(msg.get("content", "")) for msg in conv)) length_stats = compute_length_statistics(lengths) self.stats = { "total_samples": len(ds) if not self.config.streaming else "streaming", "avg_length": length_stats["mean"], "length_std": length_stats["std"], "min_length": length_stats["min"], "max_length": length_stats["max"], "p90_length": length_stats["p90"], "p99_length": length_stats["p99"], } logger.info(f"Dataset stats: {self.stats}") def preprocess_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Preprocess a single sample with all transformations.""" self._processed_count += 1 # Extract conversations conversations = sample.get("conversations", []) if isinstance(conversations, str): conversations = json.loads(conversations) if isinstance(conversations, str) else conversations # Preprocess conversation processed = preprocess_conversation( conversations, include_thoughts=self.config.include_thoughts, include_reasoning=self.config.include_reasoning, ) # Extract thoughts if available if "thoughts" in sample: thoughts = extract_thoughts(sample["thoughts"]) processed["thoughts"] = thoughts # Detect domain if "domain" not in processed: processed["domain"] = detect_domain(conversations) # Estimate difficulty if "difficulty" not in processed: processed["difficulty"] = estimate_difficulty(conversations, sample.get("thoughts", "")) # Quality metrics quality_metrics = self.quality_filter.compute_metrics(processed) processed["quality_metrics"] = quality_metrics # Apply augmentation if enabled if self.augmenter and np.random.random() < self.config.augmentation_prob: augmented = self.augmenter.augment(processed) if augmented: processed.update(augmented) # Tokenize if tokenizer provided if self.tokenizer: tokenized = self._tokenize_sample(processed) processed.update(tokenized) return processed def _tokenize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Tokenize sample with advanced options.""" if not self.tokenizer: return {} text = format_for_training(sample, include_thoughts=self.config.include_thoughts) # Tokenize with truncation tokenized = self.tokenizer( text, truncation=True, max_length=self.config.max_seq_length, padding="max_length", return_tensors="pt", ) # Create labels (shifted for causal LM) input_ids = tokenized["input_ids"][0] attention_mask = tokenized["attention_mask"][0] labels = input_ids.clone() labels[labels == self.tokenizer.pad_token_id] = -100 # Mask out non-target tokens if needed if self.config.include_thoughts and "thoughts" in sample: # Could mask specific parts for multi-task learning pass return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def create_dataloader( self, ds: Dataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4, curriculum_epoch: int = 0, ) -> DataLoader: """Create DataLoader with optional curriculum sampling.""" # Apply preprocessing to entire dataset if not streaming if not self.config.streaming: logger.info("Preprocessing entire dataset...") ds = ds.map( self.preprocess_sample, batched=False, num_proc=num_workers, remove_columns=ds.column_names, desc="Preprocessing", ) # Apply curriculum sampling if configured if self.curriculum_sampler and shuffle: sampler = self.curriculum_sampler.get_sampler(ds, epoch=curriculum_epoch) shuffle = False # Sampler handles shuffling else: sampler = None return DataLoader( ds, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=True, ) class OpenThoughtsDataset(IterableDataset): """Streaming dataset for large-scale OpenThoughts training.""" def __init__( self, processor: OpenThoughtsProcessor, infinite: bool = False, ): self.processor = processor self.infinite = infinite self._buffer = [] def __iter__(self) -> Iterator[Dict[str, Any]]: """Iterate over samples.""" while True: # Load dataset (or use cached iterator) if not self._buffer or self.processor.config.streaming: ds = self.processor.load_dataset() for sample in ds: processed = self.processor.preprocess_sample(sample) yield processed if not self.infinite: break def __len__(self) -> int: """Return approximate length.""" if self.processor.config.streaming: return float("inf") return self.processor._filtered_count def create_openthoughts_pipeline( config: OpenThoughtsConfig, tokenizer: Optional[Any] = None, ) -> Tuple[OpenThoughtsProcessor, Dataset]: """Factory function to create complete OpenThoughts pipeline.""" config.tokenizer = tokenizer processor = OpenThoughtsProcessor(config) dataset = processor.load_dataset() return processor, dataset