70m-adamwgrok / sft_data.py
Asilarknes's picture
Upload sft_data.py with huggingface_hub
e0bb40a verified
"""Build a high-quality Russian SFT dataset from IlyaGusev/saiga_scored.
Filters: Russian, single-turn, not bad-by-regex, opus_score >= MINSCORE.
Format (consistent with inference):
prompt = "Вопрос: {user}\nОтвет:"
answer = " {bot}" + <eos>
Loss is computed only on the answer (prompt + pad masked with -100).
Saves /root/sft.pt = {"train": [...], "val": [...]} of dicts {ids, labels}.
"""
import torch
from datasets import load_dataset
from tokenizers import Tokenizer
MINSCORE = 8
MAXLEN = 512
TOK = "/root/ru_tok.json"
OUT = "/root/sft.pt"
tok = Tokenizer.from_file(TOK)
BOS = tok.token_to_id("<bos>")
EOS = tok.token_to_id("<eos>")
ds = load_dataset("IlyaGusev/saiga_scored", split="train")
def keep(r):
return (r["language"] == "Russian" and not r["is_bad_by_regex"]
and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1
and len(r["messages"]) >= 2
and r["messages"][0]["role"] == "user"
and r["messages"][1]["role"] == "bot")
examples = []
for r in ds:
if not keep(r):
continue
user = r["messages"][0]["content"].strip()
bot = r["messages"][1]["content"].strip()
if not user or not bot:
continue
prompt = f"Вопрос: {user}\nОтвет:"
p_ids = tok.encode(prompt).ids
a_ids = tok.encode(" " + bot).ids + [EOS]
ids = [BOS] + p_ids + a_ids
if len(ids) > MAXLEN:
continue
labels = [-100] * (1 + len(p_ids)) + a_ids # mask bos+prompt, learn answer+eos
assert len(ids) == len(labels)
examples.append({"ids": ids, "labels": labels})
import random
random.seed(0)
random.shuffle(examples)
nval = 800
val = examples[:nval]
train = examples[nval:]
torch.save({"train": train, "val": val}, OUT)
ntok = sum(len(e["ids"]) for e in examples)
ans_tok = sum(sum(1 for l in e["labels"] if l != -100) for e in examples)
print(f"examples={len(examples)} train={len(train)} val={len(val)} "
f"total_tokens={ntok/1e6:.2f}M answer_tokens={ans_tok/1e6:.2f}M "
f"avg_len={ntok/len(examples):.0f}")
print("saved", OUT)