Authentica / detree /utils /detectors /Fast_DetectGPT_evaluation.py
MAS-AI-0000's picture
Upload 6 files
3cdaafb verified
import logging
import random
import numpy as np
import torch
import argparse
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ..utils import evaluate_metrics
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
lprobs_score = torch.log_softmax(logits_score, dim=-1)
probs_ref = torch.softmax(logits_ref, dim=-1)
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
discrepancy = discrepancy.mean()
return discrepancy.item()
def get_text_crit(text, args, model_config):
tokenized = model_config["scoring_tokenizer"](text, return_tensors="pt",
return_token_type_ids=False)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = model_config["scoring_model"](**tokenized).logits[:, :-1]
if args.reference_model == args.scoring_model:
logits_ref = logits_score
else:
tokenized = model_config["reference_tokenizer"](text, return_tensors="pt",
return_token_type_ids=False)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = model_config["reference_model"](**tokenized).logits[:, :-1]
text_crit = get_sampling_discrepancy_analytic(logits_ref, logits_score, labels)
return text_crit
def load_jsonl(file_path):
out = []
with open(file_path, mode='r', encoding='utf-8') as jsonl_file:
for line in jsonl_file:
item = json.loads(line)
out.append(item)
print(f"Loaded {len(out)} examples from {file_path}")
return out
def dict2str(metrics):
out_str=''
for key in metrics.keys():
out_str+=f"{key}:{metrics[key]} "
return out_str
def experiment(args):
# load model
logging.info(f"Loading reference model of type {args.reference_model}...")
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
reference_model = AutoModelForCausalLM.from_pretrained(args.reference_model,device_map="auto")
reference_model.eval()
reference_model
scoring_tokenizer = AutoTokenizer.from_pretrained(args.scoring_model)
scoring_model = AutoModelForCausalLM.from_pretrained(args.scoring_model,device_map="auto")
scoring_model.eval()
scoring_model
model_config = {
"reference_tokenizer": reference_tokenizer,
"reference_model": reference_model,
"scoring_tokenizer": scoring_tokenizer,
"scoring_model": scoring_model,
}
logging.info(f"Test in {args.test_data_path}")
test_data = load_jsonl(args.test_data_path)
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.shuffle(test_data)
predictions = []
labels = []
st = time.time()
for i, item in tqdm(enumerate(test_data), total=len(test_data)):
if i>=100: # for debugging, only use the first 100 samples
break
text = item["text"]
label = item["label"]
src = item["src"]
text_crit = get_text_crit(text, args, model_config)
if text_crit is None or np.isnan(text_crit) or np.isinf(text_crit):
text_crit = 0
if 'human' in src:
labels.append(0)
else:
labels.append(1)
predictions.append(text_crit)
ed = time.time()
print((ed - st) / 100)
# metric = evaluate_metrics(labels, predictions)
# print(dict2str(metric))
# with open("runs/val-other_detector.txt",'a+') as f:
# f.write(f"Fast DetectGPT {args.test_data_path} {args.scoring_model} {args.reference_model}\n")
# f.write(f"{dict2str(metric)}\n")
# logging.info(f"{result}")
# with open(filename.split(".json")[0] + "_Fast_DetectGPT_data.json", "w") as f:
# json.dump(test_data, f, indent=4)
# with open(filename.split(".json")[0] + "_Fast_DetectGPT_result.json", "w") as f:
# json.dump(result, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--test_data_path', type=str, default='/path/to/RealBench/Beemo/Llama_edited/test.jsonl',
help="Path to the test data. could be several files with ','. "
"Note that the data should have been perturbed.")
parser.add_argument('--reference_model', type=str, default="EleutherAI/gpt-neo-2.7B")
parser.add_argument('--scoring_model', type=str, default="EleutherAI/gpt-j-6B")
parser.add_argument('--DEVICE0', default="cuda:0", type=str, required=False)
parser.add_argument('--DEVICE1', default="cuda:1", type=str, required=False)
parser.add_argument('--seed', default=2023, type=int, required=False)
args = parser.parse_args()
experiment(args)