"""
Flint-1.2B Data Pipeline — Thought-Action Pretraining (TAP)
============================================================
Streams real data from HuggingFace Hub, applies TAP formatting,
tokenizes, packs into fixed-length sequences, and batches.
CRITICAL: This is what actually feeds the model. If this breaks,
the model trains on garbage (or nothing).
"""
import os
import json
import random as pyrandom
import traceback
from typing import Iterator, List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import numpy as np
from datasets import load_dataset, IterableDataset
from transformers import AutoTokenizer, PreTrainedTokenizer
# ============================================================
# SPECIAL TOKENS
# ============================================================
TAP_SPECIAL_TOKENS = [
"", "",
"", "",
"", "",
"", "",
"", "",
"<|pad|>",
]
def create_tokenizer(base_tokenizer: str = "HuggingFaceTB/cosmo2-tokenizer") -> PreTrainedTokenizer:
"""Load tokenizer and add TAP special tokens."""
tokenizer = AutoTokenizer.from_pretrained(base_tokenizer)
special_tokens_dict = {
"additional_special_tokens": TAP_SPECIAL_TOKENS[:-1],
"pad_token": "<|pad|>",
}
num_added = tokenizer.add_special_tokens(special_tokens_dict)
print(f"[Tokenizer] Base vocab: 49152 + {num_added} special = {len(tokenizer)} total")
print(f"[Tokenizer] Pad token: '{tokenizer.pad_token}' (id={tokenizer.pad_token_id})")
return tokenizer
# ============================================================
# FORMATTERS
# ============================================================
def format_raw_text(sample: Dict[str, Any]) -> str:
"""Raw text (FineWeb, DCLM, FineMath, code, etc.)."""
text = sample.get("text", "")
if isinstance(text, str):
return text
return ""
def format_reasoning_trace(sample: Dict[str, Any]) -> str:
"""OpenThoughts-114k: wrap assistant response in tags."""
conversations = sample.get("conversations", [])
if not conversations or not isinstance(conversations, list):
return ""
parts = []
for msg in conversations:
if not isinstance(msg, dict):
continue
role = msg.get("from", "")
content = msg.get("value", "")
if not content:
continue
if role == "user":
parts.append(content.strip())
parts.append("")
elif role == "assistant":
if "" in content:
parts.append(content.strip())
else:
paragraphs = content.strip().split("\n\n")
if len(paragraphs) > 1:
reasoning = "\n\n".join(paragraphs[:-1])
answer = paragraphs[-1]
parts.append(f"\n{reasoning}\n\n\n{answer}")
else:
parts.append(f"\n{content.strip()}\n")
return "\n".join(parts)
def format_tool_call(sample: Dict[str, Any]) -> str:
"""SmolTalk apigen-80k: messages with tool calls."""
messages = sample.get("messages", [])
if not messages or not isinstance(messages, list):
return ""
parts = []
for msg in messages:
if not isinstance(msg, dict):
continue
role = msg.get("role", "")
content = msg.get("content", "")
if not content:
continue
if role == "system":
parts.append(f"System: {content.strip()}\n")
elif role == "user":
parts.append(f"User: {content.strip()}\n")
elif role == "assistant":
if "" in content or '"name"' in content:
parts.append("\nI need to call a tool to answer this.\n\n")
parts.append(f"{content.strip()}\n")
return "\n".join(parts)
def format_agent_instruct(sample: Dict[str, Any]) -> str:
"""Orca-AgentInstruct: messages field is a JSON string."""
messages_raw = sample.get("messages", "")
# Parse JSON string
if isinstance(messages_raw, str):
try:
messages = json.loads(messages_raw)
except (json.JSONDecodeError, TypeError):
# If it's not JSON, just return as raw text
return messages_raw if len(messages_raw) > 20 else ""
elif isinstance(messages_raw, list):
messages = messages_raw
else:
return ""
if not messages:
return ""
parts = []
for msg in messages:
if not isinstance(msg, dict):
continue
role = msg.get("role", "")
content = msg.get("content", "")
if not content or not isinstance(content, str):
continue
if role == "system" and content.strip():
parts.append(f"System: {content.strip()}\n")
elif role == "user":
parts.append(f"User: {content.strip()}\n")
elif role == "assistant":
if len(content) > 500:
sentences = content.split(". ")
if len(sentences) > 4:
cut = max(len(sentences) // 4, 2)
thinking = ". ".join(sentences[:cut]) + "."
response = ". ".join(sentences[cut:])
parts.append(f"\n{thinking}\n\n\n{response}\n")
else:
parts.append(f"{content.strip()}\n")
else:
parts.append(f"{content.strip()}\n")
return "\n".join(parts)
FORMATTERS = {
"raw": format_raw_text,
"reasoning": format_reasoning_trace,
"tool_call": format_tool_call,
"agent": format_agent_instruct,
}
# ============================================================
# SEQUENCE PACKING
# ============================================================
class SequencePacker:
"""Pack documents into fixed-length sequences. Zero padding waste."""
def __init__(self, max_length: int, pad_token_id: int, eos_token_id: int):
self.max_length = max_length
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.buffer = []
def add_document(self, token_ids: List[int]) -> Optional[np.ndarray]:
"""Add tokens. Returns packed sequence when buffer is full."""
if not token_ids:
return None
# Truncate long documents
if len(token_ids) > self.max_length - 1:
token_ids = token_ids[:self.max_length - 1]
needed = len(token_ids) + 1 # +1 for EOS separator
if len(self.buffer) + needed > self.max_length:
result = self._emit()
self.buffer = token_ids + [self.eos_token_id]
return result
else:
self.buffer.extend(token_ids)
self.buffer.append(self.eos_token_id)
if len(self.buffer) >= self.max_length:
return self._emit()
return None
def _emit(self) -> np.ndarray:
seq = self.buffer[:self.max_length]
if len(seq) < self.max_length:
seq = seq + [self.pad_token_id] * (self.max_length - len(seq))
self.buffer = []
return np.array(seq, dtype=np.int32)
def flush(self) -> Optional[np.ndarray]:
if self.buffer:
return self._emit()
return None
# ============================================================
# DATA SOURCES
# ============================================================
@dataclass
class DataSource:
dataset_id: str
config_name: Optional[str]
weight: float
text_column: str
formatter: str
split: str = "train"
def get_stage_sources(stage: int) -> List[DataSource]:
"""Data sources per curriculum stage. Ordered by reliability."""
if stage == 1:
return [
# High-reliability sources first
DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.50, "text", "raw"),
DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.15, "text", "raw"),
DataSource("open-thoughts/OpenThoughts-114k", None, 0.15, "conversations", "reasoning"),
DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.08, "messages", "tool_call"),
DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.05, "text", "raw"),
# Lower priority (may be slow to stream)
DataSource("mlfoundations/dclm-baseline-1.0-parquet", None, 0.07, "text", "raw"),
]
elif stage == 2:
return [
DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.25, "text", "raw"),
DataSource("open-thoughts/OpenThoughts-114k", None, 0.20, "conversations", "reasoning"),
DataSource("HuggingFaceTB/finemath", "finemath-3plus", 0.15, "text", "raw"),
DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.10, "messages", "tool_call"),
DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.10, "text", "raw"),
DataSource("open-web-math/open-web-math", None, 0.08, "text", "raw"),
DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.07, "text", "raw"),
DataSource("microsoft/orca-agentinstruct-1M-v1", "creative_content", 0.05, "messages", "agent"),
]
else: # stage 3 (annealing)
return [
DataSource("HuggingFaceTB/finemath", "finemath-4plus", 0.25, "text", "raw"),
DataSource("open-thoughts/OpenThoughts-114k", None, 0.25, "conversations", "reasoning"),
DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.20, "text", "raw"),
DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.12, "messages", "tool_call"),
DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.08, "text", "raw"),
DataSource("microsoft/orca-agentinstruct-1M-v1", "analytical_reasoning", 0.10, "messages", "agent"),
]
# ============================================================
# PIPELINE
# ============================================================
def get_current_stage(step: int, config) -> int:
"""Determine curriculum stage from step number."""
stage1_end = int(config.total_steps * config.stage1_end_frac)
stage2_end = int(config.total_steps * config.stage2_end_frac)
if step < stage1_end:
return 1
elif step < stage2_end:
return 2
return 3
class TAPDataPipeline:
"""
Streams real data from HF Hub, formats with TAP tags, packs sequences.
Robust to individual dataset failures — skips broken sources and
redistributes weight to working ones.
"""
def __init__(self, config, tokenizer, start_step=0, start_position=0):
self.config = config
self.tokenizer = tokenizer
self.current_stage = get_current_stage(start_step, config)
self.position = start_position
self.step = start_step
self.samples_processed = 0
self.samples_failed = 0
self.packer = SequencePacker(
max_length=config.max_seq_len,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0,
)
self._load_stage_datasets()
def _load_stage_datasets(self):
"""Load datasets for current stage. Skip any that fail."""
sources = get_stage_sources(self.current_stage)
self.datasets = []
self.weights = []
print(f"[Data] Loading stage {self.current_stage} datasets...")
for source in sources:
try:
kwargs = {"path": source.dataset_id, "split": source.split, "streaming": True}
if source.config_name:
kwargs["name"] = source.config_name
ds = load_dataset(**kwargs)
self.datasets.append((ds, source))
self.weights.append(source.weight)
print(f" ✓ {source.dataset_id}/{source.config_name or ''} ({source.weight:.0%}) [{source.formatter}]")
except Exception as e:
print(f" ✗ {source.dataset_id}/{source.config_name or ''}: {e}")
if not self.datasets:
raise RuntimeError("[Data] FATAL: No datasets loaded! Check network/auth.")
# Normalize weights
total_w = sum(self.weights)
self.weights = [w / total_w for w in self.weights]
# Create iterators
self.iterators = [iter(ds) for ds, _ in self.datasets]
print(f"[Data] ✓ {len(self.datasets)} sources ready for stage {self.current_stage}\n")
def _get_sample(self) -> Tuple[Optional[Dict], Optional[DataSource]]:
"""Get one sample from weighted random source. Handles StopIteration."""
if not self.datasets:
return None, None
idx = pyrandom.choices(range(len(self.datasets)), weights=self.weights, k=1)[0]
_, source = self.datasets[idx]
try:
sample = next(self.iterators[idx])
return sample, source
except StopIteration:
# Restart iterator (epoch boundary)
ds, _ = self.datasets[idx]
self.iterators[idx] = iter(ds)
try:
sample = next(self.iterators[idx])
return sample, source
except StopIteration:
return None, None
except Exception as e:
# Network glitch, rate limit, etc. — skip this sample
self.samples_failed += 1
if self.samples_failed % 100 == 0:
print(f"[Data] Warning: {self.samples_failed} samples failed ({e})")
return None, None
def _tokenize(self, sample: Dict, source: DataSource) -> List[int]:
"""Format + tokenize. Returns empty list on failure."""
try:
formatter = FORMATTERS[source.formatter]
text = formatter(sample)
if not text or len(text.strip()) < 20:
return []
tokens = self.tokenizer.encode(text, add_special_tokens=False)
if len(tokens) < 5:
return []
self.samples_processed += 1
return tokens
except Exception:
self.samples_failed += 1
return []
def get_batch(self, batch_size: int, step: int) -> np.ndarray:
"""Get one batch of packed sequences."""
# Check stage switch
new_stage = get_current_stage(step, self.config)
if new_stage != self.current_stage:
print(f"\n[Data] ═══ Stage switch: {self.current_stage} → {new_stage} ═══")
self.current_stage = new_stage
self._load_stage_datasets()
sequences = []
max_attempts = batch_size * 200
for _ in range(max_attempts):
if len(sequences) >= batch_size:
break
sample, source = self._get_sample()
if sample is None:
continue
tokens = self._tokenize(sample, source)
if not tokens:
continue
result = self.packer.add_document(tokens)
if result is not None:
sequences.append(result)
# Flush if needed
while len(sequences) < batch_size:
result = self.packer.flush()
if result is not None:
sequences.append(result)
else:
# Absolute last resort — should never happen with working data
print("[Data] ⚠️ Had to insert padding sequence!")
sequences.append(
np.full(self.config.max_seq_len, self.tokenizer.pad_token_id, dtype=np.int32)
)
return np.stack(sequences[:batch_size])
def __iter__(self):
step = self.step
while True:
yield self.get_batch(self.config.global_batch_size, step)
step += 1
def create_data_pipeline(config, tokenizer, start_step=0, start_position=0) -> Iterator:
"""Create the streaming data pipeline. Returns infinite batch iterator."""
pipeline = TAPDataPipeline(config, tokenizer, start_step, start_position)
return iter(pipeline)