File size: 2,188 Bytes
4255a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
import logging
from collections import Counter
from pathlib import Path

import numpy as np
import regex
from tqdm import tqdm

logger = logging.getLogger(__name__)


def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]:
    """
    Create a vocabulary from a list of texts.

    :param texts: The list of texts to create the vocabulary from.
    :param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models.
    :return: The vocabulary.
    """
    tokenizer_regex = regex.compile(r"\w+|[^\w\s]+")

    # Tokenize all texts
    tokens = []
    for text in tqdm(texts, desc="Tokenizing texts"):
        tokens.extend(tokenizer_regex.findall(text.lower()))

    # Count the tokens
    token_counts = Counter(tokens)

    # Get the most common tokens as the vocabulary
    return [word for word, _ in token_counts.most_common(vocab_size)]


def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]:
    """Collect means and texts from a list of paths."""
    txts = []
    vectors_list = []
    for items_path in tqdm(paths, desc="Collecting means and texts"):
        if not items_path.name.endswith(".json"):
            continue
        base_path = items_path.with_name(items_path.stem.replace("", ""))
        vectors_path = items_path.with_name(base_path.name.replace(".json", "") + ".npy")
        try:
            with open(items_path) as f:
                items = json.load(f)
            vectors = np.load(vectors_path, allow_pickle=False)
        except (KeyError, FileNotFoundError, ValueError) as e:
            logger.info(f"Error loading data from {base_path}: {e}")
            continue

        # Filter out any NaN vectors before appending
        vectors = np.array(vectors)
        items = np.array(items)
        non_nan_indices = ~np.isnan(vectors).any(axis=1)
        valid_vectors = vectors[non_nan_indices]
        valid_items = items[non_nan_indices]
        txts.extend(valid_items.tolist())
        vectors_list.append(valid_vectors)

    all_vectors = np.concatenate(vectors_list, axis=0) if vectors_list else np.array([])
    return txts, all_vectors