| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from tokenizers import Tokenizer |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.append(str(ROOT / "src")) |
|
|
| from sllm.config import load_json, save_json |
| from sllm.data import SFTShardWriter |
| from sllm.utils import setup_logger |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Prepare fixed-length SFT tensors.") |
| parser.add_argument("--config", required=True, help="Path to SFT data JSON config.") |
| parser.add_argument("--tokenizer-dir", required=True, help="Directory with tokenizer.json and metadata.") |
| parser.add_argument("--output-dir", required=True, help="Directory to store processed SFT tensors.") |
| parser.add_argument("--seq-len", type=int, default=2_048, help="Packed example length.") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.") |
| return parser |
|
|
|
|
| def load_tokenizer(tokenizer_dir: str | Path) -> tuple[Tokenizer, dict]: |
| tokenizer_dir = Path(tokenizer_dir) |
| tokenizer = Tokenizer.from_file(str(tokenizer_dir / "tokenizer.json")) |
| metadata = load_json(tokenizer_dir / "tokenizer_meta.json") |
| return tokenizer, metadata |
|
|
|
|
| def row_to_messages(row: dict, config: dict) -> list[dict[str, str]]: |
| fmt = config.get("format", "messages") |
| if fmt == "messages": |
| messages = row.get(config.get("messages_field", "messages")) |
| if not isinstance(messages, list): |
| raise ValueError("Не найден список сообщений в SFT-датасете.") |
| normalized = [] |
| for message in messages: |
| if not isinstance(message, dict): |
| continue |
| role = message.get("role") |
| content = message.get("content") |
| if isinstance(content, list): |
| parts = [item.get("text", "") for item in content if isinstance(item, dict)] |
| content = "\n".join(part for part in parts if part) |
| if isinstance(role, str) and isinstance(content, str) and content.strip(): |
| normalized.append({"role": role, "content": content.strip()}) |
| return normalized |
|
|
| if fmt == "prompt_response": |
| prompt = row.get(config.get("prompt_field", "prompt")) |
| response = row.get(config.get("response_field", "response")) |
| if not isinstance(prompt, str) or not isinstance(response, str): |
| raise ValueError("Не найдены поля prompt/response в SFT-датасете.") |
| system_prompt = config.get("system_prompt") |
| messages = [] |
| if isinstance(system_prompt, str) and system_prompt.strip(): |
| messages.append({"role": "system", "content": system_prompt.strip()}) |
| messages.append({"role": "user", "content": prompt.strip()}) |
| messages.append({"role": "assistant", "content": response.strip()}) |
| return messages |
|
|
| if fmt == "alpaca": |
| instruction = row.get(config.get("instruction_field", "instruction")) |
| input_text = row.get(config.get("input_field", "input"), "") |
| output_text = row.get(config.get("output_field", "output")) |
| if not isinstance(instruction, str) or not isinstance(output_text, str): |
| raise ValueError("Не найдены поля instruction/output в Alpaca-подобном датасете.") |
| prompt = instruction.strip() |
| if isinstance(input_text, str) and input_text.strip(): |
| prompt = f"{prompt}\n\n{input_text.strip()}" |
| return [ |
| {"role": "user", "content": prompt}, |
| {"role": "assistant", "content": output_text.strip()}, |
| ] |
|
|
| raise ValueError(f"Unsupported SFT format: {fmt}") |
|
|
|
|
| def tokenize_messages( |
| tokenizer: Tokenizer, |
| messages: list[dict[str, str]], |
| bos_id: int, |
| eos_id: int, |
| ) -> tuple[list[int], list[int]]: |
| input_ids = [bos_id] |
| labels = [-100] |
|
|
| for message in messages: |
| role = message["role"].strip().lower() |
| content = message["content"].strip() |
| if not content: |
| continue |
| text = f"<|{role}|>\n{content}\n" |
| piece = tokenizer.encode(text, add_special_tokens=False).ids |
| if not piece: |
| continue |
| input_ids.extend(piece) |
| if role == "assistant": |
| labels.extend(piece) |
| else: |
| labels.extend([-100] * len(piece)) |
|
|
| input_ids.append(eos_id) |
| labels.append(eos_id) |
| return input_ids, labels |
|
|
|
|
| def pad_or_truncate( |
| input_ids: list[int], |
| labels: list[int], |
| seq_len: int, |
| pad_id: int, |
| ) -> tuple[list[int], list[int]]: |
| input_ids = input_ids[:seq_len] |
| labels = labels[:seq_len] |
| if len(input_ids) < seq_len: |
| pad_length = seq_len - len(input_ids) |
| input_ids = input_ids + [pad_id] * pad_length |
| labels = labels + [-100] * pad_length |
| return input_ids, labels |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
| config = load_json(args.config) |
| tokenizer, tokenizer_meta = load_tokenizer(args.tokenizer_dir) |
| specials = tokenizer_meta["special_tokens"] |
| bos_id = int(specials["bos_token_id"]) |
| eos_id = int(specials["eos_token_id"]) |
| pad_id = int(specials["pad_token_id"]) |
|
|
| dataset = load_dataset( |
| path=config["path"], |
| name=config.get("config_name"), |
| split=config.get("split", "train"), |
| revision=config.get("revision"), |
| streaming=bool(config.get("streaming", False)), |
| ) |
| if config.get("shuffle", True): |
| dataset = dataset.shuffle(seed=args.seed) |
|
|
| val_examples = int(config.get("val_examples", 1_000)) |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| logger, log_path = setup_logger("sllm.prepare_sft_data", output_dir, "prepare_sft_data") |
| logger.info("SFT data preparation started") |
| logger.info("Log file: %s", log_path) |
| logger.info( |
| "Arguments | config=%s tokenizer_dir=%s output_dir=%s seq_len=%s seed=%s", |
| args.config, |
| args.tokenizer_dir, |
| args.output_dir, |
| args.seq_len, |
| args.seed, |
| ) |
| logger.info( |
| "SFT source config | path=%s config_name=%s split=%s format=%s streaming=%s val_examples=%s max_train_examples=%s", |
| config.get("path"), |
| config.get("config_name"), |
| config.get("split", "train"), |
| config.get("format", "messages"), |
| bool(config.get("streaming", False)), |
| val_examples, |
| config.get("max_train_examples"), |
| ) |
| train_writer = SFTShardWriter(output_dir, prefix="train", seq_len=args.seq_len) |
| val_writer = SFTShardWriter(output_dir, prefix="val", seq_len=args.seq_len) |
|
|
| train_count = 0 |
| val_count = 0 |
| max_train_examples = config.get("max_train_examples") |
|
|
| for row in dataset: |
| messages = row_to_messages(row, config) |
| if not messages: |
| continue |
| input_ids, labels = tokenize_messages(tokenizer, messages, bos_id=bos_id, eos_id=eos_id) |
| input_ids, labels = pad_or_truncate(input_ids, labels, args.seq_len, pad_id=pad_id) |
|
|
| if val_count < val_examples: |
| val_writer.add_example(input_ids, labels) |
| val_count += 1 |
| else: |
| train_writer.add_example(input_ids, labels) |
| train_count += 1 |
|
|
| total_examples = train_count + val_count |
| if total_examples % 5_000 == 0: |
| logger.info( |
| "SFT progress | processed=%s train_examples=%s val_examples=%s", |
| f"{total_examples:,}", |
| f"{train_count:,}", |
| f"{val_count:,}", |
| ) |
|
|
| if max_train_examples is not None and train_count >= int(max_train_examples): |
| break |
|
|
| train_metadata = train_writer.finalize() |
| val_metadata = val_writer.finalize() |
| save_json( |
| output_dir / "dataset_summary.json", |
| { |
| "config": config, |
| "tokenizer_meta": tokenizer_meta, |
| "train": train_metadata, |
| "val": val_metadata, |
| }, |
| ) |
| logger.info("SFT dataset saved | output_dir=%s", output_dir) |
| logger.info("SFT summary | train_examples=%s val_examples=%s", f"{train_count:,}", f"{val_count:,}") |
| logger.info("SFT metadata saved | path=%s", output_dir / "dataset_summary.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|