Percy3822 commited on
Commit
304794d
·
verified ·
1 Parent(s): 9b4de55

Create make_python_dataset.py

Browse files
Files changed (1) hide show
  1. make_python_dataset.py +298 -0
make_python_dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, json, random, gzip, time, math, textwrap, hashlib
2
+ from datetime import datetime
3
+
4
+ # Determinism but still varied
5
+ random.seed(1337)
6
+
7
+ # ---------------------- Content Pools ----------------------
8
+
9
+ # Small library of canonical functions to expand/perturb
10
+ FUNC_SNIPPETS = {
11
+ "sum_list": textwrap.dedent("""\
12
+ from typing import Iterable
13
+
14
+ def sum_list(nums: Iterable[int]) -> int:
15
+ \"\"\"Return the sum of integers in an iterable.\"\"\"
16
+ return sum(nums)
17
+ """).strip(),
18
+ "is_prime": textwrap.dedent("""\
19
+ from math import isqrt
20
+
21
+ def is_prime(n: int) -> bool:
22
+ \"\"\"Return True if n is prime (n>=2).\"\"\"
23
+ if n < 2:
24
+ return False
25
+ if n % 2 == 0:
26
+ return n == 2
27
+ r = isqrt(n)
28
+ f = 3
29
+ while f <= r:
30
+ if n % f == 0:
31
+ return False
32
+ f += 2
33
+ return True
34
+ """).strip(),
35
+ "fizzbuzz": textwrap.dedent("""\
36
+ def fizzbuzz(n: int) -> list[str]:
37
+ \"\"\"Generate FizzBuzz from 1..n.\"\"\"
38
+ out = []
39
+ for i in range(1, n + 1):
40
+ s = ""
41
+ if i % 3 == 0:
42
+ s += "Fizz"
43
+ if i % 5 == 0:
44
+ s += "Buzz"
45
+ out.append(s or str(i))
46
+ return out
47
+ """).strip(),
48
+ "factorial": textwrap.dedent("""\
49
+ def factorial(n: int) -> int:
50
+ \"\"\"Return n! for n>=0.\"\"\"
51
+ if n < 0:
52
+ raise ValueError("n must be >= 0")
53
+ res = 1
54
+ for i in range(2, n + 1):
55
+ res *= i
56
+ return res
57
+ """).strip(),
58
+ "fib": textwrap.dedent("""\
59
+ def fib(n: int) -> int:
60
+ \"\"\"Return nth Fibonacci number (0-indexed).\"\"\"
61
+ a, b = 0, 1
62
+ for _ in range(n):
63
+ a, b = b, a + b
64
+ return a
65
+ """).strip(),
66
+ }
67
+
68
+ # Buggy variants (to generate "debug/fix/explain" style samples)
69
+ BUGGY_SNIPPETS = [
70
+ (
71
+ "IndexError on short lists",
72
+ "def last_two_sum(xs):\n return xs[-1] + xs[-2]\n",
73
+ textwrap.dedent("""\
74
+ def last_two_sum(xs):
75
+ if len(xs) < 2:
76
+ raise ValueError("Need at least 2 items")
77
+ return xs[-1] + xs[-2]
78
+ """).strip(),
79
+ "Accessing [-1] and [-2] without length check can raise IndexError.",
80
+ ),
81
+ (
82
+ "Mutable default arg",
83
+ "def add_name(name, xs=[]):\n xs.append(name)\n return xs\n",
84
+ textwrap.dedent("""\
85
+ def add_name(name, xs=None):
86
+ if xs is None:
87
+ xs = []
88
+ xs.append(name)
89
+ return xs
90
+ """).strip(),
91
+ "Default list is shared across calls. Use None sentinel.",
92
+ ),
93
+ (
94
+ "Shadowing builtins",
95
+ "def max(nums):\n return 0\n",
96
+ textwrap.dedent("""\
97
+ from typing import Iterable
98
+
99
+ def my_max(nums: Iterable[int]) -> int:
100
+ return _builtins_['max'](nums)
101
+ """).strip(),
102
+ "Function named 'max' shadows builtin, causing confusion. Rename or call builtins explicitly.",
103
+ ),
104
+ (
105
+ "Off-by-one",
106
+ "def range_inclusive(a,b):\n return list(range(a,b))\n",
107
+ textwrap.dedent("""\
108
+ def range_inclusive(a: int, b: int) -> list[int]:
109
+ return list(range(a, b + 1))
110
+ """).strip(),
111
+ "range(a,b) excludes b; inclusive needs b+1.",
112
+ ),
113
+ ]
114
+
115
+ EXPLAIN_TOPICS = [
116
+ ("list comprehensions", "squares = [x*x for x in range(5)]"),
117
+ ("generators", "def gen():\n yield from range(3)"),
118
+ ("context managers", "with open('file.txt','r',encoding='utf-8') as f:\n data = f.read()"),
119
+ ("dataclasses", "from dataclasses import dataclass\n@dataclass\nclass User:\n name: str\n age: int"),
120
+ ("typing generics", "from typing import TypeVar, Generic, Iterable\nT = TypeVar('T')\nclass Box(Generic[T]):\n def _init_(self, x: T): self.x = x"),
121
+ ("asyncio basics", "import asyncio\nasync def main():\n await asyncio.sleep(1)"),
122
+ ("pandas intro", "import pandas as pd\ndf = pd.DataFrame({'a':[1,2]})"),
123
+ ]
124
+
125
+ TASKS_ADVANCED = [
126
+ "Write a CLI using argparse that reads a CSV and prints the average of a column.",
127
+ "Design a file watcher that logs changes in a directory (cross-platform).",
128
+ "Write a simple REST API using FastAPI with one POST /sum endpoint and tests.",
129
+ "Implement a thread-safe counter class with locks and unit tests.",
130
+ "Write an async HTTP fetcher with rate limiting and retries (aiohttp).",
131
+ "Implement a minimal LRU cache with dict + doubly-linked list and tests.",
132
+ ]
133
+
134
+ # Prompt styles
135
+ WRITE_PREFIX = [
136
+ "Write a Python function with type hints and docstring:\n",
137
+ "Implement production-quality Python with clear naming and tests:\n",
138
+ ]
139
+ DEBUG_PREFIX = [
140
+ "Find and fix the bug in this Python code. Provide corrected code and a brief explanation:\n",
141
+ ]
142
+ EXPLAIN_PREFIX = [
143
+ "Explain what this code does and possible pitfalls:\n",
144
+ "Explain the concept and show a short example:\n",
145
+ ]
146
+ REFACTOR_PREFIX = [
147
+ "Refactor this snippet to be more Pythonic and readable:\n",
148
+ ]
149
+ TEST_PREFIX = [
150
+ "Write pytest unit tests for the following function:\n",
151
+ ]
152
+ DOC_PREFIX = [
153
+ "Write a docstring and usage example for the following function:\n",
154
+ ]
155
+
156
+ # ---------------------- Sample Builders ----------------------
157
+
158
+ def make_write_sample():
159
+ name = random.choice(list(FUNC_SNIPPETS.keys()))
160
+ code = FUNC_SNIPPETS[name]
161
+ prompt = random.choice(WRITE_PREFIX) + f"{name}()\n"
162
+ completion = code
163
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
164
+
165
+ def make_debug_sample():
166
+ title, bad, fixed, why = random.choice(BUGGY_SNIPPETS)
167
+ prompt = random.choice(DEBUG_PREFIX) + bad
168
+ completion = f"Fixed version:\npython\n{fixed}\n\nReason: {why}"
169
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
170
+
171
+ def make_explain_sample():
172
+ topic, example = random.choice(EXPLAIN_TOPICS)
173
+ if random.random() < 0.5:
174
+ prompt = random.choice(EXPLAIN_PREFIX) + f"python\n{example}\n"
175
+ completion = f"{topic.title()} explained with example above."
176
+ else:
177
+ prompt = f"Explain {topic} with a short example."
178
+ completion = f"Example:\npython\n{example}\n"
179
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
180
+
181
+ def make_refactor_sample():
182
+ src = "for i in range(len(items)):\n print(items[i])"
183
+ prompt = random.choice(REFACTOR_PREFIX) + "python\n" + src + "\n"
184
+ completion = "Use direct iteration:\npython\nfor item in items:\n print(item)\n"
185
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
186
+
187
+ def make_tests_sample():
188
+ # choose a base function to test
189
+ name = random.choice(list(FUNC_SNIPPETS.keys()))
190
+ code = FUNC_SNIPPETS[name]
191
+ prompt = random.choice(TEST_PREFIX) + "python\n" + code + "\n"
192
+ tests = textwrap.dedent(f"""\
193
+ import pytest
194
+
195
+ {code}
196
+
197
+ def test_basic():
198
+ assert isinstance({name}(0 if '{name}'=='fib' else []), (int, list, bool, type(None)))
199
+ """).strip()
200
+ completion = "python\n" + tests + "\n"
201
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
202
+
203
+ def make_doc_sample():
204
+ name = random.choice(list(FUNC_SNIPPETS.keys()))
205
+ code = FUNC_SNIPPETS[name]
206
+ # remove docstring then ask to add one
207
+ no_doc = "\n".join([ln for ln in code.splitlines() if '\"\"\"' not in ln])
208
+ prompt = random.choice(DOC_PREFIX) + "python\n" + no_doc + "\n"
209
+ completion = "Add docstring and example usage:\npython\n" + code + "\n"
210
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
211
+
212
+ def make_advanced_task_sample():
213
+ task = random.choice(TASKS_ADVANCED)
214
+ prompt = task
215
+ completion = "Provide code outline, then full implementation with comments and minimal tests."
216
+ return {"prompt": prompt.strip(), "completion": completion.strip()}
217
+
218
+ MAKERS = [
219
+ (make_write_sample, 0.25),
220
+ (make_debug_sample, 0.20),
221
+ (make_explain_sample, 0.20),
222
+ (make_refactor_sample, 0.10),
223
+ (make_tests_sample, 0.10),
224
+ (make_doc_sample, 0.10),
225
+ (make_advanced_task_sample, 0.05),
226
+ ]
227
+
228
+ # cumulative weights
229
+ def _pick_maker():
230
+ r = random.random()
231
+ acc = 0.0
232
+ for fn, w in MAKERS:
233
+ acc += w
234
+ if r <= acc:
235
+ return fn
236
+ return MAKERS[-1][0]
237
+
238
+ # ---------------------- Write Shards ----------------------
239
+
240
+ def write_shard(path, n_rows):
241
+ with gzip.open(path, "wt", encoding="utf-8") as f:
242
+ for _ in range(n_rows):
243
+ sample = _pick_maker()()
244
+ f.write(json.dumps(sample, ensure_ascii=False) + "\n")
245
+
246
+ def sha256_file(path):
247
+ import hashlib
248
+ h = hashlib.sha256()
249
+ with open(path, "rb") as rf:
250
+ while True:
251
+ chunk = rf.read(1 << 20)
252
+ if not chunk:
253
+ break
254
+ h.update(chunk)
255
+ return h.hexdigest()
256
+
257
+ def main():
258
+ ap = argparse.ArgumentParser()
259
+ ap.add_argument("--total", type=int, default=1_000_000, help="Total samples to generate")
260
+ ap.add_argument("--shard_size", type=int, default=10_000, help="Rows per shard")
261
+ ap.add_argument("--out_dir", type=str, default="python_dataset_v1", help="Output directory")
262
+ ap.add_argument("--prefix", type=str, default="python", help="Prefix for files")
263
+ args = ap.parse_args()
264
+
265
+ os.makedirs(args.out_dir, exist_ok=True)
266
+ n_shards = math.ceil(args.total / args.shard_size)
267
+
268
+ manifest = {
269
+ "created": datetime.utcnow().isoformat() + "Z",
270
+ "total": args.total,
271
+ "shard_size": args.shard_size,
272
+ "num_shards": n_shards,
273
+ "files": []
274
+ }
275
+
276
+ print(f"Generating {args.total:,} samples in {n_shards} shards of {args.shard_size}...")
277
+
278
+ for i in range(n_shards):
279
+ rows = args.shard_size if (i < n_shards - 1) else (args.total - args.shard_size * (n_shards - 1))
280
+ shard_path = os.path.join(args.out_dir, f"{args.prefix}_{i:04d}.jsonl.gz")
281
+ t0 = time.time()
282
+ write_shard(shard_path, rows)
283
+ digest = sha256_file(shard_path)
284
+ manifest["files"].append({"path": shard_path, "rows": rows, "sha256": digest})
285
+ dt = time.time() - t0
286
+ print(f"[{i+1}/{n_shards}] wrote {rows} rows → {os.path.basename(shard_path)} in {dt:.1f}s")
287
+
288
+ man_path = os.path.join(args.out_dir, f"{args.prefix}_manifest.json")
289
+ with open(man_path, "w", encoding="utf-8") as mf:
290
+ json.dump(manifest, mf, indent=2)
291
+
292
+ print("Done. Manifest:", man_path)
293
+ # show one example
294
+ ex = _pick_maker()()
295
+ print("Sample row:", json.dumps(ex, ensure_ascii=False))
296
+
297
+ if __name__ == "__main__":
298
+ main()