import argparse, os, json, random, gzip, time, math, textwrap, hashlib from datetime import datetime # Determinism but still varied random.seed(1337) # ---------------------- Content Pools ---------------------- # Small library of canonical functions to expand/perturb FUNC_SNIPPETS = { "sum_list": textwrap.dedent("""\ from typing import Iterable def sum_list(nums: Iterable[int]) -> int: \"\"\"Return the sum of integers in an iterable.\"\"\" return sum(nums) """).strip(), "is_prime": textwrap.dedent("""\ from math import isqrt def is_prime(n: int) -> bool: \"\"\"Return True if n is prime (n>=2).\"\"\" if n < 2: return False if n % 2 == 0: return n == 2 r = isqrt(n) f = 3 while f <= r: if n % f == 0: return False f += 2 return True """).strip(), "fizzbuzz": textwrap.dedent("""\ def fizzbuzz(n: int) -> list[str]: \"\"\"Generate FizzBuzz from 1..n.\"\"\" out = [] for i in range(1, n + 1): s = "" if i % 3 == 0: s += "Fizz" if i % 5 == 0: s += "Buzz" out.append(s or str(i)) return out """).strip(), "factorial": textwrap.dedent("""\ def factorial(n: int) -> int: \"\"\"Return n! for n>=0.\"\"\" if n < 0: raise ValueError("n must be >= 0") res = 1 for i in range(2, n + 1): res *= i return res """).strip(), "fib": textwrap.dedent("""\ def fib(n: int) -> int: \"\"\"Return nth Fibonacci number (0-indexed).\"\"\" a, b = 0, 1 for _ in range(n): a, b = b, a + b return a """).strip(), } # Buggy variants (to generate "debug/fix/explain" style samples) BUGGY_SNIPPETS = [ ( "IndexError on short lists", "def last_two_sum(xs):\n return xs[-1] + xs[-2]\n", textwrap.dedent("""\ def last_two_sum(xs): if len(xs) < 2: raise ValueError("Need at least 2 items") return xs[-1] + xs[-2] """).strip(), "Accessing [-1] and [-2] without length check can raise IndexError.", ), ( "Mutable default arg", "def add_name(name, xs=[]):\n xs.append(name)\n return xs\n", textwrap.dedent("""\ def add_name(name, xs=None): if xs is None: xs = [] xs.append(name) return xs """).strip(), "Default list is shared across calls. Use None sentinel.", ), ( "Shadowing builtins", "def max(nums):\n return 0\n", textwrap.dedent("""\ from typing import Iterable def my_max(nums: Iterable[int]) -> int: return _builtins_['max'](nums) """).strip(), "Function named 'max' shadows builtin, causing confusion. Rename or call builtins explicitly.", ), ( "Off-by-one", "def range_inclusive(a,b):\n return list(range(a,b))\n", textwrap.dedent("""\ def range_inclusive(a: int, b: int) -> list[int]: return list(range(a, b + 1)) """).strip(), "range(a,b) excludes b; inclusive needs b+1.", ), ] EXPLAIN_TOPICS = [ ("list comprehensions", "squares = [x*x for x in range(5)]"), ("generators", "def gen():\n yield from range(3)"), ("context managers", "with open('file.txt','r',encoding='utf-8') as f:\n data = f.read()"), ("dataclasses", "from dataclasses import dataclass\n@dataclass\nclass User:\n name: str\n age: int"), ("typing generics", "from typing import TypeVar, Generic, Iterable\nT = TypeVar('T')\nclass Box(Generic[T]):\n def _init_(self, x: T): self.x = x"), ("asyncio basics", "import asyncio\nasync def main():\n await asyncio.sleep(1)"), ("pandas intro", "import pandas as pd\ndf = pd.DataFrame({'a':[1,2]})"), ] TASKS_ADVANCED = [ "Write a CLI using argparse that reads a CSV and prints the average of a column.", "Design a file watcher that logs changes in a directory (cross-platform).", "Write a simple REST API using FastAPI with one POST /sum endpoint and tests.", "Implement a thread-safe counter class with locks and unit tests.", "Write an async HTTP fetcher with rate limiting and retries (aiohttp).", "Implement a minimal LRU cache with dict + doubly-linked list and tests.", ] # Prompt styles WRITE_PREFIX = [ "Write a Python function with type hints and docstring:\n", "Implement production-quality Python with clear naming and tests:\n", ] DEBUG_PREFIX = [ "Find and fix the bug in this Python code. Provide corrected code and a brief explanation:\n", ] EXPLAIN_PREFIX = [ "Explain what this code does and possible pitfalls:\n", "Explain the concept and show a short example:\n", ] REFACTOR_PREFIX = [ "Refactor this snippet to be more Pythonic and readable:\n", ] TEST_PREFIX = [ "Write pytest unit tests for the following function:\n", ] DOC_PREFIX = [ "Write a docstring and usage example for the following function:\n", ] # ---------------------- Sample Builders ---------------------- def make_write_sample(): name = random.choice(list(FUNC_SNIPPETS.keys())) code = FUNC_SNIPPETS[name] prompt = random.choice(WRITE_PREFIX) + f"{name}()\n" completion = code return {"prompt": prompt.strip(), "completion": completion.strip()} def make_debug_sample(): title, bad, fixed, why = random.choice(BUGGY_SNIPPETS) prompt = random.choice(DEBUG_PREFIX) + bad completion = f"Fixed version:\npython\n{fixed}\n\nReason: {why}" return {"prompt": prompt.strip(), "completion": completion.strip()} def make_explain_sample(): topic, example = random.choice(EXPLAIN_TOPICS) if random.random() < 0.5: prompt = random.choice(EXPLAIN_PREFIX) + f"python\n{example}\n" completion = f"{topic.title()} explained with example above." else: prompt = f"Explain {topic} with a short example." completion = f"Example:\npython\n{example}\n" return {"prompt": prompt.strip(), "completion": completion.strip()} def make_refactor_sample(): src = "for i in range(len(items)):\n print(items[i])" prompt = random.choice(REFACTOR_PREFIX) + "python\n" + src + "\n" completion = "Use direct iteration:\npython\nfor item in items:\n print(item)\n" return {"prompt": prompt.strip(), "completion": completion.strip()} def make_tests_sample(): # choose a base function to test name = random.choice(list(FUNC_SNIPPETS.keys())) code = FUNC_SNIPPETS[name] prompt = random.choice(TEST_PREFIX) + "python\n" + code + "\n" tests = textwrap.dedent(f"""\ import pytest {code} def test_basic(): assert isinstance({name}(0 if '{name}'=='fib' else []), (int, list, bool, type(None))) """).strip() completion = "python\n" + tests + "\n" return {"prompt": prompt.strip(), "completion": completion.strip()} def make_doc_sample(): name = random.choice(list(FUNC_SNIPPETS.keys())) code = FUNC_SNIPPETS[name] # remove docstring then ask to add one no_doc = "\n".join([ln for ln in code.splitlines() if '\"\"\"' not in ln]) prompt = random.choice(DOC_PREFIX) + "python\n" + no_doc + "\n" completion = "Add docstring and example usage:\npython\n" + code + "\n" return {"prompt": prompt.strip(), "completion": completion.strip()} def make_advanced_task_sample(): task = random.choice(TASKS_ADVANCED) prompt = task completion = "Provide code outline, then full implementation with comments and minimal tests." return {"prompt": prompt.strip(), "completion": completion.strip()} MAKERS = [ (make_write_sample, 0.25), (make_debug_sample, 0.20), (make_explain_sample, 0.20), (make_refactor_sample, 0.10), (make_tests_sample, 0.10), (make_doc_sample, 0.10), (make_advanced_task_sample, 0.05), ] # cumulative weights def _pick_maker(): r = random.random() acc = 0.0 for fn, w in MAKERS: acc += w if r <= acc: return fn return MAKERS[-1][0] # ---------------------- Write Shards ---------------------- def write_shard(path, n_rows): with gzip.open(path, "wt", encoding="utf-8") as f: for _ in range(n_rows): sample = _pick_maker()() f.write(json.dumps(sample, ensure_ascii=False) + "\n") def sha256_file(path): import hashlib h = hashlib.sha256() with open(path, "rb") as rf: while True: chunk = rf.read(1 << 20) if not chunk: break h.update(chunk) return h.hexdigest() def main(): ap = argparse.ArgumentParser() ap.add_argument("--total", type=int, default=1_000_000, help="Total samples to generate") ap.add_argument("--shard_size", type=int, default=10_000, help="Rows per shard") ap.add_argument("--out_dir", type=str, default="python_dataset_v1", help="Output directory") ap.add_argument("--prefix", type=str, default="python", help="Prefix for files") args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) n_shards = math.ceil(args.total / args.shard_size) manifest = { "created": datetime.utcnow().isoformat() + "Z", "total": args.total, "shard_size": args.shard_size, "num_shards": n_shards, "files": [] } print(f"Generating {args.total:,} samples in {n_shards} shards of {args.shard_size}...") for i in range(n_shards): rows = args.shard_size if (i < n_shards - 1) else (args.total - args.shard_size * (n_shards - 1)) shard_path = os.path.join(args.out_dir, f"{args.prefix}_{i:04d}.jsonl.gz") t0 = time.time() write_shard(shard_path, rows) digest = sha256_file(shard_path) manifest["files"].append({"path": shard_path, "rows": rows, "sha256": digest}) dt = time.time() - t0 print(f"[{i+1}/{n_shards}] wrote {rows} rows → {os.path.basename(shard_path)} in {dt:.1f}s") man_path = os.path.join(args.out_dir, f"{args.prefix}_manifest.json") with open(man_path, "w", encoding="utf-8") as mf: json.dump(manifest, mf, indent=2) print("Done. Manifest:", man_path) # show one example ex = _pick_maker()() print("Sample row:", json.dumps(ex, ensure_ascii=False)) if __name__ == "__main__": main()