File size: 3,270 Bytes
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

from collections import Counter, defaultdict
import unicodedata

def get_stats(ids):
    """
    Given `ids`, a list of 2-tuples of iterables of ints and int values,
    returns a defaultdict with the counts of occurrences of all the consecutive
    pairs of integers within each bytes object, multiplied by the integer value
    associated with each key. This function does not count pairs between the last
    element of one key the first element of the next key. The integer value
    associated with each key serves as a multiplier for the count of each pair
    within that object. Consecutive identical pairs within the same bytes object
    are counted only once to avoid overcounting repeat characters.

    Example:
        get_stats({b'abc': 2, b'bcd': 1, b'eee': 1})
        -> defaultdict(<class 'int'>, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1})
    """
    counts = defaultdict(int)
    for chunk, num in ids:
        last_index = len(chunk) - 1
        i = 0
        while i < last_index:
            j = i + 1
            counts[(chunk[i], chunk[j])] += num
            i = j
    return counts

def merge_batch_get_stats(ids, pairs):
    counts = defaultdict(int)
    for chunk, num in ids:
        last_index = len(chunk) - 1
        i = 0
        while i < last_index:
            j = i + 1
            token = pairs.get((chunk[i], chunk[j]))
            if token is not None:
                chunk[i] = token
                del chunk[j]
                last_index -= 1
            if i:
                counts[(chunk[i-1], chunk[i])] += num
            i = j
        if i and i == last_index:
            counts[(chunk[-2], chunk[i])] += num
    return counts

def merge(ids, pair, idx, len_ids):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    i = 0
    while i + 1 < len_ids:
        j = i + 1
        if ids[i] == pair[0] and ids[j] == pair[1]:
            ids[i] = idx
            del ids[j]
            len_ids -= 1
        i = j
    return len_ids

def replace_control_characters(s: str) -> str:
    # we don't want to print control characters
    # which distort the output (e.g. \n or much worse)
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
    # http://www.unicode.org/reports/tr44/#GC_Values_Table
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch) # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    return "".join(chars)

def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    s = replace_control_characters(s)
    return s

def _process_dicts(batch, compiled_pattern):   # for raw datasets.Dataset
    counter = Counter()
    for item in batch:
        counter.update(re.findall(compiled_pattern, item))
    return counter

def _process_string_scalar(batch, compiled_pattern):
    counter = Counter()
    for item in batch:
        counter.update(re.findall(compiled_pattern, item.as_py()))
    return counter