QED-75M_artifacts / scripts /prepare_sft_data.py
levossadtchi's picture
Add files using upload-large-folder tool
9847679 verified
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from datasets import load_dataset
from tokenizers import Tokenizer
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))
from sllm.config import load_json, save_json
from sllm.data import SFTShardWriter
from sllm.utils import setup_logger
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Prepare fixed-length SFT tensors.")
parser.add_argument("--config", required=True, help="Path to SFT data JSON config.")
parser.add_argument("--tokenizer-dir", required=True, help="Directory with tokenizer.json and metadata.")
parser.add_argument("--output-dir", required=True, help="Directory to store processed SFT tensors.")
parser.add_argument("--seq-len", type=int, default=2_048, help="Packed example length.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.")
return parser
def load_tokenizer(tokenizer_dir: str | Path) -> tuple[Tokenizer, dict]:
tokenizer_dir = Path(tokenizer_dir)
tokenizer = Tokenizer.from_file(str(tokenizer_dir / "tokenizer.json"))
metadata = load_json(tokenizer_dir / "tokenizer_meta.json")
return tokenizer, metadata
def row_to_messages(row: dict, config: dict) -> list[dict[str, str]]:
fmt = config.get("format", "messages")
if fmt == "messages":
messages = row.get(config.get("messages_field", "messages"))
if not isinstance(messages, list):
raise ValueError("Не найден список сообщений в SFT-датасете.")
normalized = []
for message in messages:
if not isinstance(message, dict):
continue
role = message.get("role")
content = message.get("content")
if isinstance(content, list):
parts = [item.get("text", "") for item in content if isinstance(item, dict)]
content = "\n".join(part for part in parts if part)
if isinstance(role, str) and isinstance(content, str) and content.strip():
normalized.append({"role": role, "content": content.strip()})
return normalized
if fmt == "prompt_response":
prompt = row.get(config.get("prompt_field", "prompt"))
response = row.get(config.get("response_field", "response"))
if not isinstance(prompt, str) or not isinstance(response, str):
raise ValueError("Не найдены поля prompt/response в SFT-датасете.")
system_prompt = config.get("system_prompt")
messages = []
if isinstance(system_prompt, str) and system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
messages.append({"role": "user", "content": prompt.strip()})
messages.append({"role": "assistant", "content": response.strip()})
return messages
if fmt == "alpaca":
instruction = row.get(config.get("instruction_field", "instruction"))
input_text = row.get(config.get("input_field", "input"), "")
output_text = row.get(config.get("output_field", "output"))
if not isinstance(instruction, str) or not isinstance(output_text, str):
raise ValueError("Не найдены поля instruction/output в Alpaca-подобном датасете.")
prompt = instruction.strip()
if isinstance(input_text, str) and input_text.strip():
prompt = f"{prompt}\n\n{input_text.strip()}"
return [
{"role": "user", "content": prompt},
{"role": "assistant", "content": output_text.strip()},
]
raise ValueError(f"Unsupported SFT format: {fmt}")
def tokenize_messages(
tokenizer: Tokenizer,
messages: list[dict[str, str]],
bos_id: int,
eos_id: int,
) -> tuple[list[int], list[int]]:
input_ids = [bos_id]
labels = [-100]
for message in messages:
role = message["role"].strip().lower()
content = message["content"].strip()
if not content:
continue
text = f"<|{role}|>\n{content}\n"
piece = tokenizer.encode(text, add_special_tokens=False).ids
if not piece:
continue
input_ids.extend(piece)
if role == "assistant":
labels.extend(piece)
else:
labels.extend([-100] * len(piece))
input_ids.append(eos_id)
labels.append(eos_id)
return input_ids, labels
def pad_or_truncate(
input_ids: list[int],
labels: list[int],
seq_len: int,
pad_id: int,
) -> tuple[list[int], list[int]]:
input_ids = input_ids[:seq_len]
labels = labels[:seq_len]
if len(input_ids) < seq_len:
pad_length = seq_len - len(input_ids)
input_ids = input_ids + [pad_id] * pad_length
labels = labels + [-100] * pad_length
return input_ids, labels
def main() -> None:
args = build_parser().parse_args()
config = load_json(args.config)
tokenizer, tokenizer_meta = load_tokenizer(args.tokenizer_dir)
specials = tokenizer_meta["special_tokens"]
bos_id = int(specials["bos_token_id"])
eos_id = int(specials["eos_token_id"])
pad_id = int(specials["pad_token_id"])
dataset = load_dataset(
path=config["path"],
name=config.get("config_name"),
split=config.get("split", "train"),
revision=config.get("revision"),
streaming=bool(config.get("streaming", False)),
)
if config.get("shuffle", True):
dataset = dataset.shuffle(seed=args.seed)
val_examples = int(config.get("val_examples", 1_000))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger, log_path = setup_logger("sllm.prepare_sft_data", output_dir, "prepare_sft_data")
logger.info("SFT data preparation started")
logger.info("Log file: %s", log_path)
logger.info(
"Arguments | config=%s tokenizer_dir=%s output_dir=%s seq_len=%s seed=%s",
args.config,
args.tokenizer_dir,
args.output_dir,
args.seq_len,
args.seed,
)
logger.info(
"SFT source config | path=%s config_name=%s split=%s format=%s streaming=%s val_examples=%s max_train_examples=%s",
config.get("path"),
config.get("config_name"),
config.get("split", "train"),
config.get("format", "messages"),
bool(config.get("streaming", False)),
val_examples,
config.get("max_train_examples"),
)
train_writer = SFTShardWriter(output_dir, prefix="train", seq_len=args.seq_len)
val_writer = SFTShardWriter(output_dir, prefix="val", seq_len=args.seq_len)
train_count = 0
val_count = 0
max_train_examples = config.get("max_train_examples")
for row in dataset:
messages = row_to_messages(row, config)
if not messages:
continue
input_ids, labels = tokenize_messages(tokenizer, messages, bos_id=bos_id, eos_id=eos_id)
input_ids, labels = pad_or_truncate(input_ids, labels, args.seq_len, pad_id=pad_id)
if val_count < val_examples:
val_writer.add_example(input_ids, labels)
val_count += 1
else:
train_writer.add_example(input_ids, labels)
train_count += 1
total_examples = train_count + val_count
if total_examples % 5_000 == 0:
logger.info(
"SFT progress | processed=%s train_examples=%s val_examples=%s",
f"{total_examples:,}",
f"{train_count:,}",
f"{val_count:,}",
)
if max_train_examples is not None and train_count >= int(max_train_examples):
break
train_metadata = train_writer.finalize()
val_metadata = val_writer.finalize()
save_json(
output_dir / "dataset_summary.json",
{
"config": config,
"tokenizer_meta": tokenizer_meta,
"train": train_metadata,
"val": val_metadata,
},
)
logger.info("SFT dataset saved | output_dir=%s", output_dir)
logger.info("SFT summary | train_examples=%s val_examples=%s", f"{train_count:,}", f"{val_count:,}")
logger.info("SFT metadata saved | path=%s", output_dir / "dataset_summary.json")
if __name__ == "__main__":
main()