Spaces:
Running
Running
| """kNN evaluation using pre-computed embedding databases.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from multiprocessing import Pool, cpu_count | |
| from pathlib import Path | |
| from typing import Iterable, List, Optional, Sequence, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from lightning import Fabric | |
| from torch.nn.functional import softmax as F_softmax | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from detree.model.text_embedding import TextEmbeddingModel | |
| from detree.utils.index import Indexer | |
| from detree.utils.utils import evaluate_metrics | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "true") | |
| def load_jsonl(file_path: Path) -> List[dict]: | |
| out = [] | |
| with file_path.open(mode="r", encoding="utf-8") as jsonl_file: | |
| for line in jsonl_file: | |
| item = json.loads(line) | |
| out.append(item) | |
| print(f"Loaded {len(out)} examples from {file_path}") | |
| return out | |
| def gen_data(dict_data): | |
| embeddings = dict_data["embeddings"] | |
| labels = dict_data["labels"] | |
| ids = dict_data["ids"] | |
| classes = dict_data["classes"] | |
| return embeddings, labels, ids, classes | |
| class PassagesDataset(Dataset): | |
| def __init__(self, data: Sequence[dict]): | |
| self.passages = list(data) | |
| def __len__(self) -> int: | |
| return len(self.passages) | |
| def __getitem__(self, idx: int): | |
| data_now = self.passages[idx] | |
| text = data_now["text"] | |
| label = data_now["label"] | |
| ids = data_now["id"] | |
| return text, int(label), int(ids) | |
| def infer(passages_dataloader, fabric, tokenizer, model, need_layers: Sequence[int], max_length: int = 512): | |
| if fabric.global_rank == 0: | |
| passages_dataloader = tqdm(passages_dataloader) | |
| all_ids: List[int] = [] | |
| all_embeddings: List[torch.Tensor] = [] | |
| all_labels: List[int] = [] | |
| with torch.no_grad(): | |
| for batch in passages_dataloader: | |
| text, label, ids = batch | |
| encoded_batch = tokenizer.batch_encode_plus( | |
| text, | |
| return_tensors="pt", | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} | |
| embeddings = model(encoded_batch, hidden_states=True) | |
| embeddings = fabric.all_gather(embeddings).view(-1, embeddings.size(-2), embeddings.size(-1)) | |
| label = fabric.all_gather(label).view(-1) | |
| ids = fabric.all_gather(ids).view(-1) | |
| if fabric.global_rank == 0: | |
| all_embeddings.append(embeddings.cpu()) | |
| all_ids.extend(ids.cpu().tolist()) | |
| all_labels.extend(label.cpu().tolist()) | |
| if fabric.global_rank == 0: | |
| embeddings_tensor = torch.cat(all_embeddings, dim=0) | |
| embeddings_tensor = F.normalize(embeddings_tensor, dim=-1).permute(1, 0, 2).numpy() | |
| embeddings_tensor = {layer: embeddings_tensor[layer] for layer in need_layers} | |
| return all_ids, embeddings_tensor, all_labels | |
| return [], [], [] | |
| def dict2str(metrics: dict) -> str: | |
| out_str = "" | |
| if "layer" in metrics: | |
| out_str += f"layer:{metrics['layer']} " | |
| if "k" in metrics: | |
| out_str += f"k:{metrics['k']} " | |
| for key, value in metrics.items(): | |
| if key not in {"layer", "k"}: | |
| out_str += f"{key}:{value} " | |
| return out_str.strip() | |
| def process_element(args: Tuple[Sequence[int], Sequence[float], Sequence[int], float]): | |
| ids, scores, labels, temperature = args | |
| now_score = torch.zeros(2) | |
| sorted_indices = np.argsort(scores)[::-1] | |
| element_preds = {} | |
| for k, idx in enumerate(sorted_indices): | |
| label = labels[idx] | |
| now_score[label] += scores[idx] * temperature | |
| prob = F_softmax(now_score, dim=-1)[1].item() | |
| element_preds[k + 1] = prob | |
| return element_preds | |
| def build_argument_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="Evaluate DETree with a precomputed embedding database.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser.add_argument("--device-num", type=int, default=1) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--num-workers", type=int, default=8) | |
| parser.add_argument("--max-length", type=int, default=512) | |
| parser.add_argument("--database-path", type=Path, required=True, help="Path to the saved embedding database (.pt).") | |
| parser.add_argument("--test-dataset-path", type=Path, required=True, help="Evaluation JSONL file.") | |
| parser.add_argument("--model-name-or-path", type=str, required=True) | |
| parser.add_argument("--temperature", type=float, default=0.05) | |
| parser.add_argument("--max-k", type=int, default=51, dest="max_K") | |
| parser.add_argument("--pooling", type=str, default="max", choices=("max", "average", "cls")) | |
| parser.add_argument("--embedding-dim", type=int, default=1024) | |
| parser.add_argument("--pool-workers", type=int, default=min(32, cpu_count())) | |
| parser.add_argument("--log-file", type=Path, default=Path("runs/val.txt")) | |
| return parser | |
| def evaluate(args: argparse.Namespace) -> None: | |
| if args.device_num > 1: | |
| fabric = Fabric(accelerator="cuda", devices=args.device_num, strategy="ddp", precision="bf16-mixed") | |
| else: | |
| fabric = Fabric(accelerator="cuda", devices=args.device_num, precision="bf16-mixed") | |
| fabric.launch() | |
| model = TextEmbeddingModel( | |
| args.model_name_or_path, | |
| output_hidden_states=True, | |
| infer=True, | |
| use_pooling=args.pooling, | |
| ).cuda() | |
| tokenizer = model.tokenizer | |
| model.eval() | |
| if fabric.global_rank == 0: | |
| db_embeddings, db_labels, db_ids, classes = gen_data(torch.load(args.database_path)) | |
| need_layers = list(db_embeddings.keys()) | |
| else: | |
| db_embeddings = db_labels = db_ids = classes = None | |
| need_layers = [] | |
| need_layers = fabric.broadcast(need_layers) | |
| test_database = load_jsonl(args.test_dataset_path) | |
| test_dataset = PassagesDataset(test_database) | |
| test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) | |
| test_dataloader = fabric.setup_dataloaders(test_dataloader) | |
| model = fabric.setup(model) | |
| test_ids, test_embeddings, test_labels = infer(test_dataloader, fabric, tokenizer, model, need_layers, args.max_length) | |
| torch.cuda.empty_cache() | |
| if fabric.global_rank != 0: | |
| return | |
| test_labels = [int(label) for label in test_labels] | |
| index = Indexer(args.embedding_dim) | |
| human_idx = classes.index("human") | |
| all_details = [] | |
| with Pool(processes=args.pool_workers) as pool: | |
| for layer in need_layers: | |
| now_best_metrics = None | |
| label_dict = {} | |
| train_embeddings = db_embeddings[layer].float().numpy() | |
| if isinstance(db_labels, dict): | |
| train_labels = db_labels[layer].tolist() | |
| train_ids = db_ids[layer].tolist() | |
| else: | |
| train_labels = db_labels.tolist() | |
| train_ids = db_ids.tolist() | |
| for i in range(len(train_ids)): | |
| label_dict[int(train_ids[i])] = int(train_labels[i] == human_idx) | |
| index.label_dict = label_dict | |
| index.reset() | |
| index.index_data(train_ids, train_embeddings) | |
| preds = {k: [] for k in range(1, args.max_K + 1)} | |
| top_ids_and_scores = index.search_knn(test_embeddings[layer], args.max_K, index_batch_size=128) | |
| args_list = [ | |
| (ids, scores, labels, args.temperature) | |
| for ids, scores, labels in top_ids_and_scores | |
| ] | |
| for result in tqdm(pool.imap(process_element, args_list), total=len(args_list)): | |
| for k, value in result.items(): | |
| preds[k].append(value) | |
| for k in range(1, args.max_K + 1): | |
| metric = evaluate_metrics(test_labels, preds[k], threshold_param=-1) | |
| if now_best_metrics is None or now_best_metrics["auroc"] < metric["auroc"]: | |
| now_best_metrics = metric | |
| now_best_metrics["k"] = k | |
| now_best_metrics["layer"] = layer | |
| if now_best_metrics: | |
| print(dict2str(now_best_metrics)) | |
| all_details.append(now_best_metrics) | |
| if not all_details: | |
| return | |
| max_ids = max(range(len(all_details)), key=lambda idx: all_details[idx]["auroc"]) | |
| best_metrics = all_details[max_ids] | |
| print("Best " + dict2str(best_metrics)) | |
| args.log_file.parent.mkdir(parents=True, exist_ok=True) | |
| with args.log_file.open("a+", encoding="utf-8") as fp: | |
| fp.write(f"test model:{args.model_name_or_path} mode:{args.test_dataset_path} database_path:{args.database_path}\n") | |
| fp.write(f"Last {dict2str(all_details[-1])}\n") | |
| fp.write(f"Best {dict2str(best_metrics)}\n") | |
| fp.write("------------------------------------------\n") | |
| def main(argv: Optional[Iterable[str]] = None) -> None: | |
| parser = build_argument_parser() | |
| args = parser.parse_args(argv) | |
| evaluate(args) | |
| if __name__ == "__main__": | |
| main() | |
| __all__ = ["build_argument_parser", "evaluate", "main"] | |