"""Build retrieval-augmented SFT data. Two skills mixed: (A) GROUNDED answering from context (SberQuAD): teaches "use the provided context". prompt = "Контекст: {ctx}\nВопрос: {q}\nОтвет:" ; answer = " {span}" (B) GENERAL instruction following (saiga_scored, no context), to keep fluency. prompt = "Вопрос: {q}\nОтвет:" ; answer = " {bot}" Loss only on answer (prompt+pad = -100). Saves /root/sft_rag.pt. """ import torch, random from datasets import load_dataset from tokenizers import Tokenizer TOK = "/root/ru_tok.json" OUT = "/root/sft_rag.pt" MAXLEN = 512 MINSCORE = 8 tok = Tokenizer.from_file(TOK) BOS = tok.token_to_id("") EOS = tok.token_to_id("") def enc(s): return tok.encode(s).ids examples = [] # (A) SberQuAD — grounded sq = load_dataset("sberquad", split="train") pre = enc("Контекст: ") n_ctx_trunc = 0 for r in sq: ctx, q = r["context"].strip(), r["question"].strip() ans = r["answers"]["text"][0].strip() if r["answers"]["text"] else "" if not ctx or not q or not ans: continue midq = enc("\nВопрос: " + q + "\nОтвет:") a_ids = enc(" " + ans) + [EOS] budget = MAXLEN - 1 - len(pre) - len(midq) - len(a_ids) if budget < 32: continue c_ids = enc(ctx) if len(c_ids) > budget: c_ids = c_ids[:budget]; n_ctx_trunc += 1 ids = [BOS] + pre + c_ids + midq + a_ids labels = [-100] * (len(ids) - len(a_ids)) + a_ids examples.append({"ids": ids, "labels": labels}) n_sq = len(examples) # (B) saiga_scored — general, no context sa = load_dataset("IlyaGusev/saiga_scored", split="train") for r in sa: if not (r["language"] == "Russian" and not r["is_bad_by_regex"] and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1 and len(r["messages"]) >= 2 and r["messages"][0]["role"] == "user" and r["messages"][1]["role"] == "bot"): continue u = r["messages"][0]["content"].strip() b = r["messages"][1]["content"].strip() if not u or not b: continue p_ids = enc(f"Вопрос: {u}\nОтвет:") a_ids = enc(" " + b) + [EOS] ids = [BOS] + p_ids + a_ids if len(ids) > MAXLEN: continue labels = [-100] * (1 + len(p_ids)) + a_ids examples.append({"ids": ids, "labels": labels}) n_sa = len(examples) - n_sq random.seed(0) random.shuffle(examples) nval = 1000 val, train = examples[:nval], examples[nval:] torch.save({"train": train, "val": val}, OUT) ntok = sum(len(e["ids"]) for e in examples) print(f"sberquad={n_sq} (ctx_trunc={n_ctx_trunc}) saiga={n_sa} total={len(examples)} " f"train={len(train)} val={len(val)} tokens={ntok/1e6:.2f}M avg_len={ntok/len(examples):.0f}") print("saved", OUT)