File size: 5,097 Bytes
4255a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import argparse
import json
import logging
from collections.abc import Iterator
from pathlib import Path

import numpy as np
from datasets import load_dataset
from more_itertools import batched
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers.tokenization_utils import PreTrainedTokenizer

_SAVE_EVERY = 32


logger = logging.getLogger(__name__)


def featurize(
    dataset: Iterator[dict[str, str]],
    model: SentenceTransformer,
    output_dir: str,
    max_means: int,
    batch_size: int,
    text_key: str,
) -> None:
    """Make a directory and dump all kinds of data in it."""
    output_dir_path = Path(output_dir)
    output_dir_path.mkdir(parents=True, exist_ok=True)

    # Ugly hack
    largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0)
    if largest_batch:
        logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.")

    texts = []
    embeddings = []
    dim = model.get_sentence_embedding_dimension()
    if dim is None:
        msg = "Model has no sentence embedding dimension."
        raise ValueError(msg)

    tokenizer: PreTrainedTokenizer = model.tokenizer
    # Binding i in case the dataset is empty.
    i = 0
    for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))):
        if i * batch_size >= max_means:
            logger.info(f"Reached maximum number of means: {max_means}")
            break
        if largest_batch and i <= largest_batch:
            continue
        batch = [x[text_key] for x in batch]

        if not all(isinstance(x, str) for x in batch):
            msg = f"Detected non-string at batch: {i}"
            raise ValueError(msg)

        batch_embeddings = model.encode(batch, output_value="token_embeddings")  # type: ignore  # Annoying
        for text, embedding in zip(batch, batch_embeddings, strict=False):
            texts.append(_truncate_text(tokenizer, text))
            embeddings.append(embedding[1:-1].mean(axis=0).cpu().numpy())
        if i and i % _SAVE_EVERY == 0:
            json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
            np.save(output_dir_path / f"feature_{i}.npy", embeddings)
            texts = []
            embeddings = []
    if texts:
        json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
        np.save(output_dir_path / f"feature_{i}.npy", embeddings)


def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str:
    """Truncate text to fit the tokenizer's maximum length."""
    tokens = tokenizer.encode(
        text,
        truncation=True,
        max_length=tokenizer.model_max_length,
    )
    return tokenizer.decode(tokens, skip_special_tokens=True)


def main() -> None:
    """Main function to featurize texts using a sentence transformer."""
    parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.")
    parser.add_argument(
        "--model-name",
        type=str,
        default="baai/bge-base-en-v1.5",
        help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help="Directory to save the featurized texts.",
    )
    parser.add_argument(
        "--dataset-path",
        type=str,
        default="allenai/c4",
        help="The dataset path or name (e.g. 'allenai/c4').",
    )
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="en",
        help="The dataset configuration name (e.g., 'en' for C4).",
    )
    parser.add_argument(
        "--dataset-split",
        type=str,
        default="train",
        help="The dataset split (e.g., 'train', 'validation').",
    )
    parser.add_argument(
        "--no-streaming",
        action="store_false",
        help="Disable streaming mode when loading the dataset.",
    )
    parser.add_argument(
        "--max-means",
        type=int,
        default=1000000,
        help="The maximum number of mean embeddings to generate.",
    )
    parser.add_argument(
        "--key",
        type=str,
        default="text",
        help="The key of the text field in the dataset to featurize (default: 'text').",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size to use for encoding the texts.",
    )

    args = parser.parse_args()

    if args.output_dir is None:
        model_name = args.model_name.replace("/", "_")
        dataset_path = args.dataset_path.replace("/", "_")
        output_dir = f"{model_name}_{dataset_path}_featurized"
    else:
        output_dir = args.output_dir

    model = SentenceTransformer(args.model_name)
    dataset = load_dataset(
        args.dataset_path,
        name=args.dataset_name,
        split=args.dataset_split,
        streaming=args.no_streaming,
    )

    featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key)


if __name__ == "__main__":
    main()