Authentica / detree /cli /embeddings.py
MAS-AI-0000's picture
Upload 9 files
4d939fc verified
raw
history blame
8.27 kB
"""Embedding generation CLI for DETree."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Iterable, Literal, Optional
import pandas as pd
import torch
import torch.nn.functional as F
from lightning import Fabric
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from detree.model.text_embedding import TextEmbeddingModel
from detree.utils.dataset import SCLDataset, load_datapath
def infer(passages_dataloader, fabric, tokenizer, model, args):
if fabric.global_rank == 0:
passages_dataloader = tqdm(passages_dataloader)
all_ids, all_embeddings, all_labels = [], {}, []
for layer in args.need_layer:
all_embeddings[layer] = []
with torch.no_grad():
for batch in passages_dataloader:
text, label, write_model, ids = batch
encoded_batch = tokenizer.batch_encode_plus(
text,
return_tensors="pt",
max_length=args.max_length,
padding="max_length",
truncation=True,
)
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
embeddings = model(encoded_batch, hidden_states=True)
embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1))
label = fabric.all_gather(write_model).view(-1)
ids = fabric.all_gather(ids).view(-1)
if fabric.global_rank == 0:
embeddings = F.normalize(embeddings, dim=-1).cpu().to(torch.bfloat16)
for layer in args.need_layer:
all_embeddings[layer].append(embeddings[:, layer, :].clone())
all_ids.extend(ids.cpu().tolist())
all_labels.extend(label.cpu().tolist())
del embeddings, label, ids
if fabric.global_rank == 0:
for layer in args.need_layer:
all_embeddings[layer] = torch.cat(all_embeddings[layer], dim=0)
return torch.tensor(all_ids), all_embeddings, torch.tensor(all_labels)
return [], [], []
def stable_long_hash(input_string: str) -> int:
import hashlib
hash_object = hashlib.sha256(input_string.encode())
hex_digest = hash_object.hexdigest()
int_hash = int(hex_digest, 16)
return int_hash & ((1 << 63) - 1)
def load_data(split: Literal["train", "test", "extra"], include_adversarial: bool, fp: Path) -> pd.DataFrame:
if split not in ("train", "test", "extra"):
raise ValueError("`split` must be one of (\"train\", \"test\", \"extra\")")
fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv"
fp = fp / fname
return pd.read_csv(fp)
class PassagesDataset(Dataset):
def __init__(self, data):
self.passages = []
for item in data:
if item["attack"] not in ("none", "paraphrase") and stable_long_hash(item["generation"]) % 10 < 5:
continue
self.passages.append(item)
classes = sorted({item["model"] for item in data})
self.classes = list(classes)
self.human_id = self.classes.index("human")
def __len__(self):
return len(self.passages)
def __getitem__(self, idx):
data_now = self.passages[idx]
text = data_now["generation"]
model = self.classes.index(data_now["model"])
label = int(model == self.human_id)
ids = stable_long_hash(text)
return text, int(label), int(model), int(ids)
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Generate embedding databases for DETree evaluators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--device-num", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--max-length", type=int, default=512)
parser.add_argument("--path", type=Path, required=True, help="Dataset root directory or JSONL file path.")
parser.add_argument("--database-name", type=str, default="M4_monolingual")
parser.add_argument(
"--model-name",
type=str,
default="FacebookAI/roberta-large",
help=(
"Model identifier for embeddings generation. Accepts either a Hugging Face "
"model hub name or a local path to a directory in Hugging Face format."
),
)
parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls"))
parser.add_argument("--need-layer", type=int, nargs="+", default=[16, 17, 18, 19, 22, 23])
parser.add_argument("--adversarial", dest="adversarial", action="store_true")
parser.add_argument("--no-adversarial", dest="adversarial", action="store_false")
parser.set_defaults(adversarial=True)
parser.add_argument("--has-mix", dest="has_mix", action="store_true")
parser.add_argument("--no-has-mix", dest="has_mix", action="store_false")
parser.set_defaults(has_mix=False)
parser.add_argument("--savedir", type=Path, required=True, help="Output directory for the embedding database.")
parser.add_argument("--name", type=str, required=True, help="Filename (without extension) for the saved embeddings.")
parser.add_argument("--split", type=str, default="train", choices=("train", "test", "extra"))
return parser
def generate_embeddings(args: argparse.Namespace) -> None:
if args.device_num > 1:
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy="ddp")
else:
fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num)
fabric.launch()
model = TextEmbeddingModel(
args.model_name,
output_hidden_states=True,
infer=True,
use_pooling=args.pooling,
).cuda()
tokenizer = model.tokenizer
model.eval()
path_str = str(args.path)
if "LLM_detect_data" in path_str:
now_data = load_data(args.split, include_adversarial=args.adversarial, fp=args.path)
now_data = now_data.to_dict(orient="records")
dataset = PassagesDataset(now_data)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
dataloader = fabric.setup_dataloaders(dataloader)
elif path_str.endswith(".jsonl"):
dataset = SCLDataset([path_str], fabric, tokenizer, need_ids=True, adv_p=0)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
else:
data_path = load_datapath(
path_str,
include_adversarial=args.adversarial,
dataset_name=args.database_name,
)[args.split]
dataset = SCLDataset(data_path, fabric, tokenizer, need_ids=True, adv_p=0, has_mix=args.has_mix)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
model = fabric.setup(model)
classes = dataset.classes
train_ids, train_embeddings, train_labels = infer(dataloader, fabric, tokenizer, model, args)
torch.cuda.empty_cache()
if fabric.global_rank == 0:
args.savedir.mkdir(parents=True, exist_ok=True)
emb_dict = {
"embeddings": train_embeddings,
"labels": train_labels,
"ids": train_ids,
"classes": classes,
}
output_path = args.savedir / f"{args.name}.pt"
torch.save(emb_dict, output_path)
print(f"Saved embedding database to {output_path}")
def main(argv: Optional[Iterable[str]] = None) -> None:
parser = build_argument_parser()
args = parser.parse_args(argv)
generate_embeddings(args)
if __name__ == "__main__":
main()
__all__ = ["build_argument_parser", "generate_embeddings", "main"]