File size: 4,905 Bytes
c5f49b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import json
import os
import random
import re
SEED = 1337
VAL_RATIO = 0.08
MAX_GENERAL_TRAIN = 3000
MAX_GENERAL_VAL = 420
REFINE_BOOST_RATIO = 0.65
SRC_GENERAL_TRAIN = os.path.join("data", "jarvis_train.txt")
SRC_GENERAL_VAL = os.path.join("data", "jarvis_val.txt")
SRC_REFINE_TRAIN = os.path.join("data", "jarvis_refine_train.txt")
SRC_REFINE_VAL = os.path.join("data", "jarvis_refine_val.txt")
OUT_TRAIN = os.path.join("data", "jarvis_mix_train.txt")
OUT_VAL = os.path.join("data", "jarvis_mix_val.txt")
OUT_REPORT = os.path.join("data", "jarvis_mix_report.json")
PAIR_RE = re.compile(r"User:\s*(.*?)\nAssistant:\s*(.*?)(?=\n\nUser:|\Z)", flags=re.S)
LOW_QUALITY_PATTERNS = [
"set a clear target, run one controlled test",
"keep only measurable improvements",
"steps: 1) prepare ingredients, 2) cook in short stages",
]
def normalize(text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def read_pairs(path):
if not os.path.exists(path):
return []
text = open(path, "r", encoding="utf-8", errors="ignore").read()
pairs = []
for user, assistant in PAIR_RE.findall(text):
u = normalize(user)
a = normalize(assistant)
if len(u) < 4 or len(a) < 8:
continue
if len(u) > 420 or len(a) > 840:
continue
pairs.append((u, a))
return pairs
def dedupe(pairs):
seen = set()
out = []
for u, a in pairs:
key = (u.lower(), a.lower())
if key in seen:
continue
seen.add(key)
out.append((u, a))
return out
def filter_low_quality(pairs):
out = []
for u, a in pairs:
u_low = u.lower()
a_low = a.lower()
if any(p in a_low for p in LOW_QUALITY_PATTERNS):
continue
if "context:" in u_low:
continue
if "\n" in u:
continue
if "[" in a or "]" in a:
continue
if "â" in a:
continue
a_words = len(re.findall(r"[A-Za-z0-9']+", a))
if a_words < 7 or a_words > 140:
continue
if not re.search(r"[.!?]$", a):
continue
out.append((u, a))
return out
def sample_without_replacement(rng, items, count):
if count <= 0:
return []
if count >= len(items):
return list(items)
return rng.sample(items, count)
def sample_with_replacement(rng, items, count):
if count <= 0 or not items:
return []
return [rng.choice(items) for _ in range(count)]
def write_pairs(path, pairs):
with open(path, "w", encoding="utf-8") as f:
for u, a in pairs:
f.write(f"User: {u}\nAssistant: {a}\n\n")
def main():
rng = random.Random(SEED)
general_train = dedupe(filter_low_quality(read_pairs(SRC_GENERAL_TRAIN)))
general_val = dedupe(filter_low_quality(read_pairs(SRC_GENERAL_VAL)))
refine_train = dedupe(filter_low_quality(read_pairs(SRC_REFINE_TRAIN)))
refine_val = dedupe(filter_low_quality(read_pairs(SRC_REFINE_VAL)))
rng.shuffle(general_train)
rng.shuffle(general_val)
general_train = general_train[:MAX_GENERAL_TRAIN]
general_val = general_val[:MAX_GENERAL_VAL]
refine_key = {(u.lower(), a.lower()) for u, a in refine_train}
general_only_train = [p for p in general_train if (p[0].lower(), p[1].lower()) not in refine_key]
train_core = refine_train + general_only_train
rng.shuffle(train_core)
boost_count = int(len(refine_train) * REFINE_BOOST_RATIO)
boosted_refine = sample_with_replacement(rng, refine_train, boost_count)
mix_train = train_core + boosted_refine
rng.shuffle(mix_train)
# Validation set from dedicated val files + small holdout from non-boosted train core.
val_holdout_n = max(120, int(len(train_core) * VAL_RATIO))
holdout = sample_without_replacement(rng, train_core, val_holdout_n)
mix_val = dedupe(refine_val + general_val + holdout)
# Remove validation leakage from training.
val_key = {(u.lower(), a.lower()) for u, a in mix_val}
mix_train = [(u, a) for u, a in mix_train if (u.lower(), a.lower()) not in val_key]
rng.shuffle(mix_train)
os.makedirs("data", exist_ok=True)
write_pairs(OUT_TRAIN, mix_train)
write_pairs(OUT_VAL, mix_val)
report = {
"seed": SEED,
"val_ratio": VAL_RATIO,
"general_train_used": len(general_train),
"general_val_used": len(general_val),
"refine_train_used": len(refine_train),
"refine_val_used": len(refine_val),
"boosted_refine_rows": len(boosted_refine),
"mix_train_rows": len(mix_train),
"mix_val_rows": len(mix_val),
"out_train": OUT_TRAIN,
"out_val": OUT_VAL,
}
with open(OUT_REPORT, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()
|