AI / scripts /build_mixed_refine_data.py
shsplas's picture
Upload 15 files
c5f49b9 verified
Raw
History Blame Contribute Delete
4.91 kB
import json
import os
import random
import re
SEED = 1337
VAL_RATIO = 0.08
MAX_GENERAL_TRAIN = 3000
MAX_GENERAL_VAL = 420
REFINE_BOOST_RATIO = 0.65
SRC_GENERAL_TRAIN = os.path.join("data", "jarvis_train.txt")
SRC_GENERAL_VAL = os.path.join("data", "jarvis_val.txt")
SRC_REFINE_TRAIN = os.path.join("data", "jarvis_refine_train.txt")
SRC_REFINE_VAL = os.path.join("data", "jarvis_refine_val.txt")
OUT_TRAIN = os.path.join("data", "jarvis_mix_train.txt")
OUT_VAL = os.path.join("data", "jarvis_mix_val.txt")
OUT_REPORT = os.path.join("data", "jarvis_mix_report.json")
PAIR_RE = re.compile(r"User:\s*(.*?)\nAssistant:\s*(.*?)(?=\n\nUser:|\Z)", flags=re.S)
LOW_QUALITY_PATTERNS = [
"set a clear target, run one controlled test",
"keep only measurable improvements",
"steps: 1) prepare ingredients, 2) cook in short stages",
]
def normalize(text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def read_pairs(path):
if not os.path.exists(path):
return []
text = open(path, "r", encoding="utf-8", errors="ignore").read()
pairs = []
for user, assistant in PAIR_RE.findall(text):
u = normalize(user)
a = normalize(assistant)
if len(u) < 4 or len(a) < 8:
continue
if len(u) > 420 or len(a) > 840:
continue
pairs.append((u, a))
return pairs
def dedupe(pairs):
seen = set()
out = []
for u, a in pairs:
key = (u.lower(), a.lower())
if key in seen:
continue
seen.add(key)
out.append((u, a))
return out
def filter_low_quality(pairs):
out = []
for u, a in pairs:
u_low = u.lower()
a_low = a.lower()
if any(p in a_low for p in LOW_QUALITY_PATTERNS):
continue
if "context:" in u_low:
continue
if "\n" in u:
continue
if "[" in a or "]" in a:
continue
if "â" in a:
continue
a_words = len(re.findall(r"[A-Za-z0-9']+", a))
if a_words < 7 or a_words > 140:
continue
if not re.search(r"[.!?]$", a):
continue
out.append((u, a))
return out
def sample_without_replacement(rng, items, count):
if count <= 0:
return []
if count >= len(items):
return list(items)
return rng.sample(items, count)
def sample_with_replacement(rng, items, count):
if count <= 0 or not items:
return []
return [rng.choice(items) for _ in range(count)]
def write_pairs(path, pairs):
with open(path, "w", encoding="utf-8") as f:
for u, a in pairs:
f.write(f"User: {u}\nAssistant: {a}\n\n")
def main():
rng = random.Random(SEED)
general_train = dedupe(filter_low_quality(read_pairs(SRC_GENERAL_TRAIN)))
general_val = dedupe(filter_low_quality(read_pairs(SRC_GENERAL_VAL)))
refine_train = dedupe(filter_low_quality(read_pairs(SRC_REFINE_TRAIN)))
refine_val = dedupe(filter_low_quality(read_pairs(SRC_REFINE_VAL)))
rng.shuffle(general_train)
rng.shuffle(general_val)
general_train = general_train[:MAX_GENERAL_TRAIN]
general_val = general_val[:MAX_GENERAL_VAL]
refine_key = {(u.lower(), a.lower()) for u, a in refine_train}
general_only_train = [p for p in general_train if (p[0].lower(), p[1].lower()) not in refine_key]
train_core = refine_train + general_only_train
rng.shuffle(train_core)
boost_count = int(len(refine_train) * REFINE_BOOST_RATIO)
boosted_refine = sample_with_replacement(rng, refine_train, boost_count)
mix_train = train_core + boosted_refine
rng.shuffle(mix_train)
# Validation set from dedicated val files + small holdout from non-boosted train core.
val_holdout_n = max(120, int(len(train_core) * VAL_RATIO))
holdout = sample_without_replacement(rng, train_core, val_holdout_n)
mix_val = dedupe(refine_val + general_val + holdout)
# Remove validation leakage from training.
val_key = {(u.lower(), a.lower()) for u, a in mix_val}
mix_train = [(u, a) for u, a in mix_train if (u.lower(), a.lower()) not in val_key]
rng.shuffle(mix_train)
os.makedirs("data", exist_ok=True)
write_pairs(OUT_TRAIN, mix_train)
write_pairs(OUT_VAL, mix_val)
report = {
"seed": SEED,
"val_ratio": VAL_RATIO,
"general_train_used": len(general_train),
"general_val_used": len(general_val),
"refine_train_used": len(refine_train),
"refine_val_used": len(refine_val),
"boosted_refine_rows": len(boosted_refine),
"mix_train_rows": len(mix_train),
"mix_val_rows": len(mix_val),
"out_train": OUT_TRAIN,
"out_val": OUT_VAL,
}
with open(OUT_REPORT, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()