import hashlib import json import multiprocessing as mp import re import sqlite3 from collections import Counter, defaultdict from pathlib import Path from typing import Dict, Iterable, Iterator, List, Optional from tqdm import tqdm TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) CODE_PATTERN = re.compile( r"(\bdef\b|\bclass\b|\bimport\b|\breturn\b|=>|function\s+\w+|public\s+class|#include|```)", re.IGNORECASE, ) MIN_CODE_SIGNAL_RE = re.compile(r"(\bdef\s+|\bclass\s+|\bimport\s+|=|\breturn\s+|\bfor\s+|\bif\s+)") EXPLANATION_PATTERN = re.compile( r"\b(explain|because|algorithm|steps|approach|complexity|solution)\b", re.IGNORECASE ) PROBLEM_PROMPT_RE = re.compile( r"\b(solve|given|find|compute|return|input|output|problem|algorithm|task|challenge)\b", re.IGNORECASE, ) YAML_FRONTMATTER_RE = re.compile(r"^\s*---\s*\n.*?\n---\s*", re.DOTALL) FENCE_LINE_RE = re.compile(r"^\s*```(?:[a-zA-Z0-9_+-]+)?\s*$") CODEBLOCK_DIRECTIVE_RE = re.compile(r"^\s*code-block::\s*\w*\s*$", re.IGNORECASE) CLI_NOISE_LINE_RE = re.compile( r"^\s*(\[[^\]]+\]\s*)?(INFO|WARNING|ERROR|DEBUG|TRACE)\b|" r"^\s*(PS\s+[A-Za-z]:\\|[A-Za-z]:\\[^>]*>)|" r"^\s*(Traceback \(most recent call last\):|File \".*\", line \d+|Exception:|RuntimeError:|ValueError:|TypeError:)", re.IGNORECASE, ) def estimate_tokens(text: str) -> int: if not text: return 0 return len(TOKEN_PATTERN.findall(text)) def normalize_text(text: str) -> str: if text is None: return "" text = str(text).replace("\x00", "") text = text.replace("\r\n", "\n").replace("\r", "\n") text = "".join(ch for ch in text if ch == "\n" or ch == "\t" or ord(ch) >= 32) lines = [line.rstrip() for line in text.split("\n")] return "\n".join(lines).strip() def clean_response_text(response: str) -> str: text = normalize_text(response) if not text: return "" # Remove YAML front matter and markdown wrappers. text = YAML_FRONTMATTER_RE.sub("", text) kept_lines: List[str] = [] for line in text.split("\n"): if FENCE_LINE_RE.match(line): continue if CODEBLOCK_DIRECTIVE_RE.match(line): continue if CLI_NOISE_LINE_RE.match(line): continue kept_lines.append(line) text = "\n".join(kept_lines) text = text.replace("```python", "").replace("```py", "").replace("```", "") # Normalize indentation and drop leading/trailing blank lines. normalized_lines = [ln.rstrip() for ln in text.split("\n")] while normalized_lines and not normalized_lines[0].strip(): normalized_lines.pop(0) while normalized_lines and not normalized_lines[-1].strip(): normalized_lines.pop() if not normalized_lines: return "" # Keep indentation consistent: convert tabs to 4 spaces. normalized_lines = [ln.replace("\t", " ") for ln in normalized_lines] text = "\n".join(normalized_lines) return text def _ascii_ratio(text: str) -> float: if not text: return 1.0 ascii_count = sum(1 for c in text if ord(c) < 128) return ascii_count / len(text) def _response_is_valid(response: str) -> bool: if not response: return False if CODE_PATTERN.search(response): return True if EXPLANATION_PATTERN.search(response): return True return False def _response_has_code(response: str) -> bool: return bool( re.search( r"(\bdef\b|\bclass\b|\breturn\b|\bimport\b|```|function\s+\w+|public\s+class|#include|SELECT\s+)", response, re.IGNORECASE, ) ) def clean_record( record: Dict[str, str], *, min_tokens: int = 10, max_tokens: int = 2048, ) -> Optional[Dict[str, str]]: instruction = normalize_text(record.get("instruction", "")) response = clean_response_text(record.get("response", "")) source = normalize_text(record.get("_source", "unknown")) category = normalize_text(record.get("_category", "")) if not category: src_low = source.lower() if any(k in src_low for k in ("codealpaca", "evol", "ultrachat", "openhermes", "orca")): category = "instruction" elif any( k in src_low for k in ( "leetcode", "contest", "mbpp", "humaneval", "apps", "codeforces", "problem", "codesearchnet_problem", ) ): category = "problem" else: category = "structured" if not instruction or not response: return None if len(response) < 40: return None if _ascii_ratio(instruction + response) < 0.85: return None if not MIN_CODE_SIGNAL_RE.search(response): return None if not _response_is_valid(response): return None if category == "problem": if len(instruction) <= 50: return None if not PROBLEM_PROMPT_RE.search(instruction): return None if not _response_has_code(response): return None # Problem solutions must include code, not explanation-only text. if EXPLANATION_PATTERN.search(response) and not CODE_PATTERN.search(response): return None total_tokens = estimate_tokens(instruction) + estimate_tokens(response) if total_tokens < min_tokens or total_tokens > max_tokens: return None return { "instruction": instruction, "response": response, "_source": source, "_category": category, "_tokens": total_tokens, } def _iter_jsonl(path: Path) -> Iterable[Dict[str, str]]: with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: yield json.loads(line) except json.JSONDecodeError: continue def _clean_record_worker(payload: Dict[str, object]) -> Optional[Dict[str, str]]: record = payload["record"] min_tokens = int(payload["min_tokens"]) max_tokens = int(payload["max_tokens"]) return clean_record(record, min_tokens=min_tokens, max_tokens=max_tokens) def iter_cleaned_records( path: Path, *, min_tokens: int, max_tokens: int, num_workers: int = 1, batch_size: int = 2000, ) -> Iterator[Dict[str, str]]: if num_workers <= 1: for record in _iter_jsonl(path): cleaned = clean_record(record, min_tokens=min_tokens, max_tokens=max_tokens) if cleaned is not None: yield cleaned return pool = mp.Pool(processes=num_workers) try: batch: List[Dict[str, str]] = [] for record in _iter_jsonl(path): batch.append(record) if len(batch) < batch_size: continue payloads = [ {"record": r, "min_tokens": min_tokens, "max_tokens": max_tokens} for r in batch ] for cleaned in pool.imap_unordered(_clean_record_worker, payloads, chunksize=64): if cleaned is not None: yield cleaned batch.clear() if batch: payloads = [{"record": r, "min_tokens": min_tokens, "max_tokens": max_tokens} for r in batch] for cleaned in pool.imap_unordered(_clean_record_worker, payloads, chunksize=64): if cleaned is not None: yield cleaned finally: pool.close() pool.join() def _remove_sqlite_artifacts(sqlite_path: Path) -> None: if sqlite_path.exists(): sqlite_path.unlink() for suffix in ("-wal", "-shm"): p = sqlite_path.with_name(sqlite_path.name + suffix) if p.exists(): p.unlink() def _open_dedupe_db(sqlite_path: Path): sqlite_path = sqlite_path.resolve() sqlite_path.parent.mkdir(parents=True, exist_ok=True) _remove_sqlite_artifacts(sqlite_path) conn = sqlite3.connect(str(sqlite_path)) conn.execute("PRAGMA journal_mode=WAL;") conn.execute("CREATE TABLE IF NOT EXISTS seen_hashes (h TEXT PRIMARY KEY)") return conn def _is_duplicate(conn, instruction: str, response: str) -> bool: digest = hashlib.sha256(f"{instruction}||{response}".encode("utf-8")).hexdigest() try: conn.execute("INSERT INTO seen_hashes(h) VALUES (?)", (digest,)) return False except sqlite3.IntegrityError: return True def build_balanced_dataset( input_paths: List[Path], output_path: Path, *, target_size: int = 1_000_000, min_tokens: int = 10, max_tokens: int = 2048, category_weights: Optional[Dict[str, float]] = None, sqlite_path: Optional[Path] = None, num_workers: int = 1, ) -> Dict[str, object]: output_path.parent.mkdir(parents=True, exist_ok=True) if sqlite_path is None: sqlite_path = output_path.parent / "dedupe_hashes.sqlite" conn = _open_dedupe_db(sqlite_path) weights = category_weights or {"instruction": 0.60, "structured": 0.30, "problem": 0.10} target_by_cat = {k: int(target_size * v) for k, v in weights.items()} target_by_cat["problem"] = target_size - target_by_cat["instruction"] - target_by_cat["structured"] grouped_paths: Dict[str, List[Path]] = defaultdict(list) for path in input_paths: if not path.exists(): continue name = path.stem if "codealpaca" in name or "evol" in name or "ultrachat" in name or "openhermes" in name: grouped_paths["instruction"].append(path) elif any( k in name for k in ( "leetcode", "contest", "problem", "mbpp", "humaneval", "apps", "codeforces", ) ): grouped_paths["problem"].append(path) else: grouped_paths["structured"].append(path) source_counter = Counter() category_counter = Counter() total_tokens = 0 total_kept = 0 def try_write(cleaned: Dict[str, str], out_f, enforce_category_target: bool) -> bool: nonlocal total_kept, total_tokens category = cleaned["_category"] if enforce_category_target and category_counter[category] >= target_by_cat.get(category, 0): return False if _is_duplicate(conn, cleaned["instruction"], cleaned["response"]): return False source = cleaned["_source"] tokens = int(cleaned["_tokens"]) category_counter[category] += 1 source_counter[source] += 1 total_tokens += tokens total_kept += 1 out_f.write( json.dumps( {"instruction": cleaned["instruction"], "response": cleaned["response"]}, ensure_ascii=False, ) + "\n" ) return True with output_path.open("w", encoding="utf-8") as out_f: # Phase 1: enforce 60/30/10 quotas. for category in ("instruction", "structured", "problem"): if category not in grouped_paths: continue for path in grouped_paths[category]: cleaned_iter = iter_cleaned_records( path, min_tokens=min_tokens, max_tokens=max_tokens, num_workers=num_workers, ) for cleaned in tqdm(cleaned_iter, desc=f"balance1:{path.name}", unit="rows"): if total_kept >= target_size or category_counter[category] >= target_by_cat[category]: break try_write(cleaned, out_f, enforce_category_target=True) conn.commit() if total_kept >= target_size or category_counter[category] >= target_by_cat[category]: continue # Phase 2: fill remaining slots from all categories while preserving dedupe. if total_kept < target_size: for path in input_paths: if not path.exists(): continue cleaned_iter = iter_cleaned_records( path, min_tokens=min_tokens, max_tokens=max_tokens, num_workers=num_workers, ) for cleaned in tqdm(cleaned_iter, desc=f"balance2:{path.name}", unit="rows"): if total_kept >= target_size: break try_write(cleaned, out_f, enforce_category_target=False) conn.commit() if total_kept >= target_size: break conn.close() avg_len = round((total_tokens / total_kept), 2) if total_kept else 0.0 raw_converted = category_counter["structured"] + category_counter["problem"] ratio = { "instruction_pct": round(100.0 * category_counter["instruction"] / max(total_kept, 1), 2), "raw_converted_pct": round(100.0 * raw_converted / max(total_kept, 1), 2), } return { "total_samples": total_kept, "avg_length_tokens": avg_len, "source_breakdown": dict(source_counter), "category_breakdown": dict(category_counter), "instruction_vs_raw_ratio": ratio, "targets": target_by_cat, }