| """Build retrieval-augmented SFT data. |
| |
| Two skills mixed: |
| (A) GROUNDED answering from context (SberQuAD): teaches "use the provided context". |
| prompt = "Контекст: {ctx}\nВопрос: {q}\nОтвет:" ; answer = " {span}"<eos> |
| (B) GENERAL instruction following (saiga_scored, no context), to keep fluency. |
| prompt = "Вопрос: {q}\nОтвет:" ; answer = " {bot}"<eos> |
| Loss only on answer (prompt+pad = -100). Saves /root/sft_rag.pt. |
| """ |
| import torch, random |
| from datasets import load_dataset |
| from tokenizers import Tokenizer |
|
|
| TOK = "/root/ru_tok.json" |
| OUT = "/root/sft_rag.pt" |
| MAXLEN = 512 |
| MINSCORE = 8 |
|
|
| tok = Tokenizer.from_file(TOK) |
| BOS = tok.token_to_id("<bos>") |
| EOS = tok.token_to_id("<eos>") |
|
|
|
|
| def enc(s): |
| return tok.encode(s).ids |
|
|
|
|
| examples = [] |
|
|
| |
| sq = load_dataset("sberquad", split="train") |
| pre = enc("Контекст: ") |
| n_ctx_trunc = 0 |
| for r in sq: |
| ctx, q = r["context"].strip(), r["question"].strip() |
| ans = r["answers"]["text"][0].strip() if r["answers"]["text"] else "" |
| if not ctx or not q or not ans: |
| continue |
| midq = enc("\nВопрос: " + q + "\nОтвет:") |
| a_ids = enc(" " + ans) + [EOS] |
| budget = MAXLEN - 1 - len(pre) - len(midq) - len(a_ids) |
| if budget < 32: |
| continue |
| c_ids = enc(ctx) |
| if len(c_ids) > budget: |
| c_ids = c_ids[:budget]; n_ctx_trunc += 1 |
| ids = [BOS] + pre + c_ids + midq + a_ids |
| labels = [-100] * (len(ids) - len(a_ids)) + a_ids |
| examples.append({"ids": ids, "labels": labels}) |
| n_sq = len(examples) |
|
|
| |
| sa = load_dataset("IlyaGusev/saiga_scored", split="train") |
| for r in sa: |
| if not (r["language"] == "Russian" and not r["is_bad_by_regex"] |
| and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1 |
| and len(r["messages"]) >= 2 |
| and r["messages"][0]["role"] == "user" |
| and r["messages"][1]["role"] == "bot"): |
| continue |
| u = r["messages"][0]["content"].strip() |
| b = r["messages"][1]["content"].strip() |
| if not u or not b: |
| continue |
| p_ids = enc(f"Вопрос: {u}\nОтвет:") |
| a_ids = enc(" " + b) + [EOS] |
| ids = [BOS] + p_ids + a_ids |
| if len(ids) > MAXLEN: |
| continue |
| labels = [-100] * (1 + len(p_ids)) + a_ids |
| examples.append({"ids": ids, "labels": labels}) |
| n_sa = len(examples) - n_sq |
|
|
| random.seed(0) |
| random.shuffle(examples) |
| nval = 1000 |
| val, train = examples[:nval], examples[nval:] |
| torch.save({"train": train, "val": val}, OUT) |
| ntok = sum(len(e["ids"]) for e in examples) |
| print(f"sberquad={n_sq} (ctx_trunc={n_ctx_trunc}) saiga={n_sa} total={len(examples)} " |
| f"train={len(train)} val={len(val)} tokens={ntok/1e6:.2f}M avg_len={ntok/len(examples):.0f}") |
| print("saved", OUT) |
|
|