File size: 4,676 Bytes
68a4c53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Parallel frequency counting over corpus shards.



The output is a Counter mapping token → integer count, where:

- Each whole token from `iter_tokens` contributes its raw count.

- For ASCII identifiers, sub-parts contribute fractional counts via

  `boost_weight`, accumulated as floats and ceiled at the end so the boost

  is actually nonzero (fix #4 from the build plan: the v2.1 draft used

  `int(0.3) == 0`, silently disabling the boost).

"""

from __future__ import annotations

import math
import os
from collections import Counter, defaultdict
from collections.abc import Iterator
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

from ._accel_loader import USE_RUST, accel
from .corpus import iter_shards, read_shard
from .patterns import is_identifier, iter_tokens, split_identifier


def count_in_text(

    text: str,

    counter: Counter[str],

    boost_acc: defaultdict[str, float],

    boost_weight: float,

    max_token_len: int,

) -> None:
    """Accumulate token counts from `text` into `counter` + `boost_acc`."""
    if USE_RUST:
        c_delta, b_delta = accel.count_in_text(text, boost_weight, max_token_len)
        for k, v in c_delta.items():
            counter[k] += v
        for k, v in b_delta.items():
            boost_acc[k] += v
        return
    for tok, _, _ in iter_tokens(text):
        if len(tok) > max_token_len:
            continue
        counter[tok] += 1
        if boost_weight > 0 and is_identifier(tok):
            for part in split_identifier(tok):
                if part != tok:
                    boost_acc[part] += boost_weight


def _count_shard(

    args: tuple[Path, float, int],

) -> tuple[Counter[str], dict[str, float]]:
    """Worker: count one shard, return (raw counter, boost accumulator)."""
    shard_path, boost_weight, max_token_len = args
    counter: Counter[str] = Counter()
    boost_acc: defaultdict[str, float] = defaultdict(float)
    for rec in read_shard(shard_path):
        count_in_text(rec.text, counter, boost_acc, boost_weight, max_token_len)
    return counter, dict(boost_acc)


def count_frequencies(

    shards_dir: Path,

    boost_weight: float = 0.3,

    max_token_len: int = 50,

    workers: int = 0,

) -> Counter[str]:
    """Count token frequencies across all shards.



    Parameters

    ----------

    shards_dir

        Directory containing `shard_*.jsonl.gz` produced by `ingest_corpus`.

    boost_weight

        Fractional weight added per identifier sub-part occurrence.

        Sub-parts that never appear as standalone tokens still get nonzero

        counts and become eligible for PUA assignment.

    max_token_len

        Tokens longer than this are ignored entirely.

    workers

        0 → use os.cpu_count(); 1 → run sequentially in this process;

        N>1 → use a ProcessPoolExecutor with N workers.

    """
    shards = list(iter_shards(shards_dir))
    if not shards:
        return Counter()

    if USE_RUST:
        # Rust path: Rayon-parallel shard processing inside the extension.
        # `workers` is ignored — Rayon manages its own thread pool sized to
        # available CPUs. Returns a flat dict already merged with the boost.
        rust_dict = accel.count_frequencies([str(s) for s in shards], boost_weight, max_token_len)
        return Counter(rust_dict)

    if workers == 0:
        workers = os.cpu_count() or 1

    args_list = [(s, boost_weight, max_token_len) for s in shards]

    total_counter: Counter[str] = Counter()
    boost_total: defaultdict[str, float] = defaultdict(float)

    if workers <= 1:
        results: Iterator[tuple[Counter[str], dict[str, float]]] = (
            _count_shard(a) for a in args_list
        )
    else:
        # imap-style streaming via map; ordered for determinism.
        # `with` guarantees pool shutdown even on exception.
        def _run() -> Iterator[tuple[Counter[str], dict[str, float]]]:
            with ProcessPoolExecutor(max_workers=workers) as ex:
                yield from ex.map(_count_shard, args_list)

        results = _run()

    for c, b in results:
        total_counter.update(c)
        for k, v in b.items():
            boost_total[k] += v

    # Merge boost into the counter, ceiling fractions so any nonzero boost
    # produces at least 1 count. This guarantees the boost is observable.
    for tok, frac in boost_total.items():
        bonus = math.ceil(frac)
        if bonus > 0:
            total_counter[tok] += bonus

    return total_counter


__all__ = ["count_frequencies", "count_in_text"]