from __future__ import annotations import argparse import math import random import sys from collections import deque from pathlib import Path from datasets import load_dataset from tokenizers import Tokenizer ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT / "src")) from sllm.config import DataMixConfig, load_json, save_json from sllm.data import TokenShardWriter from sllm.utils import setup_logger def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Tokenize and shard pretraining corpora.") parser.add_argument("--data-config", required=True, help="Path to data mixture JSON config.") parser.add_argument("--tokenizer-dir", required=True, help="Directory with tokenizer.json.") parser.add_argument("--output-dir", required=True, help="Root directory for train/val shards.") parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.") return parser def load_tokenizer(tokenizer_dir: str | Path) -> tuple[Tokenizer, dict]: tokenizer_dir = Path(tokenizer_dir) tokenizer = Tokenizer.from_file(str(tokenizer_dir / "tokenizer.json")) metadata = load_json(tokenizer_dir / "tokenizer_meta.json") return tokenizer, metadata def iter_source_rows(source, seed: int): dataset = load_dataset( path=source.path, name=source.config_name, data_dir=source.data_dir, split=source.split, revision=source.revision, streaming=source.streaming, ) if source.streaming: dataset = dataset.shuffle(seed=seed, buffer_size=source.shuffle_buffer) return iter(dataset) TOKENIZE_BATCH_SIZE = 128 def allocate_token_targets(data_config: DataMixConfig, total_tokens: int) -> dict[str, int]: weights = data_config.normalized_weights() raw_targets = { source.name: total_tokens * weights[source.name] for source in data_config.sources } base_targets = { name: int(math.floor(value)) for name, value in raw_targets.items() } remainder = total_tokens - sum(base_targets.values()) ranked = sorted( raw_targets.items(), key=lambda item: (item[1] - math.floor(item[1]), item[0]), reverse=True, ) for index in range(remainder): name = ranked[index % len(ranked)][0] base_targets[name] += 1 return base_targets def make_source_state(source, seed: int) -> dict: return { "source": source, "iterator": iter_source_rows(source, seed), "documents_used": 0, "train_tokens_written": 0, "val_tokens_written": 0, "exhausted": False, "token_queue": deque(), } def refill_token_queue(state: dict, tokenizer: Tokenizer) -> None: if state["exhausted"]: return texts: list[str] = [] while len(texts) < TOKENIZE_BATCH_SIZE: try: row = next(state["iterator"]) except StopIteration: state["exhausted"] = True break text = row.get(state["source"].text_field or "", None) if not isinstance(text, str): continue text = text.strip() if not text: continue texts.append(text) if not texts: return encoded_batch = tokenizer.encode_batch(texts) for encoded in encoded_batch: token_ids = encoded.ids if token_ids: state["token_queue"].append(token_ids) def next_valid_token_ids(state: dict, tokenizer: Tokenizer) -> list[int] | None: while True: if state["token_queue"]: state["documents_used"] += 1 return state["token_queue"].popleft() if state["exhausted"]: return None refill_token_queue(state, tokenizer) def choose_source_name(states: dict[str, dict], targets: dict[str, int], split: str, rng: random.Random) -> str | None: candidates = [] for name, state in states.items(): if state["exhausted"]: continue target = targets[name] if target <= 0: continue written = state[f"{split}_tokens_written"] if written >= target: continue progress = written / target candidates.append((progress, rng.random(), name)) if not candidates: return None candidates.sort(key=lambda item: (item[0], item[1])) return candidates[0][2] def interleave_split( split: str, writer: TokenShardWriter, states: dict[str, dict], targets: dict[str, int], tokenizer: Tokenizer, logger, rng: random.Random, ) -> int: total_target = sum(targets.values()) total_written = 0 emitted_documents = 0 logger.info( "Interleave start | split=%s total_target_tokens=%s strategy=weighted_progress_balancing", split, f"{total_target:,}", ) while total_written < total_target: source_name = choose_source_name(states, targets, split, rng) if source_name is None: raise RuntimeError( f"Недостаточно данных для заполнения split={split}. " "Все доступные источники исчерпаны до достижения целевого объема." ) state = states[source_name] token_ids = next_valid_token_ids(state, tokenizer) if token_ids is None: logger.warning("Source exhausted early | split=%s source=%s", split, source_name) continue source_remaining = targets[source_name] - state[f"{split}_tokens_written"] split_remaining = total_target - total_written chunk = token_ids[: min(len(token_ids), source_remaining, split_remaining)] if not chunk: continue writer.add_tokens(chunk) state[f"{split}_tokens_written"] += len(chunk) total_written += len(chunk) emitted_documents += 1 if emitted_documents % 10_000 == 0: logger.info( "Interleave progress | split=%s documents=%s total_tokens=%s/%s current_source=%s", split, f"{emitted_documents:,}", f"{total_written:,}", f"{total_target:,}", source_name, ) logger.info( "Interleave done | split=%s documents=%s total_tokens=%s", split, f"{emitted_documents:,}", f"{total_written:,}", ) return total_written def main() -> None: args = build_parser().parse_args() data_config = DataMixConfig.from_dict(load_json(args.data_config)) tokenizer, tokenizer_meta = load_tokenizer(args.tokenizer_dir) output_dir = Path(args.output_dir) train_dir = output_dir / "train" val_dir = output_dir / "val" train_dir.mkdir(parents=True, exist_ok=True) val_dir.mkdir(parents=True, exist_ok=True) logger, log_path = setup_logger("sllm.prepare_pretrain_data", output_dir, "prepare_pretrain_data") logger.info("Pretokenization started") logger.info("Log file: %s", log_path) logger.info("Arguments | data_config=%s tokenizer_dir=%s output_dir=%s seed=%s", args.data_config, args.tokenizer_dir, args.output_dir, args.seed) logger.info("Tokenizer meta | vocab_size=%s special_tokens=%s", tokenizer_meta.get("vocab_size"), tokenizer_meta.get("special_tokens")) logger.info("Mixing strategy | global interleaving with weighted progress balancing") logger.info("Tokenization strategy | encode_batch with batch_size=%s", TOKENIZE_BATCH_SIZE) weight_map = data_config.normalized_weights() train_targets = allocate_token_targets(data_config, data_config.train_tokens) val_targets = allocate_token_targets(data_config, data_config.val_tokens) dataset_summary: dict[str, dict] = {} states: dict[str, dict] = {} for index, source in enumerate(data_config.sources): states[source.name] = make_source_state(source, args.seed + index) logger.info( "Source registered | name=%s path=%s data_dir=%s split=%s text_field=%s weight=%.4f train_target=%s val_target=%s streaming=%s", source.name, source.path, source.data_dir, source.split, source.text_field, weight_map[source.name], f"{train_targets[source.name]:,}", f"{val_targets[source.name]:,}", source.streaming, ) rng_val = random.Random(args.seed + 10_000) rng_train = random.Random(args.seed + 20_000) val_writer = TokenShardWriter( output_dir=val_dir, prefix="val", shard_size_tokens=max(1_000_000, min(data_config.shard_size_tokens, data_config.val_tokens)), ) train_writer = TokenShardWriter( output_dir=train_dir, prefix="train", shard_size_tokens=data_config.shard_size_tokens, ) total_val = interleave_split("val", val_writer, states, val_targets, tokenizer, logger, rng_val) total_train = interleave_split("train", train_writer, states, train_targets, tokenizer, logger, rng_train) train_shards = train_writer.finalize() val_shards = val_writer.finalize() for source in data_config.sources: state = states[source.name] dataset_summary[source.name] = { "path": source.path, "data_dir": source.data_dir, "split": source.split, "train_target_tokens": train_targets[source.name], "val_target_tokens": val_targets[source.name], "train_tokens_written": state["train_tokens_written"], "val_tokens_written": state["val_tokens_written"], "documents_used": state["documents_used"], } logger.info( "Source done | name=%s documents=%s train_tokens=%s/%s val_tokens=%s/%s", source.name, f"{state['documents_used']:,}", f"{state['train_tokens_written']:,}", f"{train_targets[source.name]:,}", f"{state['val_tokens_written']:,}", f"{val_targets[source.name]:,}", ) save_json( output_dir / "dataset_summary.json", { "tokenizer": tokenizer_meta, "data_config": data_config.to_dict(), "mixing_strategy": "global_interleaving_weighted_progress_balancing", "train_target_tokens": data_config.train_tokens, "val_target_tokens": data_config.val_tokens, "train_tokens_written": total_train, "val_tokens_written": total_val, "train_shards": len(train_shards), "val_shards": len(val_shards), "sources": dataset_summary, }, ) logger.info( "Pretokenization finished | output_dir=%s total_train_tokens=%s total_val_tokens=%s train_shards=%s val_shards=%s", output_dir, f"{total_train:,}", f"{total_val:,}", len(train_shards), len(val_shards), ) logger.info("Dataset summary saved | path=%s", output_dir / "dataset_summary.json") if __name__ == "__main__": main()