Authentica / detree /utils /detectors /Fast_DetectGPT_evaluation.py
MAS-AI-0000's picture
Upload 6 files
3cdaafb verified
raw
history blame
5.85 kB
import logging
import random
import numpy as np
import torch
import argparse
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ..utils import evaluate_metrics
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
lprobs_score = torch.log_softmax(logits_score, dim=-1)
probs_ref = torch.softmax(logits_ref, dim=-1)
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
discrepancy = discrepancy.mean()
return discrepancy.item()
def get_text_crit(text, args, model_config):
tokenized = model_config["scoring_tokenizer"](text, return_tensors="pt",
return_token_type_ids=False)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = model_config["scoring_model"](**tokenized).logits[:, :-1]
if args.reference_model == args.scoring_model:
logits_ref = logits_score
else:
tokenized = model_config["reference_tokenizer"](text, return_tensors="pt",
return_token_type_ids=False)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = model_config["reference_model"](**tokenized).logits[:, :-1]
text_crit = get_sampling_discrepancy_analytic(logits_ref, logits_score, labels)
return text_crit
def load_jsonl(file_path):
out = []
with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
for line in jsonl_file:
item = json.loads(line)
out.append(item)
print(f"Loaded {len(out)} examples from {file_path}")
return out
def dict2str(metrics):
out_str=''
for key in metrics.keys():
out_str+=f"{key}:{metrics[key]} "
return out_str
def experiment(args):
# load model
logging.info(f"Loading reference model of type {args.reference_model}...")
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
reference_model = AutoModelForCausalLM.from_pretrained(args.reference_model,device_map="auto")
reference_model.eval()
reference_model
scoring_tokenizer = AutoTokenizer.from_pretrained(args.scoring_model)
scoring_model = AutoModelForCausalLM.from_pretrained(args.scoring_model,device_map="auto")
scoring_model.eval()
scoring_model
model_config = {
"reference_tokenizer": reference_tokenizer,
"reference_model": reference_model,
"scoring_tokenizer": scoring_tokenizer,
"scoring_model": scoring_model,
}
logging.info(f"Test in {args.test_data_path}")
test_data = load_jsonl(args.test_data_path)
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.shuffle(test_data)
predictions = []
labels = []
st = time.time()
for i, item in tqdm(enumerate(test_data), total=len(test_data)):
if i>=100: # for debugging, only use the first 100 samples
break
text = item["text"]
label = item["label"]
src = item["src"]
text_crit = get_text_crit(text, args, model_config)
if text_crit is None or np.isnan(text_crit) or np.isinf(text_crit):
text_crit = 0
if 'human' in src:
labels.append(0)
else:
labels.append(1)
predictions.append(text_crit)
ed = time.time()
print((ed - st) / 100)
# metric = evaluate_metrics(labels, predictions)
# print(dict2str(metric))
# with open("runs/val-other_detector.txt",'a+') as f:
# f.write(f"Fast DetectGPT {args.test_data_path} {args.scoring_model} {args.reference_model}\n")
# f.write(f"{dict2str(metric)}\n")
# logging.info(f"{result}")
# with open(filename.split(".json")[0] + "_Fast_DetectGPT_data.json", "w") as f:
# json.dump(test_data, f, indent=4)
# with open(filename.split(".json")[0] + "_Fast_DetectGPT_result.json", "w") as f:
# json.dump(result, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test_data_path', type=str, default='/path/to/RealBench/Beemo/Llama_edited/test.jsonl',
help="Path to the test data. could be several files with ','. "
"Note that the data should have been perturbed.")
parser.add_argument('--reference_model', type=str, default="EleutherAI/gpt-neo-2.7B")
parser.add_argument('--scoring_model', type=str, default="EleutherAI/gpt-j-6B")
parser.add_argument('--DEVICE0', default="cuda:0", type=str, required=False)
parser.add_argument('--DEVICE1', default="cuda:1", type=str, required=False)
parser.add_argument('--seed', default=2023, type=int, required=False)
args = parser.parse_args()
experiment(args)