testoneshot / scripts /glm_prep.py
Asilarknes's picture
upload oneshot glm artifacts
a216fa7 verified
"""Prepare GLM-5.1-Reasoning (main subset prefix) for the analytic model.
1. build a text corpus sample and train a 16k BPE SentencePiece tokenizer
(keeps the V x V PMI cooccurrence + SVD feasible, unlike GPT-2's 50k);
2. tokenize every record as input + "\n\n" + output with <eos> between
documents, into train.bin / valid.bin (uint16).
"""
import os, sys, json, time, numpy as np
import sentencepiece as spm
DATA = os.environ.get("ONESHOT_DATA", "/workspace/ts")
SRC = os.path.join(DATA, "main_prefix.jsonl")
CORPUS = os.path.join(DATA, "glm_corpus.txt")
SPM_PREFIX = os.path.join(DATA, "glm16k")
VOCAB = 16384
HF_DATASET = "Jackrong/GLM-5.1-Reasoning-1M-Cleaned"
def log(*a): print(f"[{time.strftime('%H:%M:%S')}]", *a, flush=True)
def first_present(record, names, default=""):
for name in names:
if name in record and record[name] is not None:
return record[name]
return default
def normalize_record(record):
inp = first_present(record, ["input", "prompt", "instruction", "question", "query"])
out = first_present(record, ["output", "response", "answer", "completion"])
if isinstance(inp, (list, dict)):
inp = json.dumps(inp, ensure_ascii=False)
if isinstance(out, (list, dict)):
out = json.dumps(out, ensure_ascii=False)
inp = str(inp).strip()
out = str(out).strip()
if not inp or not out:
return None
return {"input": inp, "output": out}
def download_jsonl(dataset=HF_DATASET, split="train", subset=None, max_records=0):
from datasets import load_dataset
kwargs = {"split": split, "streaming": True}
ds = load_dataset(dataset, subset, **kwargs) if subset else load_dataset(dataset, **kwargs)
n = 0
os.makedirs(DATA, exist_ok=True)
with open(SRC, "w", encoding="utf-8") as out:
for row in ds:
rec = normalize_record(row)
if rec is None:
continue
out.write(json.dumps(rec, ensure_ascii=False) + "\n")
n += 1
if n % 10000 == 0:
log(f"downloaded {n:,} records -> {SRC}")
if max_records and n >= max_records:
break
log(f"download done: {n:,} records -> {SRC}")
def answer_of(r):
"""The actual English response: the <think> reasoning dump is stripped,
keep the final answer."""
o = r.get("output", "")
if "</think>" in o:
o = o.split("</think>")[-1]
return o.strip()
def is_english_answer(a):
"""Keep natural-language answers; drop code/math/LaTeX-dominated ones so the
model learns to answer in plain English (the 'answer English' goal)."""
if not (40 <= len(a) <= 4000):
return False
if "```" in a: # code fence
return False
alpha = sum(c.isalpha() or c.isspace() for c in a) / len(a)
if alpha < 0.93:
return False
sym = sum(a.count(c) for c in "{}\\$=#|<>_~^")
if sym / len(a) > 0.02: # LaTeX / code punctuation density
return False
return True
def set_paths(data):
global DATA, SRC, CORPUS, SPM_PREFIX
DATA = data
SRC = os.path.join(DATA, "main_prefix.jsonl")
CORPUS = os.path.join(DATA, "glm_corpus.txt")
SPM_PREFIX = os.path.join(DATA, "glm16k")
def build_corpus(max_records=120_000, max_bytes=400_000_000):
n = 0; b = 0
with open(SRC, "r", encoding="utf-8", errors="ignore") as f, \
open(CORPUS, "w", encoding="utf-8") as out:
for line in f:
line = line.strip()
if not line: continue
try: r = json.loads(line)
except Exception: continue
txt = r["input"] + "\n" + r["output"] + "\n"
out.write(txt); b += len(txt); n += 1
if n >= max_records or b >= max_bytes: break
log(f"corpus: {n:,} records, {b/1e6:.1f} MB -> {CORPUS}")
def train_spm():
spm.SentencePieceTrainer.train(
input=CORPUS, model_prefix=SPM_PREFIX, vocab_size=VOCAB,
model_type="bpe", character_coverage=0.9995,
input_sentence_size=3_000_000, shuffle_input_sentence=True,
max_sentence_length=100000, num_threads=32,
unk_id=0, bos_id=1, eos_id=2, pad_id=-1,
byte_fallback=True,
)
log(f"trained SP -> {SPM_PREFIX}.model (vocab={VOCAB})")
def tokenize(val_frac=0.04, english_only=True):
sp = spm.SentencePieceProcessor(model_file=SPM_PREFIX + ".model")
eos = sp.eos_id()
log("scanning + filtering records...")
docs = []; seen = 0; t0 = time.time()
with open(SRC, "r", encoding="utf-8", errors="ignore") as f:
for line in f:
line = line.strip()
if not line: continue
try: r = json.loads(line)
except Exception: continue
seen += 1
a = answer_of(r)
if english_only and not is_english_answer(a):
continue
docs.append(r["input"].strip() + "\n\n" + a)
log(f"{seen:,} records -> {len(docs):,} kept "
f"({100*len(docs)/max(seen,1):.1f}%) english_only={english_only}")
n_val = int(len(docs) * val_frac)
splits = {"glm_train.bin": docs[:len(docs) - n_val],
"glm_valid.bin": docs[len(docs) - n_val:]}
counts = {}
for fname, dlist in splits.items():
nt = 0
with open(os.path.join(DATA, fname), "wb") as fo:
for b in range(0, len(dlist), 1000):
for ids in sp.encode(dlist[b:b + 1000]):
arr = np.array(ids + [eos], dtype=np.uint16)
arr.tofile(fo); nt += len(arr)
counts[fname] = nt
log(f"DONE train={counts['glm_train.bin']:,} tokens, "
f"valid={counts['glm_valid.bin']:,} tokens ({time.time()-t0:.0f}s)")
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("cmd", nargs="?", default="all",
choices=["download", "corpus", "spm", "tok", "all"])
ap.add_argument("--data", default=DATA)
ap.add_argument("--src", default=None)
ap.add_argument("--vocab", type=int, default=VOCAB)
ap.add_argument("--dataset", default=HF_DATASET)
ap.add_argument("--subset", default=None)
ap.add_argument("--split", default="train")
ap.add_argument("--max_records", type=int, default=0)
ap.add_argument("--english_only", type=int, default=1)
ap.add_argument("--val_frac", type=float, default=0.04)
args = ap.parse_args()
set_paths(args.data)
if args.src:
SRC = args.src
VOCAB = args.vocab
cmd = args.cmd
if cmd in ("download",):
download_jsonl(args.dataset, args.split, args.subset, args.max_records)
if cmd in ("corpus", "all"): build_corpus(max_records=args.max_records or 120_000)
if cmd in ("spm", "all"): train_spm()
if cmd in ("tok", "all"): tokenize(args.val_frac, bool(args.english_only))