import logging import random import numpy as np import torch import argparse import json from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from ..utils import evaluate_metrics logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels): assert logits_ref.shape[0] == 1 assert logits_score.shape[0] == 1 assert labels.shape[0] == 1 if logits_ref.size(-1) != logits_score.size(-1): # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.") vocab_size = min(logits_ref.size(-1), logits_score.size(-1)) logits_ref = logits_ref[:, :, :vocab_size] logits_score = logits_score[:, :, :vocab_size] labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels lprobs_score = torch.log_softmax(logits_score, dim=-1) probs_ref = torch.softmax(logits_ref, dim=-1) log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1) mean_ref = (probs_ref * lprobs_score).sum(dim=-1) var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref) discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt() discrepancy = discrepancy.mean() return discrepancy.item() def get_text_crit(text, args, model_config): tokenized = model_config["scoring_tokenizer"](text, return_tensors="pt", return_token_type_ids=False) labels = tokenized.input_ids[:, 1:] with torch.no_grad(): logits_score = model_config["scoring_model"](**tokenized).logits[:, :-1] if args.reference_model == args.scoring_model: logits_ref = logits_score else: tokenized = model_config["reference_tokenizer"](text, return_tensors="pt", return_token_type_ids=False) assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch." logits_ref = model_config["reference_model"](**tokenized).logits[:, :-1] text_crit = get_sampling_discrepancy_analytic(logits_ref, logits_score, labels) return text_crit def load_jsonl(file_path): out = [] with open(file_path, mode='r', encoding='utf-8') as jsonl_file: for line in jsonl_file: item = json.loads(line) out.append(item) print(f"Loaded {len(out)} examples from {file_path}") return out def dict2str(metrics): out_str='' for key in metrics.keys(): out_str+=f"{key}:{metrics[key]} " return out_str def experiment(args): # load model logging.info(f"Loading reference model of type {args.reference_model}...") reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model) reference_model = AutoModelForCausalLM.from_pretrained(args.reference_model,device_map="auto") reference_model.eval() reference_model scoring_tokenizer = AutoTokenizer.from_pretrained(args.scoring_model) scoring_model = AutoModelForCausalLM.from_pretrained(args.scoring_model,device_map="auto") scoring_model.eval() scoring_model model_config = { "reference_tokenizer": reference_tokenizer, "reference_model": reference_model, "scoring_tokenizer": scoring_tokenizer, "scoring_model": scoring_model, } logging.info(f"Test in {args.test_data_path}") test_data = load_jsonl(args.test_data_path) random.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) random.shuffle(test_data) predictions = [] labels = [] st = time.time() for i, item in tqdm(enumerate(test_data), total=len(test_data)): if i>=100: # for debugging, only use the first 100 samples break text = item["text"] label = item["label"] src = item["src"] text_crit = get_text_crit(text, args, model_config) if text_crit is None or np.isnan(text_crit) or np.isinf(text_crit): text_crit = 0 if 'human' in src: labels.append(0) else: labels.append(1) predictions.append(text_crit) ed = time.time() print((ed - st) / 100) # metric = evaluate_metrics(labels, predictions) # print(dict2str(metric)) # with open("runs/val-other_detector.txt",'a+') as f: # f.write(f"Fast DetectGPT {args.test_data_path} {args.scoring_model} {args.reference_model}\n") # f.write(f"{dict2str(metric)}\n") # logging.info(f"{result}") # with open(filename.split(".json")[0] + "_Fast_DetectGPT_data.json", "w") as f: # json.dump(test_data, f, indent=4) # with open(filename.split(".json")[0] + "_Fast_DetectGPT_result.json", "w") as f: # json.dump(result, f, indent=4) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--test_data_path', type=str, default='/path/to/RealBench/Beemo/Llama_edited/test.jsonl', help="Path to the test data. could be several files with ','. " "Note that the data should have been perturbed.") parser.add_argument('--reference_model', type=str, default="EleutherAI/gpt-neo-2.7B") parser.add_argument('--scoring_model', type=str, default="EleutherAI/gpt-j-6B") parser.add_argument('--DEVICE0', default="cuda:0", type=str, required=False) parser.add_argument('--DEVICE1', default="cuda:1", type=str, required=False) parser.add_argument('--seed', default=2023, type=int, required=False) args = parser.parse_args() experiment(args)