| """Build a high-quality Russian SFT dataset from IlyaGusev/saiga_scored. |
| |
| Filters: Russian, single-turn, not bad-by-regex, opus_score >= MINSCORE. |
| Format (consistent with inference): |
| prompt = "Вопрос: {user}\nОтвет:" |
| answer = " {bot}" + <eos> |
| Loss is computed only on the answer (prompt + pad masked with -100). |
| Saves /root/sft.pt = {"train": [...], "val": [...]} of dicts {ids, labels}. |
| """ |
| import torch |
| from datasets import load_dataset |
| from tokenizers import Tokenizer |
|
|
| MINSCORE = 8 |
| MAXLEN = 512 |
| TOK = "/root/ru_tok.json" |
| OUT = "/root/sft.pt" |
|
|
| tok = Tokenizer.from_file(TOK) |
| BOS = tok.token_to_id("<bos>") |
| EOS = tok.token_to_id("<eos>") |
|
|
| ds = load_dataset("IlyaGusev/saiga_scored", split="train") |
|
|
|
|
| def keep(r): |
| return (r["language"] == "Russian" and not r["is_bad_by_regex"] |
| and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1 |
| and len(r["messages"]) >= 2 |
| and r["messages"][0]["role"] == "user" |
| and r["messages"][1]["role"] == "bot") |
|
|
|
|
| examples = [] |
| for r in ds: |
| if not keep(r): |
| continue |
| user = r["messages"][0]["content"].strip() |
| bot = r["messages"][1]["content"].strip() |
| if not user or not bot: |
| continue |
| prompt = f"Вопрос: {user}\nОтвет:" |
| p_ids = tok.encode(prompt).ids |
| a_ids = tok.encode(" " + bot).ids + [EOS] |
| ids = [BOS] + p_ids + a_ids |
| if len(ids) > MAXLEN: |
| continue |
| labels = [-100] * (1 + len(p_ids)) + a_ids |
| assert len(ids) == len(labels) |
| examples.append({"ids": ids, "labels": labels}) |
|
|
| import random |
| random.seed(0) |
| random.shuffle(examples) |
| nval = 800 |
| val = examples[:nval] |
| train = examples[nval:] |
| torch.save({"train": train, "val": val}, OUT) |
| ntok = sum(len(e["ids"]) for e in examples) |
| ans_tok = sum(sum(1 for l in e["labels"] if l != -100) for e in examples) |
| print(f"examples={len(examples)} train={len(train)} val={len(val)} " |
| f"total_tokens={ntok/1e6:.2f}M answer_tokens={ans_tok/1e6:.2f}M " |
| f"avg_len={ntok/len(examples):.0f}") |
| print("saved", OUT) |
|
|