""" Flint-1.2B Data Pipeline — Thought-Action Pretraining (TAP) ============================================================ Streams real data from HuggingFace Hub, applies TAP formatting, tokenizes, packs into fixed-length sequences, and batches. CRITICAL: This is what actually feeds the model. If this breaks, the model trains on garbage (or nothing). """ import os import json import random as pyrandom import traceback from typing import Iterator, List, Dict, Any, Optional, Tuple from dataclasses import dataclass import numpy as np from datasets import load_dataset, IterableDataset from transformers import AutoTokenizer, PreTrainedTokenizer # ============================================================ # SPECIAL TOKENS # ============================================================ TAP_SPECIAL_TOKENS = [ "", "", "", "", "", "", "", "", "", "", "<|pad|>", ] def create_tokenizer(base_tokenizer: str = "HuggingFaceTB/cosmo2-tokenizer") -> PreTrainedTokenizer: """Load tokenizer and add TAP special tokens.""" tokenizer = AutoTokenizer.from_pretrained(base_tokenizer) special_tokens_dict = { "additional_special_tokens": TAP_SPECIAL_TOKENS[:-1], "pad_token": "<|pad|>", } num_added = tokenizer.add_special_tokens(special_tokens_dict) print(f"[Tokenizer] Base vocab: 49152 + {num_added} special = {len(tokenizer)} total") print(f"[Tokenizer] Pad token: '{tokenizer.pad_token}' (id={tokenizer.pad_token_id})") return tokenizer # ============================================================ # FORMATTERS # ============================================================ def format_raw_text(sample: Dict[str, Any]) -> str: """Raw text (FineWeb, DCLM, FineMath, code, etc.).""" text = sample.get("text", "") if isinstance(text, str): return text return "" def format_reasoning_trace(sample: Dict[str, Any]) -> str: """OpenThoughts-114k: wrap assistant response in tags.""" conversations = sample.get("conversations", []) if not conversations or not isinstance(conversations, list): return "" parts = [] for msg in conversations: if not isinstance(msg, dict): continue role = msg.get("from", "") content = msg.get("value", "") if not content: continue if role == "user": parts.append(content.strip()) parts.append("") elif role == "assistant": if "" in content: parts.append(content.strip()) else: paragraphs = content.strip().split("\n\n") if len(paragraphs) > 1: reasoning = "\n\n".join(paragraphs[:-1]) answer = paragraphs[-1] parts.append(f"\n{reasoning}\n\n\n{answer}") else: parts.append(f"\n{content.strip()}\n") return "\n".join(parts) def format_tool_call(sample: Dict[str, Any]) -> str: """SmolTalk apigen-80k: messages with tool calls.""" messages = sample.get("messages", []) if not messages or not isinstance(messages, list): return "" parts = [] for msg in messages: if not isinstance(msg, dict): continue role = msg.get("role", "") content = msg.get("content", "") if not content: continue if role == "system": parts.append(f"System: {content.strip()}\n") elif role == "user": parts.append(f"User: {content.strip()}\n") elif role == "assistant": if "" in content or '"name"' in content: parts.append("\nI need to call a tool to answer this.\n\n") parts.append(f"{content.strip()}\n") return "\n".join(parts) def format_agent_instruct(sample: Dict[str, Any]) -> str: """Orca-AgentInstruct: messages field is a JSON string.""" messages_raw = sample.get("messages", "") # Parse JSON string if isinstance(messages_raw, str): try: messages = json.loads(messages_raw) except (json.JSONDecodeError, TypeError): # If it's not JSON, just return as raw text return messages_raw if len(messages_raw) > 20 else "" elif isinstance(messages_raw, list): messages = messages_raw else: return "" if not messages: return "" parts = [] for msg in messages: if not isinstance(msg, dict): continue role = msg.get("role", "") content = msg.get("content", "") if not content or not isinstance(content, str): continue if role == "system" and content.strip(): parts.append(f"System: {content.strip()}\n") elif role == "user": parts.append(f"User: {content.strip()}\n") elif role == "assistant": if len(content) > 500: sentences = content.split(". ") if len(sentences) > 4: cut = max(len(sentences) // 4, 2) thinking = ". ".join(sentences[:cut]) + "." response = ". ".join(sentences[cut:]) parts.append(f"\n{thinking}\n\n\n{response}\n") else: parts.append(f"{content.strip()}\n") else: parts.append(f"{content.strip()}\n") return "\n".join(parts) FORMATTERS = { "raw": format_raw_text, "reasoning": format_reasoning_trace, "tool_call": format_tool_call, "agent": format_agent_instruct, } # ============================================================ # SEQUENCE PACKING # ============================================================ class SequencePacker: """Pack documents into fixed-length sequences. Zero padding waste.""" def __init__(self, max_length: int, pad_token_id: int, eos_token_id: int): self.max_length = max_length self.pad_token_id = pad_token_id self.eos_token_id = eos_token_id self.buffer = [] def add_document(self, token_ids: List[int]) -> Optional[np.ndarray]: """Add tokens. Returns packed sequence when buffer is full.""" if not token_ids: return None # Truncate long documents if len(token_ids) > self.max_length - 1: token_ids = token_ids[:self.max_length - 1] needed = len(token_ids) + 1 # +1 for EOS separator if len(self.buffer) + needed > self.max_length: result = self._emit() self.buffer = token_ids + [self.eos_token_id] return result else: self.buffer.extend(token_ids) self.buffer.append(self.eos_token_id) if len(self.buffer) >= self.max_length: return self._emit() return None def _emit(self) -> np.ndarray: seq = self.buffer[:self.max_length] if len(seq) < self.max_length: seq = seq + [self.pad_token_id] * (self.max_length - len(seq)) self.buffer = [] return np.array(seq, dtype=np.int32) def flush(self) -> Optional[np.ndarray]: if self.buffer: return self._emit() return None # ============================================================ # DATA SOURCES # ============================================================ @dataclass class DataSource: dataset_id: str config_name: Optional[str] weight: float text_column: str formatter: str split: str = "train" def get_stage_sources(stage: int) -> List[DataSource]: """Data sources per curriculum stage. Ordered by reliability.""" if stage == 1: return [ # High-reliability sources first DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.50, "text", "raw"), DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.15, "text", "raw"), DataSource("open-thoughts/OpenThoughts-114k", None, 0.15, "conversations", "reasoning"), DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.08, "messages", "tool_call"), DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.05, "text", "raw"), # Lower priority (may be slow to stream) DataSource("mlfoundations/dclm-baseline-1.0-parquet", None, 0.07, "text", "raw"), ] elif stage == 2: return [ DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.25, "text", "raw"), DataSource("open-thoughts/OpenThoughts-114k", None, 0.20, "conversations", "reasoning"), DataSource("HuggingFaceTB/finemath", "finemath-3plus", 0.15, "text", "raw"), DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.10, "messages", "tool_call"), DataSource("HuggingFaceTB/smollm-corpus", "cosmopedia-v2", 0.10, "text", "raw"), DataSource("open-web-math/open-web-math", None, 0.08, "text", "raw"), DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.07, "text", "raw"), DataSource("microsoft/orca-agentinstruct-1M-v1", "creative_content", 0.05, "messages", "agent"), ] else: # stage 3 (annealing) return [ DataSource("HuggingFaceTB/finemath", "finemath-4plus", 0.25, "text", "raw"), DataSource("open-thoughts/OpenThoughts-114k", None, 0.25, "conversations", "reasoning"), DataSource("HuggingFaceFW/fineweb-edu", "sample-10BT", 0.20, "text", "raw"), DataSource("HuggingFaceTB/smoltalk", "apigen-80k", 0.12, "messages", "tool_call"), DataSource("HuggingFaceTB/smollm-corpus", "python-edu", 0.08, "text", "raw"), DataSource("microsoft/orca-agentinstruct-1M-v1", "analytical_reasoning", 0.10, "messages", "agent"), ] # ============================================================ # PIPELINE # ============================================================ def get_current_stage(step: int, config) -> int: """Determine curriculum stage from step number.""" stage1_end = int(config.total_steps * config.stage1_end_frac) stage2_end = int(config.total_steps * config.stage2_end_frac) if step < stage1_end: return 1 elif step < stage2_end: return 2 return 3 class TAPDataPipeline: """ Streams real data from HF Hub, formats with TAP tags, packs sequences. Robust to individual dataset failures — skips broken sources and redistributes weight to working ones. """ def __init__(self, config, tokenizer, start_step=0, start_position=0): self.config = config self.tokenizer = tokenizer self.current_stage = get_current_stage(start_step, config) self.position = start_position self.step = start_step self.samples_processed = 0 self.samples_failed = 0 self.packer = SequencePacker( max_length=config.max_seq_len, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0, ) self._load_stage_datasets() def _load_stage_datasets(self): """Load datasets for current stage. Skip any that fail.""" sources = get_stage_sources(self.current_stage) self.datasets = [] self.weights = [] print(f"[Data] Loading stage {self.current_stage} datasets...") for source in sources: try: kwargs = {"path": source.dataset_id, "split": source.split, "streaming": True} if source.config_name: kwargs["name"] = source.config_name ds = load_dataset(**kwargs) self.datasets.append((ds, source)) self.weights.append(source.weight) print(f" ✓ {source.dataset_id}/{source.config_name or ''} ({source.weight:.0%}) [{source.formatter}]") except Exception as e: print(f" ✗ {source.dataset_id}/{source.config_name or ''}: {e}") if not self.datasets: raise RuntimeError("[Data] FATAL: No datasets loaded! Check network/auth.") # Normalize weights total_w = sum(self.weights) self.weights = [w / total_w for w in self.weights] # Create iterators self.iterators = [iter(ds) for ds, _ in self.datasets] print(f"[Data] ✓ {len(self.datasets)} sources ready for stage {self.current_stage}\n") def _get_sample(self) -> Tuple[Optional[Dict], Optional[DataSource]]: """Get one sample from weighted random source. Handles StopIteration.""" if not self.datasets: return None, None idx = pyrandom.choices(range(len(self.datasets)), weights=self.weights, k=1)[0] _, source = self.datasets[idx] try: sample = next(self.iterators[idx]) return sample, source except StopIteration: # Restart iterator (epoch boundary) ds, _ = self.datasets[idx] self.iterators[idx] = iter(ds) try: sample = next(self.iterators[idx]) return sample, source except StopIteration: return None, None except Exception as e: # Network glitch, rate limit, etc. — skip this sample self.samples_failed += 1 if self.samples_failed % 100 == 0: print(f"[Data] Warning: {self.samples_failed} samples failed ({e})") return None, None def _tokenize(self, sample: Dict, source: DataSource) -> List[int]: """Format + tokenize. Returns empty list on failure.""" try: formatter = FORMATTERS[source.formatter] text = formatter(sample) if not text or len(text.strip()) < 20: return [] tokens = self.tokenizer.encode(text, add_special_tokens=False) if len(tokens) < 5: return [] self.samples_processed += 1 return tokens except Exception: self.samples_failed += 1 return [] def get_batch(self, batch_size: int, step: int) -> np.ndarray: """Get one batch of packed sequences.""" # Check stage switch new_stage = get_current_stage(step, self.config) if new_stage != self.current_stage: print(f"\n[Data] ═══ Stage switch: {self.current_stage} → {new_stage} ═══") self.current_stage = new_stage self._load_stage_datasets() sequences = [] max_attempts = batch_size * 200 for _ in range(max_attempts): if len(sequences) >= batch_size: break sample, source = self._get_sample() if sample is None: continue tokens = self._tokenize(sample, source) if not tokens: continue result = self.packer.add_document(tokens) if result is not None: sequences.append(result) # Flush if needed while len(sequences) < batch_size: result = self.packer.flush() if result is not None: sequences.append(result) else: # Absolute last resort — should never happen with working data print("[Data] ⚠️ Had to insert padding sequence!") sequences.append( np.full(self.config.max_seq_len, self.tokenizer.pad_token_id, dtype=np.int32) ) return np.stack(sequences[:batch_size]) def __iter__(self): step = self.step while True: yield self.get_batch(self.config.global_batch_size, step) step += 1 def create_data_pipeline(config, tokenizer, start_step=0, start_position=0) -> Iterator: """Create the streaming data pipeline. Returns infinite batch iterator.""" pipeline = TAPDataPipeline(config, tokenizer, start_step, start_position) return iter(pipeline)