symbol-fim-model / prepare_dataset.py
ethanker's picture
Upload prepare_dataset.py with huggingface_hub
d233dd1 verified
#!/usr/bin/env python
import argparse
import json
import multiprocessing as mp
import os
import time
from functools import partial
from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
# Hardware specifications (hardcoded)
HOST_CPU_COUNT = 24
HOST_MEMORY_GB = 117
HOST_GPU_VRAM_GB = 80
def simple_minify(code: str) -> str:
"""Fast minification - no regex, single pass."""
if not code:
return ""
out_lines = []
for line in code.splitlines():
s = line.rstrip()
if s.startswith("#") or s == "":
continue
out_lines.append(s)
return "\n".join(out_lines)
def process_sample(sample: dict, max_chars: int, python_only: bool) -> Optional[dict]:
"""Process a single sample: filter, minify, return dict or None."""
code = sample.get("content") or sample.get("code") or ""
if not code:
return None
if python_only:
lang = (sample.get("language") or "").lower()
if lang and "python" not in lang:
return None
if len(code) > max_chars:
return None
min_code = simple_minify(code)
if not min_code.strip():
return None
return {"code": min_code}
def writer_proc(out_path: str, q: mp.Queue, flush_every: int = 1000) -> None:
"""Writer process: buffers writes and flushes in chunks."""
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8", buffering=1024 * 1024) as f:
buf = []
while True:
item = q.get()
if item is None: # Sentinel to stop
break
buf.append(item)
if len(buf) >= flush_every:
f.write("".join(buf))
buf = []
if buf:
f.write("".join(buf))
def main(out_dir: str, num_workers: int, shuffle_seed: Optional[int], max_chars: int = 200_000, max_files: Optional[int] = None) -> None:
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "python_minified.jsonl")
if os.path.exists(out_path):
print(f"dataset already exists at {out_path}, skipping")
return
print(f"[INFO] Starting dataset preparation")
print(f"[INFO] Host hardware: {HOST_CPU_COUNT} CPUs, {HOST_MEMORY_GB}GB RAM, {HOST_GPU_VRAM_GB}GB GPU VRAM")
print(f"[INFO] Requested workers: {num_workers}")
start_time = time.time()
stream = load_dataset(
"codeparrot/github-code",
split="train",
streaming=True,
trust_remote_code=True,
)
print(f"[INFO] Dataset loaded in {time.time() - start_time:.2f}s")
if shuffle_seed is not None:
buffer_size = max(2048, num_workers * 1024)
stream = stream.shuffle(buffer_size=buffer_size, seed=shuffle_seed)
# Use hardcoded CPU count - limit workers to available CPUs
effective_workers = min(num_workers, HOST_CPU_COUNT)
print(f"[INFO] Using {effective_workers} worker processes (max: {HOST_CPU_COUNT})")
print(f"[INFO] Max chars per file: {max_chars}")
if max_files:
print(f"[INFO] Max files to process: {max_files}")
# Start writer process
q = mp.Queue(maxsize=5000)
writer_process = mp.Process(target=writer_proc, args=(out_path, q, 1000))
writer_process.start()
# Build worker pool
process_fn = partial(
process_sample,
max_chars=max_chars,
python_only=True,
)
kept = 0
chunksize = 64 # Process 64 samples per chunk for better parallelism
# Process in parallel using imap_unordered for better throughput
with mp.Pool(effective_workers) as pool:
with tqdm(desc="minifying", unit="files") as pbar:
for result in pool.imap_unordered(process_fn, stream, chunksize=chunksize):
if result is None:
continue
# Serialize and queue for writer
q.put(json.dumps(result, ensure_ascii=False) + "\n")
kept += 1
pbar.update(1)
if max_files and kept >= max_files:
break
# Shutdown
q.put(None) # Signal writer to stop
writer_process.join()
elapsed = time.time() - start_time
rate = kept / elapsed if elapsed > 0 else 0
print(f"[INFO] Completed! Wrote {kept} records to {out_path}")
print(f"[INFO] Total time: {elapsed:.2f}s | Average rate: {rate:.1f} files/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--num_workers", type=int, default=12)
parser.add_argument("--shuffle_seed", type=int, default=None)
parser.add_argument("--max_chars", type=int, default=200_000)
parser.add_argument("--max_files", type=int, default=None)
args = parser.parse_args()
main(
out_dir=args.out_dir,
num_workers=args.num_workers,
shuffle_seed=args.shuffle_seed,
max_chars=args.max_chars,
max_files=args.max_files,
)