"""Embedding generation CLI for DETree.""" from __future__ import annotations import argparse from pathlib import Path from typing import Iterable, Literal, Optional import pandas as pd import torch import torch.nn.functional as F from lightning import Fabric from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from detree.model.text_embedding import TextEmbeddingModel from detree.utils.dataset import SCLDataset, load_datapath def infer(passages_dataloader, fabric, tokenizer, model, args): if fabric.global_rank == 0: passages_dataloader = tqdm(passages_dataloader) all_ids, all_embeddings, all_labels = [], {}, [] for layer in args.need_layer: all_embeddings[layer] = [] with torch.no_grad(): for batch in passages_dataloader: text, label, write_model, ids = batch encoded_batch = tokenizer.batch_encode_plus( text, return_tensors="pt", max_length=args.max_length, padding="max_length", truncation=True, ) encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} embeddings = model(encoded_batch, hidden_states=True) embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1)) label = fabric.all_gather(write_model).view(-1) ids = fabric.all_gather(ids).view(-1) if fabric.global_rank == 0: embeddings = F.normalize(embeddings, dim=-1).cpu().to(torch.bfloat16) for layer in args.need_layer: all_embeddings[layer].append(embeddings[:, layer, :].clone()) all_ids.extend(ids.cpu().tolist()) all_labels.extend(label.cpu().tolist()) del embeddings, label, ids if fabric.global_rank == 0: for layer in args.need_layer: all_embeddings[layer] = torch.cat(all_embeddings[layer], dim=0) return torch.tensor(all_ids), all_embeddings, torch.tensor(all_labels) return [], [], [] def stable_long_hash(input_string: str) -> int: import hashlib hash_object = hashlib.sha256(input_string.encode()) hex_digest = hash_object.hexdigest() int_hash = int(hex_digest, 16) return int_hash & ((1 << 63) - 1) def load_data(split: Literal["train", "test", "extra"], include_adversarial: bool, fp: Path) -> pd.DataFrame: if split not in ("train", "test", "extra"): raise ValueError("`split` must be one of (\"train\", \"test\", \"extra\")") fname = f"{split}.csv" if include_adversarial else f"{split}_none.csv" fp = fp / fname return pd.read_csv(fp) class PassagesDataset(Dataset): def __init__(self, data): self.passages = [] for item in data: if item["attack"] not in ("none", "paraphrase") and stable_long_hash(item["generation"]) % 10 < 5: continue self.passages.append(item) classes = sorted({item["model"] for item in data}) self.classes = list(classes) self.human_id = self.classes.index("human") def __len__(self): return len(self.passages) def __getitem__(self, idx): data_now = self.passages[idx] text = data_now["generation"] model = self.classes.index(data_now["model"]) label = int(model == self.human_id) ids = stable_long_hash(text) return text, int(label), int(model), int(ids) def build_argument_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Generate embedding databases for DETree evaluators", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--device-num", type=int, default=1) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--max-length", type=int, default=512) parser.add_argument("--path", type=Path, required=True, help="Dataset root directory or JSONL file path.") parser.add_argument("--database-name", type=str, default="M4_monolingual") parser.add_argument( "--model-name", type=str, default="FacebookAI/roberta-large", help=( "Model identifier for embeddings generation. Accepts either a Hugging Face " "model hub name or a local path to a directory in Hugging Face format." ), ) parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls")) parser.add_argument("--need-layer", type=int, nargs="+", default=[16, 17, 18, 19, 22, 23]) parser.add_argument("--adversarial", dest="adversarial", action="store_true") parser.add_argument("--no-adversarial", dest="adversarial", action="store_false") parser.set_defaults(adversarial=True) parser.add_argument("--has-mix", dest="has_mix", action="store_true") parser.add_argument("--no-has-mix", dest="has_mix", action="store_false") parser.set_defaults(has_mix=False) parser.add_argument("--savedir", type=Path, required=True, help="Output directory for the embedding database.") parser.add_argument("--name", type=str, required=True, help="Filename (without extension) for the saved embeddings.") parser.add_argument("--split", type=str, default="train", choices=("train", "test", "extra")) return parser def generate_embeddings(args: argparse.Namespace) -> None: if args.device_num > 1: fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num, strategy="ddp") else: fabric = Fabric(accelerator="cuda", precision="bf16-mixed", devices=args.device_num) fabric.launch() model = TextEmbeddingModel( args.model_name, output_hidden_states=True, infer=True, use_pooling=args.pooling, ).cuda() tokenizer = model.tokenizer model.eval() path_str = str(args.path) if "LLM_detect_data" in path_str: now_data = load_data(args.split, include_adversarial=args.adversarial, fp=args.path) now_data = now_data.to_dict(orient="records") dataset = PassagesDataset(now_data) dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) dataloader = fabric.setup_dataloaders(dataloader) elif path_str.endswith(".jsonl"): dataset = SCLDataset([path_str], fabric, tokenizer, need_ids=True, adv_p=0) dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False) else: data_path = load_datapath( path_str, include_adversarial=args.adversarial, dataset_name=args.database_name, )[args.split] dataset = SCLDataset(data_path, fabric, tokenizer, need_ids=True, adv_p=0, has_mix=args.has_mix) dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False) model = fabric.setup(model) classes = dataset.classes train_ids, train_embeddings, train_labels = infer(dataloader, fabric, tokenizer, model, args) torch.cuda.empty_cache() if fabric.global_rank == 0: args.savedir.mkdir(parents=True, exist_ok=True) emb_dict = { "embeddings": train_embeddings, "labels": train_labels, "ids": train_ids, "classes": classes, } output_path = args.savedir / f"{args.name}.pt" torch.save(emb_dict, output_path) print(f"Saved embedding database to {output_path}") def main(argv: Optional[Iterable[str]] = None) -> None: parser = build_argument_parser() args = parser.parse_args(argv) generate_embeddings(args) if __name__ == "__main__": main() __all__ = ["build_argument_parser", "generate_embeddings", "main"]