Zenith-7b-V1 / data /openthoughts_processor.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
"""OpenThoughts-1.2M Advanced Processor with Multi-Source Integration"""
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, Iterator
import numpy as np
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from torch.utils.data import DataLoader, IterableDataset
from .quality_filter import QualityFilter, QualityMetrics
from .curriculum_sampler import CurriculumSampler
from .data_augmentation import DataAugmenter
from .preprocessing import (
preprocess_conversation,
extract_thoughts,
format_for_training,
detect_domain,
estimate_difficulty,
)
from .utils import compute_length_statistics
logger = logging.getLogger(__name__)
@dataclass
class OpenThoughtsConfig:
"""Configuration for OpenThoughts processing."""
dataset_name: str = "open-thoughts/OpenThoughts3-1.2M"
split: str = "train"
cache_dir: str = "./data/cache"
streaming: bool = False
max_samples: Optional[int] = None
# Processing options
include_thoughts: bool = True
include_reasoning: bool = True
include_conversations: bool = True
min_quality_score: float = 0.7
min_thoughts_length: int = 100
# Filtering
quality_filter: QualityFilter = field(default_factory=QualityFilter)
# Curriculum
use_curriculum: bool = True
curriculum_sampler: Optional[CurriculumSampler] = None
# Augmentation
use_augmentation: bool = False
augmenter: Optional[DataAugmenter] = None
augmentation_prob: float = 0.1
# Tokenization
tokenizer: Optional[Any] = None
max_seq_length: int = 8192
# Multi-dataset mixing
custom_datasets: List[str] = field(default_factory=list)
dataset_weights: List[float] = field(default_factory=list)
class OpenThoughtsProcessor:
"""Advanced processor for OpenThoughts-1.2M with quality filtering, curriculum, and augmentation."""
def __init__(self, config: OpenThoughtsConfig):
self.config = config
self.quality_filter = config.quality_filter
self.augmenter = config.augmenter if config.use_augmentation else None
self.curriculum_sampler = config.curriculum_sampler
self.tokenizer = config.tokenizer
# Statistics
self.stats: Dict[str, Any] = {}
self._processed_count = 0
self._filtered_count = 0
def load_dataset(self) -> Dataset:
"""Load OpenThoughts dataset with optional custom datasets."""
logger.info(f"Loading OpenThoughts dataset: {self.config.dataset_name}")
try:
# Load main dataset
ds = load_dataset(
self.config.dataset_name,
split=self.config.split,
cache_dir=self.config.cache_dir,
streaming=self.config.streaming,
)
logger.info(f"Loaded {len(ds) if not self.config.streaming else 'streaming'} samples")
# Apply initial filtering
ds = self._apply_initial_filters(ds)
# Load and mix custom datasets if specified
if self.config.custom_datasets:
ds = self._mix_datasets(ds)
# Compute statistics
self._compute_statistics(ds)
return ds
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
def _apply_initial_filters(self, ds: Dataset) -> Dataset:
"""Apply quality and content filters."""
logger.info("Applying initial filters...")
# Check required columns
required_cols = ["conversations"]
for col in required_cols:
if col not in ds.column_names:
raise ValueError(f"Missing required column: {col}")
# Filter by quality score if available
if "quality_score" in ds.column_names:
initial_size = len(ds)
ds = ds.filter(lambda x: x.get("quality_score", 1.0) >= self.config.min_quality_score)
logger.info(f"Quality filter: {initial_size} -> {len(ds)}")
# Filter by thoughts length if required
if self.config.include_thoughts and self.config.min_thoughts_length > 0:
initial_size = len(ds)
ds = ds.filter(lambda x: len(x.get("thoughts", "")) >= self.config.min_thoughts_length)
logger.info(f"Thoughts length filter: {initial_size} -> {len(ds)}")
self._filtered_count = len(ds)
return ds
def _mix_datasets(self, main_ds: Dataset) -> Dataset:
"""Mix main dataset with custom datasets according to weights."""
logger.info(f"Mixing with custom datasets: {self.config.custom_datasets}")
datasets = [main_ds]
weights = [1.0] # Main dataset weight
for custom_path in self.config.custom_datasets:
try:
custom_ds = load_dataset(custom_path, split=self.config.split, cache_dir=self.config.cache_dir)
datasets.append(custom_ds)
weights.append(self.config.dataset_weights.pop(0) if self.config.dataset_weights else 1.0)
logger.info(f"Loaded custom dataset: {custom_path} ({len(custom_ds)} samples)")
except Exception as e:
logger.warning(f"Failed to load custom dataset {custom_path}: {e}")
if len(datasets) > 1:
# Normalize weights
total_weight = sum(weights)
weights = [w / total_weight for w in weights]
# Interleave datasets according to weights
# For simplicity, we concatenate and will sample later
mixed = concatenate_datasets(datasets)
logger.info(f"Mixed dataset size: {len(mixed)}")
return mixed
else:
return main_ds
def _compute_statistics(self, ds: Dataset):
"""Compute dataset statistics."""
logger.info("Computing dataset statistics...")
# Sample for analysis
sample_size = min(1000, len(ds) if not self.config.streaming else 1000)
if self.config.streaming:
samples = list(ds.take(sample_size))
else:
samples = ds.select(range(sample_size))
# Length statistics
lengths = []
for sample in samples:
conv = sample.get("conversations", "")
if isinstance(conv, str):
lengths.append(len(conv))
elif isinstance(conv, list):
lengths.append(sum(len(msg.get("content", "")) for msg in conv))
length_stats = compute_length_statistics(lengths)
self.stats = {
"total_samples": len(ds) if not self.config.streaming else "streaming",
"avg_length": length_stats["mean"],
"length_std": length_stats["std"],
"min_length": length_stats["min"],
"max_length": length_stats["max"],
"p90_length": length_stats["p90"],
"p99_length": length_stats["p99"],
}
logger.info(f"Dataset stats: {self.stats}")
def preprocess_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocess a single sample with all transformations."""
self._processed_count += 1
# Extract conversations
conversations = sample.get("conversations", [])
if isinstance(conversations, str):
conversations = json.loads(conversations) if isinstance(conversations, str) else conversations
# Preprocess conversation
processed = preprocess_conversation(
conversations,
include_thoughts=self.config.include_thoughts,
include_reasoning=self.config.include_reasoning,
)
# Extract thoughts if available
if "thoughts" in sample:
thoughts = extract_thoughts(sample["thoughts"])
processed["thoughts"] = thoughts
# Detect domain
if "domain" not in processed:
processed["domain"] = detect_domain(conversations)
# Estimate difficulty
if "difficulty" not in processed:
processed["difficulty"] = estimate_difficulty(conversations, sample.get("thoughts", ""))
# Quality metrics
quality_metrics = self.quality_filter.compute_metrics(processed)
processed["quality_metrics"] = quality_metrics
# Apply augmentation if enabled
if self.augmenter and np.random.random() < self.config.augmentation_prob:
augmented = self.augmenter.augment(processed)
if augmented:
processed.update(augmented)
# Tokenize if tokenizer provided
if self.tokenizer:
tokenized = self._tokenize_sample(processed)
processed.update(tokenized)
return processed
def _tokenize_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Tokenize sample with advanced options."""
if not self.tokenizer:
return {}
text = format_for_training(sample, include_thoughts=self.config.include_thoughts)
# Tokenize with truncation
tokenized = self.tokenizer(
text,
truncation=True,
max_length=self.config.max_seq_length,
padding="max_length",
return_tensors="pt",
)
# Create labels (shifted for causal LM)
input_ids = tokenized["input_ids"][0]
attention_mask = tokenized["attention_mask"][0]
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
# Mask out non-target tokens if needed
if self.config.include_thoughts and "thoughts" in sample:
# Could mask specific parts for multi-task learning
pass
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def create_dataloader(
self,
ds: Dataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
curriculum_epoch: int = 0,
) -> DataLoader:
"""Create DataLoader with optional curriculum sampling."""
# Apply preprocessing to entire dataset if not streaming
if not self.config.streaming:
logger.info("Preprocessing entire dataset...")
ds = ds.map(
self.preprocess_sample,
batched=False,
num_proc=num_workers,
remove_columns=ds.column_names,
desc="Preprocessing",
)
# Apply curriculum sampling if configured
if self.curriculum_sampler and shuffle:
sampler = self.curriculum_sampler.get_sampler(ds, epoch=curriculum_epoch)
shuffle = False # Sampler handles shuffling
else:
sampler = None
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
)
class OpenThoughtsDataset(IterableDataset):
"""Streaming dataset for large-scale OpenThoughts training."""
def __init__(
self,
processor: OpenThoughtsProcessor,
infinite: bool = False,
):
self.processor = processor
self.infinite = infinite
self._buffer = []
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Iterate over samples."""
while True:
# Load dataset (or use cached iterator)
if not self._buffer or self.processor.config.streaming:
ds = self.processor.load_dataset()
for sample in ds:
processed = self.processor.preprocess_sample(sample)
yield processed
if not self.infinite:
break
def __len__(self) -> int:
"""Return approximate length."""
if self.processor.config.streaming:
return float("inf")
return self.processor._filtered_count
def create_openthoughts_pipeline(
config: OpenThoughtsConfig,
tokenizer: Optional[Any] = None,
) -> Tuple[OpenThoughtsProcessor, Dataset]:
"""Factory function to create complete OpenThoughts pipeline."""
config.tokenizer = tokenizer
processor = OpenThoughtsProcessor(config)
dataset = processor.load_dataset()
return processor, dataset