| """ |
| Flint-1.2B Data Pipeline β Thought-Action Pretraining (TAP) |
| ============================================================ |
| |
| Streams real data from HuggingFace Hub, applies TAP formatting, |
| tokenizes, packs into fixed-length sequences, and batches. |
| |
| CRITICAL: This is what actually feeds the model. If this breaks, |
| the model trains on garbage (or nothing). |
| """ |
|
|
| import os |
| import json |
| import random as pyrandom |
| import traceback |
| from typing import Iterator, List, Dict, Any, Optional, Tuple |
| from dataclasses import dataclass |
|
|
| import numpy as np |
| from datasets import load_dataset, IterableDataset |
| from transformers import AutoTokenizer, PreTrainedTokenizer |
|
|
|
|
| |
| |
| |
|
|
| TAP_SPECIAL_TOKENS = [ |
| "<think>", "</think>", |
| "<tool_call>", "</tool_call>", |
| "<tool_response>", "</tool_response>", |
| "<observe>", "</observe>", |
| "<act>", "</act>", |
| "<|pad|>", |
| ] |
|
|
|
|
| def create_tokenizer(base_tokenizer: str = "HuggingFaceTB/cosmo2-tokenizer") -> PreTrainedTokenizer: |
| """Load tokenizer and add TAP special tokens.""" |
| tokenizer = AutoTokenizer.from_pretrained(base_tokenizer) |
| special_tokens_dict = { |
| "additional_special_tokens": TAP_SPECIAL_TOKENS[:-1], |
| "pad_token": "<|pad|>", |
| } |
| num_added = tokenizer.add_special_tokens(special_tokens_dict) |
| print(f"[Tokenizer] Base vocab: 49152 + {num_added} special = {len(tokenizer)} total") |
| print(f"[Tokenizer] Pad token: '{tokenizer.pad_token}' (id={tokenizer.pad_token_id})") |
| return tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def format_raw_text(sample: Dict[str, Any]) -> str: |
| """Raw text (FineWeb, DCLM, FineMath, code, etc.).""" |
| text = sample.get("text", "") |
| if isinstance(text, str): |
| return text |
| return "" |
|
|
|
|
| def format_reasoning_trace(sample: Dict[str, Any]) -> str: |
| """OpenThoughts-114k: wrap assistant response in <think> tags.""" |
| conversations = sample.get("conversations", []) |
| if not conversations or not isinstance(conversations, list): |
| return "" |
|
|
| parts = [] |
| for msg in conversations: |
| if not isinstance(msg, dict): |
| continue |
| role = msg.get("from", "") |
| content = msg.get("value", "") |
| if not content: |
| continue |
|
|
| if role == "user": |
| parts.append(content.strip()) |
| parts.append("") |
| elif role == "assistant": |
| if "<think>" in content: |
| parts.append(content.strip()) |
| else: |
| paragraphs = content.strip().split("\n\n") |
| if len(paragraphs) > 1: |
| reasoning = "\n\n".join(paragraphs[:-1]) |
| answer = paragraphs[-1] |
| parts.append(f"<think>\n{reasoning}\n</think>\n\n{answer}") |
| else: |
| parts.append(f"<think>\n{content.strip()}\n</think>") |
|
|
| return "\n".join(parts) |
|
|
|
|
| def format_tool_call(sample: Dict[str, Any]) -> str: |
| """SmolTalk apigen-80k: messages with tool calls.""" |
| messages = sample.get("messages", []) |
| if not messages or not isinstance(messages, list): |
| return "" |
|
|
| parts = [] |
| for msg in messages: |
| if not isinstance(msg, dict): |
| continue |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if not content: |
| continue |
|
|
| if role == "system": |
| parts.append(f"System: {content.strip()}\n") |
| elif role == "user": |
| parts.append(f"User: {content.strip()}\n") |
| elif role == "assistant": |
| if "<tool_call>" in content or '"name"' in content: |
| parts.append("<think>\nI need to call a tool to answer this.\n</think>\n") |
| parts.append(f"{content.strip()}\n") |
|
|
| return "\n".join(parts) |
|
|
|
|
| def format_agent_instruct(sample: Dict[str, Any]) -> str: |
| """Orca-AgentInstruct: messages field is a JSON string.""" |
| messages_raw = sample.get("messages", "") |
|
|
| |
| if isinstance(messages_raw, str): |
| try: |
| messages = json.loads(messages_raw) |
| except (json.JSONDecodeError, TypeError): |
| |
| return messages_raw if len(messages_raw) > 20 else "" |
| elif isinstance(messages_raw, list): |
| messages = messages_raw |
| else: |
| return "" |
|
|
| if not messages: |
| return "" |
|
|
| parts = [] |
| for msg in messages: |
| if not isinstance(msg, dict): |
| continue |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if not content or not isinstance(content, str): |
| continue |
|
|
| if role == "system" and content.strip(): |
| parts.append(f"System: {content.strip()}\n") |
| elif role == "user": |
| parts.append(f"User: {content.strip()}\n") |
| elif role == "assistant": |
| if len(content) > 500: |
| sentences = content.split(". ") |
| if len(sentences) > 4: |
| cut = max(len(sentences) // 4, 2) |
| thinking = ". ".join(sentences[:cut]) + "." |
| response = ". ".join(sentences[cut:]) |
| parts.append(f"<think>\n{thinking}\n</think>\n\n{response}\n") |
| else: |
| parts.append(f"{content.strip()}\n") |
| else: |
| parts.append(f"{content.strip()}\n") |
|
|
| return "\n".join(parts) |
|
|
|
|
| FORMATTERS = { |
| "raw": format_raw_text, |
| "reasoning": format_reasoning_trace, |
| "tool_call": format_tool_call, |
| "agent": format_agent_instruct, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class SequencePacker: |
| """Pack documents into fixed-length sequences. Zero padding waste.""" |
|
|
| def __init__(self, max_length: int, pad_token_id: int, eos_token_id: int): |
| self.max_length = max_length |
| self.pad_token_id = pad_token_id |
| self.eos_token_id = eos_token_id |
| self.buffer = [] |
|
|
| def add_document(self, token_ids: List[int]) -> Optional[np.ndarray]: |
| """Add tokens. Returns packed sequence when buffer is full.""" |
| if not token_ids: |
| return None |
|
|
| |
| if len(token_ids) > self.max_length - 1: |
| token_ids = token_ids[:self.max_length - 1] |
|
|
| needed = len(token_ids) + 1 |
|
|
| if len(self.buffer) + needed > self.max_length: |
| result = self._emit() |
| self.buffer = token_ids + [self.eos_token_id] |
| return result |
| else: |
| self.buffer.extend(token_ids) |
| self.buffer.append(self.eos_token_id) |
| if len(self.buffer) >= self.max_length: |
| return self._emit() |
| return None |
|
|
| def _emit(self) -> np.ndarray: |
| seq = self.buffer[:self.max_length] |
| if len(seq) < self.max_length: |
| seq = seq + [self.pad_token_id] * (self.max_length - len(seq)) |
| self.buffer = [] |
| return np.array(seq, dtype=np.int32) |
|
|
| def flush(self) -> Optional[np.ndarray]: |
| if self.buffer: |
| return self._emit() |
| return None |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class DataSource: |
| dataset_id: str |
| config_name: Optional[str] |
| weight: float |
| text_column: str |
| formatter: str |
| split: str = "train" |
|
|
|
|
| def get_stage_sources(stage: int) -> List[DataSource]: |
| """Data sources per curriculum stage. Ordered by reliability.""" |
| if stage == 1: |
| return [ |
| |
| DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.50, "text", "raw"), |
| DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.15, "text", "raw"), |
| DataSource("open-thoughts/OpenThoughts-114k", None, 0.15, "conversations", "reasoning"), |
| DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.08, "messages", "tool_call"), |
| DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.05, "text", "raw"), |
| |
| DataSource("mlfoundations/dclm-baseline-1.0-parquet", None, 0.07, "text", "raw"), |
| ] |
| elif stage == 2: |
| return [ |
| DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.25, "text", "raw"), |
| DataSource("open-thoughts/OpenThoughts-114k", None, 0.20, "conversations", "reasoning"), |
| DataSource("HuggingFaceTB/finemath", "finemath-3plus", 0.15, "text", "raw"), |
| DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.10, "messages", "tool_call"), |
| DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.10, "text", "raw"), |
| DataSource("open-web-math/open-web-math", None, 0.08, "text", "raw"), |
| DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.07, "text", "raw"), |
| DataSource("microsoft/orca-agentinstruct-1M-v1", "creative_content", 0.05, "messages", "agent"), |
| ] |
| else: |
| return [ |
| DataSource("HuggingFaceTB/finemath", "finemath-4plus", 0.25, "text", "raw"), |
| DataSource("open-thoughts/OpenThoughts-114k", None, 0.25, "conversations", "reasoning"), |
| DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.20, "text", "raw"), |
| DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.12, "messages", "tool_call"), |
| DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.08, "text", "raw"), |
| DataSource("microsoft/orca-agentinstruct-1M-v1", "analytical_reasoning", 0.10, "messages", "agent"), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def get_current_stage(step: int, config) -> int: |
| """Determine curriculum stage from step number.""" |
| stage1_end = int(config.total_steps * config.stage1_end_frac) |
| stage2_end = int(config.total_steps * config.stage2_end_frac) |
| if step < stage1_end: |
| return 1 |
| elif step < stage2_end: |
| return 2 |
| return 3 |
|
|
|
|
| class TAPDataPipeline: |
| """ |
| Streams real data from HF Hub, formats with TAP tags, packs sequences. |
| |
| Robust to individual dataset failures β skips broken sources and |
| redistributes weight to working ones. |
| """ |
|
|
| def __init__(self, config, tokenizer, start_step=0, start_position=0): |
| self.config = config |
| self.tokenizer = tokenizer |
| self.current_stage = get_current_stage(start_step, config) |
| self.position = start_position |
| self.step = start_step |
| self.samples_processed = 0 |
| self.samples_failed = 0 |
|
|
| self.packer = SequencePacker( |
| max_length=config.max_seq_len, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0, |
| ) |
|
|
| self._load_stage_datasets() |
|
|
| def _load_stage_datasets(self): |
| """Load datasets for current stage. Skip any that fail.""" |
| sources = get_stage_sources(self.current_stage) |
| self.datasets = [] |
| self.weights = [] |
|
|
| print(f"[Data] Loading stage {self.current_stage} datasets...") |
| for source in sources: |
| try: |
| kwargs = {"path": source.dataset_id, "split": source.split, "streaming": True} |
| if source.config_name: |
| kwargs["name"] = source.config_name |
| ds = load_dataset(**kwargs) |
| self.datasets.append((ds, source)) |
| self.weights.append(source.weight) |
| print(f" β {source.dataset_id}/{source.config_name or ''} ({source.weight:.0%}) [{source.formatter}]") |
| except Exception as e: |
| print(f" β {source.dataset_id}/{source.config_name or ''}: {e}") |
|
|
| if not self.datasets: |
| raise RuntimeError("[Data] FATAL: No datasets loaded! Check network/auth.") |
|
|
| |
| total_w = sum(self.weights) |
| self.weights = [w / total_w for w in self.weights] |
|
|
| |
| self.iterators = [iter(ds) for ds, _ in self.datasets] |
| print(f"[Data] β {len(self.datasets)} sources ready for stage {self.current_stage}\n") |
|
|
| def _get_sample(self) -> Tuple[Optional[Dict], Optional[DataSource]]: |
| """Get one sample from weighted random source. Handles StopIteration.""" |
| if not self.datasets: |
| return None, None |
|
|
| idx = pyrandom.choices(range(len(self.datasets)), weights=self.weights, k=1)[0] |
| _, source = self.datasets[idx] |
|
|
| try: |
| sample = next(self.iterators[idx]) |
| return sample, source |
| except StopIteration: |
| |
| ds, _ = self.datasets[idx] |
| self.iterators[idx] = iter(ds) |
| try: |
| sample = next(self.iterators[idx]) |
| return sample, source |
| except StopIteration: |
| return None, None |
| except Exception as e: |
| |
| self.samples_failed += 1 |
| if self.samples_failed % 100 == 0: |
| print(f"[Data] Warning: {self.samples_failed} samples failed ({e})") |
| return None, None |
|
|
| def _tokenize(self, sample: Dict, source: DataSource) -> List[int]: |
| """Format + tokenize. Returns empty list on failure.""" |
| try: |
| formatter = FORMATTERS[source.formatter] |
| text = formatter(sample) |
| if not text or len(text.strip()) < 20: |
| return [] |
| tokens = self.tokenizer.encode(text, add_special_tokens=False) |
| if len(tokens) < 5: |
| return [] |
| self.samples_processed += 1 |
| return tokens |
| except Exception: |
| self.samples_failed += 1 |
| return [] |
|
|
| def get_batch(self, batch_size: int, step: int) -> np.ndarray: |
| """Get one batch of packed sequences.""" |
| |
| new_stage = get_current_stage(step, self.config) |
| if new_stage != self.current_stage: |
| print(f"\n[Data] βββ Stage switch: {self.current_stage} β {new_stage} βββ") |
| self.current_stage = new_stage |
| self._load_stage_datasets() |
|
|
| sequences = [] |
| max_attempts = batch_size * 200 |
|
|
| for _ in range(max_attempts): |
| if len(sequences) >= batch_size: |
| break |
|
|
| sample, source = self._get_sample() |
| if sample is None: |
| continue |
|
|
| tokens = self._tokenize(sample, source) |
| if not tokens: |
| continue |
|
|
| result = self.packer.add_document(tokens) |
| if result is not None: |
| sequences.append(result) |
|
|
| |
| while len(sequences) < batch_size: |
| result = self.packer.flush() |
| if result is not None: |
| sequences.append(result) |
| else: |
| |
| print("[Data] β οΈ Had to insert padding sequence!") |
| sequences.append( |
| np.full(self.config.max_seq_len, self.tokenizer.pad_token_id, dtype=np.int32) |
| ) |
|
|
| return np.stack(sequences[:batch_size]) |
|
|
| def __iter__(self): |
| step = self.step |
| while True: |
| yield self.get_batch(self.config.global_batch_size, step) |
| step += 1 |
|
|
|
|
| def create_data_pipeline(config, tokenizer, start_step=0, start_position=0) -> Iterator: |
| """Create the streaming data pipeline. Returns infinite batch iterator.""" |
| pipeline = TAPDataPipeline(config, tokenizer, start_step, start_position) |
| return iter(pipeline) |
|
|