HusseinEid's picture
Super-squash branch 'main' using huggingface_hub
68a4c53
"""Parallel frequency counting over corpus shards.
The output is a Counter mapping token → integer count, where:
- Each whole token from `iter_tokens` contributes its raw count.
- For ASCII identifiers, sub-parts contribute fractional counts via
`boost_weight`, accumulated as floats and ceiled at the end so the boost
is actually nonzero (fix #4 from the build plan: the v2.1 draft used
`int(0.3) == 0`, silently disabling the boost).
"""
from __future__ import annotations
import math
import os
from collections import Counter, defaultdict
from collections.abc import Iterator
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from ._accel_loader import USE_RUST, accel
from .corpus import iter_shards, read_shard
from .patterns import is_identifier, iter_tokens, split_identifier
def count_in_text(
text: str,
counter: Counter[str],
boost_acc: defaultdict[str, float],
boost_weight: float,
max_token_len: int,
) -> None:
"""Accumulate token counts from `text` into `counter` + `boost_acc`."""
if USE_RUST:
c_delta, b_delta = accel.count_in_text(text, boost_weight, max_token_len)
for k, v in c_delta.items():
counter[k] += v
for k, v in b_delta.items():
boost_acc[k] += v
return
for tok, _, _ in iter_tokens(text):
if len(tok) > max_token_len:
continue
counter[tok] += 1
if boost_weight > 0 and is_identifier(tok):
for part in split_identifier(tok):
if part != tok:
boost_acc[part] += boost_weight
def _count_shard(
args: tuple[Path, float, int],
) -> tuple[Counter[str], dict[str, float]]:
"""Worker: count one shard, return (raw counter, boost accumulator)."""
shard_path, boost_weight, max_token_len = args
counter: Counter[str] = Counter()
boost_acc: defaultdict[str, float] = defaultdict(float)
for rec in read_shard(shard_path):
count_in_text(rec.text, counter, boost_acc, boost_weight, max_token_len)
return counter, dict(boost_acc)
def count_frequencies(
shards_dir: Path,
boost_weight: float = 0.3,
max_token_len: int = 50,
workers: int = 0,
) -> Counter[str]:
"""Count token frequencies across all shards.
Parameters
----------
shards_dir
Directory containing `shard_*.jsonl.gz` produced by `ingest_corpus`.
boost_weight
Fractional weight added per identifier sub-part occurrence.
Sub-parts that never appear as standalone tokens still get nonzero
counts and become eligible for PUA assignment.
max_token_len
Tokens longer than this are ignored entirely.
workers
0 → use os.cpu_count(); 1 → run sequentially in this process;
N>1 → use a ProcessPoolExecutor with N workers.
"""
shards = list(iter_shards(shards_dir))
if not shards:
return Counter()
if USE_RUST:
# Rust path: Rayon-parallel shard processing inside the extension.
# `workers` is ignored — Rayon manages its own thread pool sized to
# available CPUs. Returns a flat dict already merged with the boost.
rust_dict = accel.count_frequencies([str(s) for s in shards], boost_weight, max_token_len)
return Counter(rust_dict)
if workers == 0:
workers = os.cpu_count() or 1
args_list = [(s, boost_weight, max_token_len) for s in shards]
total_counter: Counter[str] = Counter()
boost_total: defaultdict[str, float] = defaultdict(float)
if workers <= 1:
results: Iterator[tuple[Counter[str], dict[str, float]]] = (
_count_shard(a) for a in args_list
)
else:
# imap-style streaming via map; ordered for determinism.
# `with` guarantees pool shutdown even on exception.
def _run() -> Iterator[tuple[Counter[str], dict[str, float]]]:
with ProcessPoolExecutor(max_workers=workers) as ex:
yield from ex.map(_count_shard, args_list)
results = _run()
for c, b in results:
total_counter.update(c)
for k, v in b.items():
boost_total[k] += v
# Merge boost into the counter, ceiling fractions so any nonzero boost
# produces at least 1 count. This guarantees the boost is observable.
for tok, frac in boost_total.items():
bonus = math.ceil(frac)
if bonus > 0:
total_counter[tok] += bonus
return total_counter
__all__ = ["count_frequencies", "count_in_text"]