File size: 6,857 Bytes
a216fa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Prepare GLM-5.1-Reasoning (main subset prefix) for the analytic model.

1. build a text corpus sample and train a 16k BPE SentencePiece tokenizer
   (keeps the V x V PMI cooccurrence + SVD feasible, unlike GPT-2's 50k);
2. tokenize every record as   input + "\n\n" + output   with <eos> between
   documents, into train.bin / valid.bin (uint16).
"""
import os, sys, json, time, numpy as np
import sentencepiece as spm

DATA = os.environ.get("ONESHOT_DATA", "/workspace/ts")
SRC = os.path.join(DATA, "main_prefix.jsonl")
CORPUS = os.path.join(DATA, "glm_corpus.txt")
SPM_PREFIX = os.path.join(DATA, "glm16k")
VOCAB = 16384
HF_DATASET = "Jackrong/GLM-5.1-Reasoning-1M-Cleaned"

def log(*a): print(f"[{time.strftime('%H:%M:%S')}]", *a, flush=True)

def first_present(record, names, default=""):
    for name in names:
        if name in record and record[name] is not None:
            return record[name]
    return default

def normalize_record(record):
    inp = first_present(record, ["input", "prompt", "instruction", "question", "query"])
    out = first_present(record, ["output", "response", "answer", "completion"])
    if isinstance(inp, (list, dict)):
        inp = json.dumps(inp, ensure_ascii=False)
    if isinstance(out, (list, dict)):
        out = json.dumps(out, ensure_ascii=False)
    inp = str(inp).strip()
    out = str(out).strip()
    if not inp or not out:
        return None
    return {"input": inp, "output": out}

def download_jsonl(dataset=HF_DATASET, split="train", subset=None, max_records=0):
    from datasets import load_dataset
    kwargs = {"split": split, "streaming": True}
    ds = load_dataset(dataset, subset, **kwargs) if subset else load_dataset(dataset, **kwargs)
    n = 0
    os.makedirs(DATA, exist_ok=True)
    with open(SRC, "w", encoding="utf-8") as out:
        for row in ds:
            rec = normalize_record(row)
            if rec is None:
                continue
            out.write(json.dumps(rec, ensure_ascii=False) + "\n")
            n += 1
            if n % 10000 == 0:
                log(f"downloaded {n:,} records -> {SRC}")
            if max_records and n >= max_records:
                break
    log(f"download done: {n:,} records -> {SRC}")

def answer_of(r):
    """The actual English response: the <think> reasoning dump is stripped,
    keep the final answer."""
    o = r.get("output", "")
    if "</think>" in o:
        o = o.split("</think>")[-1]
    return o.strip()

def is_english_answer(a):
    """Keep natural-language answers; drop code/math/LaTeX-dominated ones so the
    model learns to answer in plain English (the 'answer English' goal)."""
    if not (40 <= len(a) <= 4000):
        return False
    if "```" in a:                       # code fence
        return False
    alpha = sum(c.isalpha() or c.isspace() for c in a) / len(a)
    if alpha < 0.93:
        return False
    sym = sum(a.count(c) for c in "{}\\$=#|<>_~^")
    if sym / len(a) > 0.02:              # LaTeX / code punctuation density
        return False
    return True

def set_paths(data):
    global DATA, SRC, CORPUS, SPM_PREFIX
    DATA = data
    SRC = os.path.join(DATA, "main_prefix.jsonl")
    CORPUS = os.path.join(DATA, "glm_corpus.txt")
    SPM_PREFIX = os.path.join(DATA, "glm16k")

def build_corpus(max_records=120_000, max_bytes=400_000_000):
    n = 0; b = 0
    with open(SRC, "r", encoding="utf-8", errors="ignore") as f, \
         open(CORPUS, "w", encoding="utf-8") as out:
        for line in f:
            line = line.strip()
            if not line: continue
            try: r = json.loads(line)
            except Exception: continue
            txt = r["input"] + "\n" + r["output"] + "\n"
            out.write(txt); b += len(txt); n += 1
            if n >= max_records or b >= max_bytes: break
    log(f"corpus: {n:,} records, {b/1e6:.1f} MB -> {CORPUS}")

def train_spm():
    spm.SentencePieceTrainer.train(
        input=CORPUS, model_prefix=SPM_PREFIX, vocab_size=VOCAB,
        model_type="bpe", character_coverage=0.9995,
        input_sentence_size=3_000_000, shuffle_input_sentence=True,
        max_sentence_length=100000, num_threads=32,
        unk_id=0, bos_id=1, eos_id=2, pad_id=-1,
        byte_fallback=True,
    )
    log(f"trained SP -> {SPM_PREFIX}.model (vocab={VOCAB})")

def tokenize(val_frac=0.04, english_only=True):
    sp = spm.SentencePieceProcessor(model_file=SPM_PREFIX + ".model")
    eos = sp.eos_id()
    log("scanning + filtering records...")
    docs = []; seen = 0; t0 = time.time()
    with open(SRC, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            try: r = json.loads(line)
            except Exception: continue
            seen += 1
            a = answer_of(r)
            if english_only and not is_english_answer(a):
                continue
            docs.append(r["input"].strip() + "\n\n" + a)
    log(f"{seen:,} records -> {len(docs):,} kept "
        f"({100*len(docs)/max(seen,1):.1f}%) english_only={english_only}")
    n_val = int(len(docs) * val_frac)
    splits = {"glm_train.bin": docs[:len(docs) - n_val],
              "glm_valid.bin": docs[len(docs) - n_val:]}
    counts = {}
    for fname, dlist in splits.items():
        nt = 0
        with open(os.path.join(DATA, fname), "wb") as fo:
            for b in range(0, len(dlist), 1000):
                for ids in sp.encode(dlist[b:b + 1000]):
                    arr = np.array(ids + [eos], dtype=np.uint16)
                    arr.tofile(fo); nt += len(arr)
        counts[fname] = nt
    log(f"DONE train={counts['glm_train.bin']:,} tokens, "
        f"valid={counts['glm_valid.bin']:,} tokens ({time.time()-t0:.0f}s)")

if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser()
    ap.add_argument("cmd", nargs="?", default="all",
                    choices=["download", "corpus", "spm", "tok", "all"])
    ap.add_argument("--data", default=DATA)
    ap.add_argument("--src", default=None)
    ap.add_argument("--vocab", type=int, default=VOCAB)
    ap.add_argument("--dataset", default=HF_DATASET)
    ap.add_argument("--subset", default=None)
    ap.add_argument("--split", default="train")
    ap.add_argument("--max_records", type=int, default=0)
    ap.add_argument("--english_only", type=int, default=1)
    ap.add_argument("--val_frac", type=float, default=0.04)
    args = ap.parse_args()
    set_paths(args.data)
    if args.src:
        SRC = args.src
    VOCAB = args.vocab
    cmd = args.cmd
    if cmd in ("download",):
        download_jsonl(args.dataset, args.split, args.subset, args.max_records)
    if cmd in ("corpus", "all"): build_corpus(max_records=args.max_records or 120_000)
    if cmd in ("spm", "all"): train_spm()
    if cmd in ("tok", "all"): tokenize(args.val_frac, bool(args.english_only))