json_ai / make_json_dataset.py
Percy3822's picture
Create make_json_dataset.py
1f320b1 verified
import argparse, os, json, random, gzip, time, math, hashlib, textwrap
from datetime import datetime
random.seed(2025)
# ------------ Content pools (JSON-focused) ------------
# Simple base objects we combine/perturb
BASE_OBJECTS = [
{"id": 1, "name": "Widget", "price": 9.99, "tags": ["tools", "home"]},
{"id": 2, "name": "Gadget", "price": 19.5, "tags": ["electronics"]},
{"id": 3, "name": "Thing", "price": 2.75, "tags": []},
]
INVALID_SAMPLES = [
("Trailing comma", "{ \"a\": 1, }", "Remove trailing comma after the last property."),
("Single quotes", "{ 'a': 1 }", "Use double quotes for keys and string values."),
("Unquoted key", "{ a: 1 }", "Quote keys: { \"a\": 1 }."),
("NaN value", "{ \"x\": NaN }", "JSON has no NaN; use null or a string."),
("Comments", "{ /* note */ \"a\": 1 }", "JSON does not allow comments."),
]
SCHEMAS = [
("product",
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Product",
"type": "object",
"required": ["id", "name", "price", "tags"],
"properties": {
"id": {"type": "integer", "minimum": 1},
"name": {"type": "string", "minLength": 1},
"price": {"type": "number", "minimum": 0},
"tags": {"type": "array", "items": {"type": "string"}}
},
"additionalProperties": False
}),
("invoice",
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Invoice",
"type": "object",
"required": ["number", "currency", "items", "total"],
"properties": {
"number": {"type": "string"},
"currency": {"type": "string"},
"items": {"type": "array", "items": {
"type": "object",
"required": ["name", "qty", "price"],
"properties": {
"name": {"type": "string"},
"qty": {"type": "integer", "minimum": 1},
"price": {"type": "number", "minimum": 0}
},
"additionalProperties": False
}},
"total": {"type": "number", "minimum": 0}
},
"additionalProperties": False
}),
]
TRANSFORM_TASKS = [
("Pick fields", {"keep": ["id", "name"]}),
("Add field", {"add": {"in_stock": True}}),
("Rename field", {"rename": {"price": "unit_price"}}),
("Compute field", {"compute": {"value": "price * 1.2"}}),
]
API_TASKS = [
("Create payload", "Create JSON payload for POST /orders with fields: id (string), customer {id,name}, items [{sku,qty,price}], totals {subtotal,tax,total}."),
("Validate payload", "Given schema below, validate a sample order and fix errors."),
]
EXPLAINERS = [
"Explain why this JSON is invalid and provide a fixed version.",
"Explain differences between JSON and JSON5.",
"Explain when to use JSON Schema and provide a minimal example.",
]
# ------------ Prompt styles ------------
WRITE_PREFIX = [
"Generate valid JSON for the following entity. Use only double quotes, no trailing commas:\nEntity: ",
"Produce compact JSON (no comments) for: ",
]
FIX_PREFIX = [
"Fix the invalid JSON. Provide corrected JSON only:\n",
"The following is invalid. Output a valid JSON equivalent:\n",
]
EXPLAIN_PREFIX = [
"Explain the error and give a valid JSON version:\n",
]
SCHEMA_PREFIX = [
"Write a JSON Schema (2020-12) for the entity:\n",
]
SCHEMA_VALIDATE_PREFIX = [
"Given this schema, provide a valid sample JSON instance:\n",
]
TRANSFORM_PREFIX = [
"Transform the input JSON per the rules. Output the new JSON only.\nRules:\n",
]
API_PREFIX = [
"Design the JSON payload for this API:\n",
"Validate this JSON against the schema and correct it.\n",
]
def to_json(obj): # pretty and deterministic field order
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), sort_keys=True)
# ------------ Sample builders ------------
def make_write_sample():
base = random.choice(BASE_OBJECTS)
entity = random.choice(["product", "catalog_item", "inventory_record"])
prompt = random.choice(WRITE_PREFIX) + entity
completion = to_json(base)
return {"prompt": prompt, "completion": completion}
def make_fix_sample():
title, bad, tip = random.choice(INVALID_SAMPLES)
prompt = random.choice(FIX_PREFIX) + bad
# naive fix strategy:
fixed = bad.replace("'", '"').replace(", }", " }").replace("NaN", "null").replace("/* note */", "")
# ensure valid JSON, fallback to simple {"a":1}
try:
obj = json.loads(fixed)
except Exception:
obj = {"a": 1}
completion = to_json(obj)
return {"prompt": prompt, "completion": completion}
def make_explain_sample():
title, bad, tip = random.choice(INVALID_SAMPLES)
prompt = random.choice(EXPLAIN_PREFIX) + bad
# fixed version
fixed = bad.replace("'", '"').replace(", }", " }").replace("NaN", "null").replace("/* note */", "")
try:
obj = json.loads(fixed)
except Exception:
obj = {"a": 1}
completion = f"{tip}\nFixed:\n{to_json(obj)}"
return {"prompt": prompt, "completion": completion}
def make_schema_sample():
name, schema = random.choice(SCHEMAS)
prompt = random.choice(SCHEMA_PREFIX) + name
completion = json.dumps(schema, ensure_ascii=False, indent=2)
return {"prompt": prompt, "completion": completion}
def make_schema_instance_sample():
name, schema = random.choice(SCHEMAS)
prompt = random.choice(SCHEMA_VALIDATE_PREFIX) + json.dumps(schema, ensure_ascii=False, indent=2)
# very naive instance generator
if name == "product":
instance = {"id": 1, "name": "Sample", "price": 1.23, "tags": ["demo"]}
else:
instance = {
"number": "INV-001",
"currency": "USD",
"items": [{"name": "Sample", "qty": 1, "price": 1.0}],
"total": 1.0
}
completion = to_json(instance)
return {"prompt": prompt, "completion": completion}
def apply_transform(obj, rules):
if "keep" in rules:
obj = {k: obj[k] for k in rules["keep"] if k in obj}
if "add" in rules:
for k, v in rules["add"].items():
obj[k] = v
if "rename" in rules:
for old, new in rules["rename"].items():
if old in obj:
obj[new] = obj.pop(old)
if "compute" in rules:
# allow only simple expression "price * 1.2"
expr = rules["compute"].get("value", "")
if "price" in obj and "* 1.2" in expr:
try:
obj["value"] = round(float(obj["price"]) * 1.2, 2)
except Exception:
pass
return obj
def make_transform_sample():
base = random.choice(BASE_OBJECTS)
title, rules = random.choice(TRANSFORM_TASKS)
prompt = random.choice(TRANSFORM_PREFIX) + json.dumps(rules, ensure_ascii=False, indent=2) + "\nInput:\n" + to_json(base)
out = apply_transform(dict(base), rules)
completion = to_json(out)
return {"prompt": prompt, "completion": completion}
def make_api_sample():
title, instruction = random.choice(API_TASKS)
prompt = random.choice(API_PREFIX) + instruction
if "Create payload" in title:
completion = to_json({
"id": "ORD-1001",
"customer": {"id": "C-1", "name": "Alex"},
"items": [{"sku": "ABC", "qty": 2, "price": 4.5}],
"totals": {"subtotal": 9.0, "tax": 0.72, "total": 9.72}
})
else:
# minimal example of "invalid" -> "fixed"
bad = '{"id":"ORD-1001","items":[{"sku":"ABC","qty":0}]}'
fixed = {"id":"ORD-1001","items":[{"sku":"ABC","qty":1,"price":1.0}],"totals":{"subtotal":1.0,"tax":0.08,"total":1.08}}
completion = f"Invalid: {bad}\nFixed: {to_json(fixed)}"
return {"prompt": prompt, "completion": completion}
MAKERS = [
(make_write_sample, 0.25),
(make_fix_sample, 0.20),
(make_explain_sample, 0.10),
(make_schema_sample, 0.15),
(make_schema_instance_sample, 0.10),
(make_transform_sample, 0.10),
(make_api_sample, 0.10),
]
def pick_maker():
r, acc = random.random(), 0.0
for fn, w in MAKERS:
acc += w
if r <= acc:
return fn
return MAKERS[-1][0]
# ------------ Shard writer ------------
def write_shard(path, n_rows):
with gzip.open(path, "wt", encoding="utf-8") as f:
for _ in range(n_rows):
sample = pick_maker()()
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
def sha256_file(path):
h = hashlib.sha256()
with open(path, "rb") as rf:
while True:
chunk = rf.read(1 << 20)
if not chunk: break
h.update(chunk)
return h.hexdigest()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--total", type=int, default=1_000_000)
ap.add_argument("--shard_size", type=int, default=10_000)
ap.add_argument("--out_dir", type=str, default="json_dataset_v1")
ap.add_argument("--prefix", type=str, default="json")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
n_shards = math.ceil(args.total / args.shard_size)
manifest = {
"created": datetime.utcnow().isoformat() + "Z",
"total": args.total,
"shard_size": args.shard_size,
"num_shards": n_shards,
"files": []
}
print(f"Generating {args.total:,} samples in {n_shards} shards of {args.shard_size}...")
for i in range(n_shards):
rows = args.shard_size if i < n_shards - 1 else args.total - args.shard_size * (n_shards - 1)
shard_path = os.path.join(args.out_dir, f"{args.prefix}_{i:04d}.jsonl.gz")
t0 = time.time()
write_shard(shard_path, rows)
digest = sha256_file(shard_path)
dt = time.time() - t0
manifest["files"].append({"path": shard_path, "rows": rows, "sha256": digest})
print(f"[{i+1}/{n_shards}] wrote {rows} rows → {os.path.basename(shard_path)} in {dt:.1f}s")
man_path = os.path.join(args.out_dir, f"{args.prefix}_manifest.json")
with open(man_path, "w", encoding="utf-8") as mf:
json.dump(manifest, mf, indent=2)
print("Done. Manifest:", man_path)
# print a sample row
print("Sample:", json.dumps(pick_maker()(), ensure_ascii=False))
if __name__ == "__main__":
main()