Zenith-7b-V1 / configs /data_config.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
"""Advanced Data Configuration for OpenThoughts-1.2M Integration"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
class DataDomain(Enum):
"""Supported data domains for multi-task training."""
CODE = "code"
MATHEMATICS = "mathematics"
SCIENCE = "science"
REASONING = "reasoning"
DIALOGUE = "dialogue"
INSTRUCTION = "instruction"
EQ = "emotional_intelligence"
@dataclass
class QualityFilterConfig:
"""Configuration for data quality filtering."""
min_length: int = 50
max_length: int = 32768
min_quality_score: float = 0.7
max_repetition_ratio: float = 0.3
language_confidence: float = 0.9
remove_pii: bool = True
remove_harmful: bool = True
remove_low_quality: bool = True
deduplicate: bool = True
min_thoughts_length: int = 100 # For OpenThoughts chain-of-thought
@dataclass
class CurriculumConfig:
"""Configuration for curriculum learning schedule."""
enable_curriculum: bool = True
stages: List[Dict[str, Any]] = field(default_factory=lambda: [
{"name": "foundation", "epoch": 1, "domains": [DataDomain.INSTRUCTION, DataDomain.DIALOGUE]},
{"name": "reasoning", "epoch": 2, "domains": [DataDomain.REASONING, DataDomain.MATHEMATICS]},
{"name": "code", "epoch": 3, "domains": [DataDomain.CODE]},
{"name": "full", "epoch": float('inf'), "domains": None}, # All domains
])
difficulty_weighting: bool = True
domain_weights: Dict[str, float] = field(default_factory=lambda: {
DataDomain.CODE.value: 1.0,
DataDomain.MATHEMATICS.value: 1.0,
DataDomain.SCIENCE.value: 0.8,
DataDomain.REASONING.value: 0.9,
DataDomain.DIALOGUE.value: 0.7,
DataDomain.INSTRUCTION.value: 0.8,
DataDomain.EQ.value: 0.6,
})
@dataclass
class DataConfig:
"""Ultra-sophisticated data pipeline configuration."""
# Dataset sources
openthoughts_dataset: str = "open-thoughts/OpenThoughts3-1.2M"
openthoughts_split: str = "train"
custom_datasets: List[str] = field(default_factory=list)
# Processing
tokenizer_name: str = "meta-llama/Llama-2-7b-hf"
max_seq_length: int = 8192
stride: int = 4096
add_special_tokens: bool = True
use_chat_template: bool = True
# Quality filtering
quality_filter: QualityFilterConfig = field(default_factory=QualityFilterConfig)
# Curriculum learning
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
# Multi-task weighting
task_weights: Dict[str, float] = field(default_factory=lambda: {
"next_token": 1.0,
"thoughts": 0.5, # Chain-of-thought loss
"eq_classification": 0.3,
"frustration_detection": 0.2,
})
# Data augmentation
enable_augmentation: bool = True
augmentation_ratio: float = 0.1
augmentation_methods: List[str] = field(default_factory=lambda: [
"synonym_replacement",
"back_translation",
"code_perturbation",
])
# Sampling
sampling_strategy: str = "curriculum" # "curriculum", "mixed", "domain_balanced"
temperature: float = 1.0
oversample_quality: bool = True
# Caching
cache_dir: str = "./data/cache"
use_cache: bool = True
prefetch_size: int = 1000
# Batch composition
dynamic_batching: bool = True
length_bucket_size: int = 128
mixed_length_batches: bool = True
def __post_init__(self):
"""Validate configuration."""
assert self.max_seq_length <= 32768, "Max sequence length cannot exceed 32768"
assert 0 < self.temperature <= 2.0, "Temperature must be in (0, 2]"
@dataclass
class OpenThoughtsFeatures:
"""Features specific to OpenThoughts dataset structure."""
conversation_key: str = "conversations"
thoughts_key: str = "thoughts"
reasoning_key: str = "reasoning"
quality_score_key: str = "quality_score"
domain_key: str = "domain"
difficulty_key: str = "difficulty"
language_key: str = "language"
# Expected structure
expected_columns: List[str] = field(default_factory=lambda: [
"conversations",
"thoughts",
"reasoning",
"quality_score",
"domain",
"difficulty",
])
# Conversation format
message_roles: List[str] = field(default_factory=lambda: ["user", "assistant", "system"])
thought_delimiters: List[str] = field(default_factory=lambda: ["<think>", "</think>"])