File size: 8,270 Bytes
4d939fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Embedding generation CLI for DETree."""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable, Literal, Optional

import pandas as pd
import torch
import torch.nn.functional as F
from lightning import Fabric
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from detree.model.text_embedding import TextEmbeddingModel
from detree.utils.dataset import SCLDataset, load_datapath


def infer(passages_dataloader, fabric, tokenizer, model, args):
    if fabric.global_rank == 0:
        passages_dataloader = tqdm(passages_dataloader)
        all_ids, all_embeddings, all_labels = [], {}, []
        for layer in args.need_layer:
            all_embeddings[layer] = []
    with torch.no_grad():
        for batch in passages_dataloader:
            text, label, write_model, ids = batch
            encoded_batch = tokenizer.batch_encode_plus(
                text,
                return_tensors="pt",
                max_length=args.max_length,
                padding="max_length",
                truncation=True,
            )
            encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
            embeddings = model(encoded_batch, hidden_states=True)
            embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
            label = fabric.all_gather(write_model).view(-1)
            ids = fabric.all_gather(ids).view(-1)
            if fabric.global_rank == 0:
                embeddings = F.normalize(embeddings, dim=-1).cpu().to(torch.bfloat16)
                for layer in args.need_layer:
                    all_embeddings[layer].append(embeddings[:, layer, :].clone())
                all_ids.extend(ids.cpu().tolist())
                all_labels.extend(label.cpu().tolist())
            del embeddings, label, ids
    if fabric.global_rank == 0:
        for layer in args.need_layer:
            all_embeddings[layer] = torch.cat(all_embeddings[layer], dim=0)
        return torch.tensor(all_ids), all_embeddings, torch.tensor(all_labels)
    return [], [], []


def stable_long_hash(input_string: str) -> int:
    import hashlib

    hash_object = hashlib.sha256(input_string.encode())
    hex_digest = hash_object.hexdigest()
    int_hash = int(hex_digest, 16)
    return int_hash & ((1 << 63) - 1)


def load_data(split: Literal["train", "test", "extra"], include_adversarial: bool, fp: Path) -> pd.DataFrame:
    if split not in ("train", "test", "extra"):
        raise ValueError("`split` must be one of (\"train\", \"test\", \"extra\")")

    fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"
    fp = fp / fname
    return pd.read_csv(fp)


class PassagesDataset(Dataset):
    def __init__(self, data):
        self.passages = []
        for item in data:
            if item["attack"] not in ("none", "paraphrase") and stable_long_hash(item["generation"]) % 10 < 5:
                continue
            self.passages.append(item)
        classes = sorted({item["model"] for item in data})
        self.classes = list(classes)
        self.human_id = self.classes.index("human")

    def __len__(self):
        return len(self.passages)

    def __getitem__(self, idx):
        data_now = self.passages[idx]
        text = data_now["generation"]
        model = self.classes.index(data_now["model"])
        label = int(model == self.human_id)
        ids = stable_long_hash(text)
        return text, int(label), int(model), int(ids)


def build_argument_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Generate embedding databases for DETree evaluators",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--device-num", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--num-workers", type=int, default=8)
    parser.add_argument("--max-length", type=int, default=512)

    parser.add_argument("--path", type=Path, required=True, help="Dataset root directory or JSONL file path.")
    parser.add_argument("--database-name", type=str, default="M4_monolingual")
    parser.add_argument(
        "--model-name",
        type=str,
        default="FacebookAI/roberta-large",
        help=(
            "Model identifier for embeddings generation. Accepts either a Hugging Face "
            "model hub name or a local path to a directory in Hugging Face format."
        ),
    )

    parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
    parser.add_argument("--need-layer", type=int, nargs="+", default=[16, 17, 18, 19, 22, 23])

    parser.add_argument("--adversarial", dest="adversarial", action="store_true")
    parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
    parser.set_defaults(adversarial=True)

    parser.add_argument("--has-mix", dest="has_mix", action="store_true")
    parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
    parser.set_defaults(has_mix=False)

    parser.add_argument("--savedir", type=Path, required=True, help="Output directory for the embedding database.")
    parser.add_argument("--name", type=str, required=True, help="Filename (without extension) for the saved embeddings.")
    parser.add_argument("--split", type=str, default="train", choices=("train", "test", "extra"))

    return parser


def generate_embeddings(args: argparse.Namespace) -> None:
    if args.device_num > 1:
        fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy="ddp")
    else:
        fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
    fabric.launch()

    model = TextEmbeddingModel(
        args.model_name,
        output_hidden_states=True,
        infer=True,
        use_pooling=args.pooling,
    ).cuda()
    tokenizer = model.tokenizer
    model.eval()

    path_str = str(args.path)
    if "LLM_detect_data" in path_str:
        now_data = load_data(args.split, include_adversarial=args.adversarial, fp=args.path)
        now_data = now_data.to_dict(orient="records")
        dataset = PassagesDataset(now_data)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        dataloader = fabric.setup_dataloaders(dataloader)
    elif path_str.endswith(".jsonl"):
        dataset = SCLDataset([path_str], fabric, tokenizer, need_ids=True, adv_p=0)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
    else:
        data_path = load_datapath(
            path_str,
            include_adversarial=args.adversarial,
            dataset_name=args.database_name,
        )[args.split]
        dataset = SCLDataset(data_path, fabric, tokenizer, need_ids=True, adv_p=0, has_mix=args.has_mix)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)

    model = fabric.setup(model)
    classes = dataset.classes
    train_ids, train_embeddings, train_labels = infer(dataloader, fabric, tokenizer, model, args)

    torch.cuda.empty_cache()
    if fabric.global_rank == 0:
        args.savedir.mkdir(parents=True, exist_ok=True)
        emb_dict = {
            "embeddings": train_embeddings,
            "labels": train_labels,
            "ids": train_ids,
            "classes": classes,
        }
        output_path = args.savedir / f"{args.name}.pt"
        torch.save(emb_dict, output_path)
        print(f"Saved embedding database to {output_path}")


def main(argv: Optional[Iterable[str]] = None) -> None:
    parser = build_argument_parser()
    args = parser.parse_args(argv)
    generate_embeddings(args)


if __name__ == "__main__":
    main()

__all__ = ["build_argument_parser", "generate_embeddings", "main"]