"""Build a high-quality Russian SFT dataset from IlyaGusev/saiga_scored. Filters: Russian, single-turn, not bad-by-regex, opus_score >= MINSCORE. Format (consistent with inference): prompt = "Вопрос: {user}\nОтвет:" answer = " {bot}" + Loss is computed only on the answer (prompt + pad masked with -100). Saves /root/sft.pt = {"train": [...], "val": [...]} of dicts {ids, labels}. """ import torch from datasets import load_dataset from tokenizers import Tokenizer MINSCORE = 8 MAXLEN = 512 TOK = "/root/ru_tok.json" OUT = "/root/sft.pt" tok = Tokenizer.from_file(TOK) BOS = tok.token_to_id("") EOS = tok.token_to_id("") ds = load_dataset("IlyaGusev/saiga_scored", split="train") def keep(r): return (r["language"] == "Russian" and not r["is_bad_by_regex"] and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1 and len(r["messages"]) >= 2 and r["messages"][0]["role"] == "user" and r["messages"][1]["role"] == "bot") examples = [] for r in ds: if not keep(r): continue user = r["messages"][0]["content"].strip() bot = r["messages"][1]["content"].strip() if not user or not bot: continue prompt = f"Вопрос: {user}\nОтвет:" p_ids = tok.encode(prompt).ids a_ids = tok.encode(" " + bot).ids + [EOS] ids = [BOS] + p_ids + a_ids if len(ids) > MAXLEN: continue labels = [-100] * (1 + len(p_ids)) + a_ids # mask bos+prompt, learn answer+eos assert len(ids) == len(labels) examples.append({"ids": ids, "labels": labels}) import random random.seed(0) random.shuffle(examples) nval = 800 val = examples[:nval] train = examples[nval:] torch.save({"train": train, "val": val}, OUT) ntok = sum(len(e["ids"]) for e in examples) ans_tok = sum(sum(1 for l in e["labels"] if l != -100) for e in examples) print(f"examples={len(examples)} train={len(train)} val={len(val)} " f"total_tokens={ntok/1e6:.2f}M answer_tokens={ans_tok/1e6:.2f}M " f"avg_len={ntok/len(examples):.0f}") print("saved", OUT)