70m-adamwgrok / sft_rag_data.py
Asilarknes's picture
Upload sft_rag_data.py with huggingface_hub
1bb7b62 verified
"""Build retrieval-augmented SFT data.
Two skills mixed:
(A) GROUNDED answering from context (SberQuAD): teaches "use the provided context".
prompt = "Контекст: {ctx}\nВопрос: {q}\nОтвет:" ; answer = " {span}"<eos>
(B) GENERAL instruction following (saiga_scored, no context), to keep fluency.
prompt = "Вопрос: {q}\nОтвет:" ; answer = " {bot}"<eos>
Loss only on answer (prompt+pad = -100). Saves /root/sft_rag.pt.
"""
import torch, random
from datasets import load_dataset
from tokenizers import Tokenizer
TOK = "/root/ru_tok.json"
OUT = "/root/sft_rag.pt"
MAXLEN = 512
MINSCORE = 8
tok = Tokenizer.from_file(TOK)
BOS = tok.token_to_id("<bos>")
EOS = tok.token_to_id("<eos>")
def enc(s):
return tok.encode(s).ids
examples = []
# (A) SberQuAD — grounded
sq = load_dataset("sberquad", split="train")
pre = enc("Контекст: ")
n_ctx_trunc = 0
for r in sq:
ctx, q = r["context"].strip(), r["question"].strip()
ans = r["answers"]["text"][0].strip() if r["answers"]["text"] else ""
if not ctx or not q or not ans:
continue
midq = enc("\nВопрос: " + q + "\nОтвет:")
a_ids = enc(" " + ans) + [EOS]
budget = MAXLEN - 1 - len(pre) - len(midq) - len(a_ids)
if budget < 32:
continue
c_ids = enc(ctx)
if len(c_ids) > budget:
c_ids = c_ids[:budget]; n_ctx_trunc += 1
ids = [BOS] + pre + c_ids + midq + a_ids
labels = [-100] * (len(ids) - len(a_ids)) + a_ids
examples.append({"ids": ids, "labels": labels})
n_sq = len(examples)
# (B) saiga_scored — general, no context
sa = load_dataset("IlyaGusev/saiga_scored", split="train")
for r in sa:
if not (r["language"] == "Russian" and not r["is_bad_by_regex"]
and (r["opus_score"] or 0) >= MINSCORE and r["turns"] <= 1
and len(r["messages"]) >= 2
and r["messages"][0]["role"] == "user"
and r["messages"][1]["role"] == "bot"):
continue
u = r["messages"][0]["content"].strip()
b = r["messages"][1]["content"].strip()
if not u or not b:
continue
p_ids = enc(f"Вопрос: {u}\nОтвет:")
a_ids = enc(" " + b) + [EOS]
ids = [BOS] + p_ids + a_ids
if len(ids) > MAXLEN:
continue
labels = [-100] * (1 + len(p_ids)) + a_ids
examples.append({"ids": ids, "labels": labels})
n_sa = len(examples) - n_sq
random.seed(0)
random.shuffle(examples)
nval = 1000
val, train = examples[:nval], examples[nval:]
torch.save({"train": train, "val": val}, OUT)
ntok = sum(len(e["ids"]) for e in examples)
print(f"sberquad={n_sq} (ctx_trunc={n_ctx_trunc}) saiga={n_sa} total={len(examples)} "
f"train={len(train)} val={len(val)} tokens={ntok/1e6:.2f}M avg_len={ntok/len(examples):.0f}")
print("saved", OUT)