python_ai_coder / make_python_dataset.py
Percy3822's picture
Create make_python_dataset.py
304794d verified
raw
history blame
10.8 kB
import argparse, os, json, random, gzip, time, math, textwrap, hashlib
from datetime import datetime
# Determinism but still varied
random.seed(1337)
# ---------------------- Content Pools ----------------------
# Small library of canonical functions to expand/perturb
FUNC_SNIPPETS = {
"sum_list": textwrap.dedent("""\
from typing import Iterable
def sum_list(nums: Iterable[int]) -> int:
\"\"\"Return the sum of integers in an iterable.\"\"\"
return sum(nums)
""").strip(),
"is_prime": textwrap.dedent("""\
from math import isqrt
def is_prime(n: int) -> bool:
\"\"\"Return True if n is prime (n>=2).\"\"\"
if n < 2:
return False
if n % 2 == 0:
return n == 2
r = isqrt(n)
f = 3
while f <= r:
if n % f == 0:
return False
f += 2
return True
""").strip(),
"fizzbuzz": textwrap.dedent("""\
def fizzbuzz(n: int) -> list[str]:
\"\"\"Generate FizzBuzz from 1..n.\"\"\"
out = []
for i in range(1, n + 1):
s = ""
if i % 3 == 0:
s += "Fizz"
if i % 5 == 0:
s += "Buzz"
out.append(s or str(i))
return out
""").strip(),
"factorial": textwrap.dedent("""\
def factorial(n: int) -> int:
\"\"\"Return n! for n>=0.\"\"\"
if n < 0:
raise ValueError("n must be >= 0")
res = 1
for i in range(2, n + 1):
res *= i
return res
""").strip(),
"fib": textwrap.dedent("""\
def fib(n: int) -> int:
\"\"\"Return nth Fibonacci number (0-indexed).\"\"\"
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a
""").strip(),
}
# Buggy variants (to generate "debug/fix/explain" style samples)
BUGGY_SNIPPETS = [
(
"IndexError on short lists",
"def last_two_sum(xs):\n return xs[-1] + xs[-2]\n",
textwrap.dedent("""\
def last_two_sum(xs):
if len(xs) < 2:
raise ValueError("Need at least 2 items")
return xs[-1] + xs[-2]
""").strip(),
"Accessing [-1] and [-2] without length check can raise IndexError.",
),
(
"Mutable default arg",
"def add_name(name, xs=[]):\n xs.append(name)\n return xs\n",
textwrap.dedent("""\
def add_name(name, xs=None):
if xs is None:
xs = []
xs.append(name)
return xs
""").strip(),
"Default list is shared across calls. Use None sentinel.",
),
(
"Shadowing builtins",
"def max(nums):\n return 0\n",
textwrap.dedent("""\
from typing import Iterable
def my_max(nums: Iterable[int]) -> int:
return _builtins_['max'](nums)
""").strip(),
"Function named 'max' shadows builtin, causing confusion. Rename or call builtins explicitly.",
),
(
"Off-by-one",
"def range_inclusive(a,b):\n return list(range(a,b))\n",
textwrap.dedent("""\
def range_inclusive(a: int, b: int) -> list[int]:
return list(range(a, b + 1))
""").strip(),
"range(a,b) excludes b; inclusive needs b+1.",
),
]
EXPLAIN_TOPICS = [
("list comprehensions", "squares = [x*x for x in range(5)]"),
("generators", "def gen():\n yield from range(3)"),
("context managers", "with open('file.txt','r',encoding='utf-8') as f:\n data = f.read()"),
("dataclasses", "from dataclasses import dataclass\n@dataclass\nclass User:\n name: str\n age: int"),
("typing generics", "from typing import TypeVar, Generic, Iterable\nT = TypeVar('T')\nclass Box(Generic[T]):\n def _init_(self, x: T): self.x = x"),
("asyncio basics", "import asyncio\nasync def main():\n await asyncio.sleep(1)"),
("pandas intro", "import pandas as pd\ndf = pd.DataFrame({'a':[1,2]})"),
]
TASKS_ADVANCED = [
"Write a CLI using argparse that reads a CSV and prints the average of a column.",
"Design a file watcher that logs changes in a directory (cross-platform).",
"Write a simple REST API using FastAPI with one POST /sum endpoint and tests.",
"Implement a thread-safe counter class with locks and unit tests.",
"Write an async HTTP fetcher with rate limiting and retries (aiohttp).",
"Implement a minimal LRU cache with dict + doubly-linked list and tests.",
]
# Prompt styles
WRITE_PREFIX = [
"Write a Python function with type hints and docstring:\n",
"Implement production-quality Python with clear naming and tests:\n",
]
DEBUG_PREFIX = [
"Find and fix the bug in this Python code. Provide corrected code and a brief explanation:\n",
]
EXPLAIN_PREFIX = [
"Explain what this code does and possible pitfalls:\n",
"Explain the concept and show a short example:\n",
]
REFACTOR_PREFIX = [
"Refactor this snippet to be more Pythonic and readable:\n",
]
TEST_PREFIX = [
"Write pytest unit tests for the following function:\n",
]
DOC_PREFIX = [
"Write a docstring and usage example for the following function:\n",
]
# ---------------------- Sample Builders ----------------------
def make_write_sample():
name = random.choice(list(FUNC_SNIPPETS.keys()))
code = FUNC_SNIPPETS[name]
prompt = random.choice(WRITE_PREFIX) + f"{name}()\n"
completion = code
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_debug_sample():
title, bad, fixed, why = random.choice(BUGGY_SNIPPETS)
prompt = random.choice(DEBUG_PREFIX) + bad
completion = f"Fixed version:\npython\n{fixed}\n\nReason: {why}"
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_explain_sample():
topic, example = random.choice(EXPLAIN_TOPICS)
if random.random() < 0.5:
prompt = random.choice(EXPLAIN_PREFIX) + f"python\n{example}\n"
completion = f"{topic.title()} explained with example above."
else:
prompt = f"Explain {topic} with a short example."
completion = f"Example:\npython\n{example}\n"
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_refactor_sample():
src = "for i in range(len(items)):\n print(items[i])"
prompt = random.choice(REFACTOR_PREFIX) + "python\n" + src + "\n"
completion = "Use direct iteration:\npython\nfor item in items:\n print(item)\n"
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_tests_sample():
# choose a base function to test
name = random.choice(list(FUNC_SNIPPETS.keys()))
code = FUNC_SNIPPETS[name]
prompt = random.choice(TEST_PREFIX) + "python\n" + code + "\n"
tests = textwrap.dedent(f"""\
import pytest
{code}
def test_basic():
assert isinstance({name}(0 if '{name}'=='fib' else []), (int, list, bool, type(None)))
""").strip()
completion = "python\n" + tests + "\n"
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_doc_sample():
name = random.choice(list(FUNC_SNIPPETS.keys()))
code = FUNC_SNIPPETS[name]
# remove docstring then ask to add one
no_doc = "\n".join([ln for ln in code.splitlines() if '\"\"\"' not in ln])
prompt = random.choice(DOC_PREFIX) + "python\n" + no_doc + "\n"
completion = "Add docstring and example usage:\npython\n" + code + "\n"
return {"prompt": prompt.strip(), "completion": completion.strip()}
def make_advanced_task_sample():
task = random.choice(TASKS_ADVANCED)
prompt = task
completion = "Provide code outline, then full implementation with comments and minimal tests."
return {"prompt": prompt.strip(), "completion": completion.strip()}
MAKERS = [
(make_write_sample, 0.25),
(make_debug_sample, 0.20),
(make_explain_sample, 0.20),
(make_refactor_sample, 0.10),
(make_tests_sample, 0.10),
(make_doc_sample, 0.10),
(make_advanced_task_sample, 0.05),
]
# cumulative weights
def _pick_maker():
r = random.random()
acc = 0.0
for fn, w in MAKERS:
acc += w
if r <= acc:
return fn
return MAKERS[-1][0]
# ---------------------- Write Shards ----------------------
def write_shard(path, n_rows):
with gzip.open(path, "wt", encoding="utf-8") as f:
for _ in range(n_rows):
sample = _pick_maker()()
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
def sha256_file(path):
import hashlib
h = hashlib.sha256()
with open(path, "rb") as rf:
while True:
chunk = rf.read(1 << 20)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--total", type=int, default=1_000_000, help="Total samples to generate")
ap.add_argument("--shard_size", type=int, default=10_000, help="Rows per shard")
ap.add_argument("--out_dir", type=str, default="python_dataset_v1", help="Output directory")
ap.add_argument("--prefix", type=str, default="python", help="Prefix for files")
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
n_shards = math.ceil(args.total / args.shard_size)
manifest = {
"created": datetime.utcnow().isoformat() + "Z",
"total": args.total,
"shard_size": args.shard_size,
"num_shards": n_shards,
"files": []
}
print(f"Generating {args.total:,} samples in {n_shards} shards of {args.shard_size}...")
for i in range(n_shards):
rows = args.shard_size if (i < n_shards - 1) else (args.total - args.shard_size * (n_shards - 1))
shard_path = os.path.join(args.out_dir, f"{args.prefix}_{i:04d}.jsonl.gz")
t0 = time.time()
write_shard(shard_path, rows)
digest = sha256_file(shard_path)
manifest["files"].append({"path": shard_path, "rows": rows, "sha256": digest})
dt = time.time() - t0
print(f"[{i+1}/{n_shards}] wrote {rows} rows → {os.path.basename(shard_path)} in {dt:.1f}s")
man_path = os.path.join(args.out_dir, f"{args.prefix}_manifest.json")
with open(man_path, "w", encoding="utf-8") as mf:
json.dump(manifest, mf, indent=2)
print("Done. Manifest:", man_path)
# show one example
ex = _pick_maker()()
print("Sample row:", json.dumps(ex, ensure_ascii=False))
if __name__ == "__main__":
main()