subnet32-llm-detector / scripts /verify_var.py
ThaoTran7's picture
incomplete commit
485127c
import random
import numpy as np
import torch
from torch import nn
import tqdm
import argparse
import json
from data_builder import load_data
from model import load_tokenizer, load_model
from nuisance_func import BSplineTwoSample
from utils import load_training_data, separated_string
def sample_stat_value(logits_ref, logits_score, labels, w_func):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
lprobs_score = w_func(torch.log_softmax(logits_score, dim=-1))
probs_ref = torch.softmax(logits_ref, dim=-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
stat = var_ref.flatten()
stat = torch.cumsum(stat, dim=0) / torch.arange(1, len(stat)+1, dtype=torch.float32, device=stat.device)
return stat
def human_stat_value(logits_ref, logits_score, labels, w_func):
assert logits_ref.shape[0] == 1
assert logits_score.shape[0] == 1
assert labels.shape[0] == 1
if logits_ref.size(-1) != logits_score.size(-1):
# print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
logits_ref = logits_ref[:, :, :vocab_size]
logits_score = logits_score[:, :, :vocab_size]
labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
lprobs_score = w_func(torch.log_softmax(logits_score, dim=-1))
probs_ref = torch.softmax(logits_ref, dim=-1)
log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
mean_ref = mean_ref.flatten()
log_likelihood = log_likelihood.flatten()
stat = torch.zeros(log_likelihood.shape[0], device=log_likelihood.device)
for j in range(1, log_likelihood.shape[0] + 1):
term1 = torch.var(log_likelihood[:j], unbiased=False)
term2 = torch.var(mean_ref[:j], unbiased=False)
stat[j-1] = term1 - term2
stat[0] = torch.zeros(1, device=log_likelihood.device)
stat = torch.cumsum(stat, dim=0) / torch.arange(1, len(stat)+1, dtype=torch.float32, device=stat.device)
return stat
def compute_sample_variance(cumsum_stat_list, L):
sample_var_list = torch.zeros(L-1)
for l in range(L-1):
sample_var = torch.var(torch.tensor([cumsum_stat[l] for cumsum_stat in cumsum_stat_list if len(cumsum_stat) >= l+1]), unbiased=False)
sample_var_list[l] = sample_var
sample_var_list = torch.cumsum(sample_var_list, dim=0) / torch.arange(1, L, dtype=torch.float32)
return sample_var_list
def experiment(args):
# load model
scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.cache_dir)
scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir)
scoring_model.eval()
if args.sampling_model_name != args.scoring_model_name:
sampling_tokenizer = load_tokenizer(args.sampling_model_name, args.cache_dir)
sampling_model = load_model(args.sampling_model_name, args.device, args.cache_dir)
sampling_model.eval()
# load data
data = load_data(args.dataset_file)
n_samples = len(data["sampled"])
sample_criterion_fn = sample_stat_value
human_criterion_fn = human_stat_value
if args.w_func == 'identity':
w_func = nn.Identity()
else:
bspline_args = args.config
## load training data
print(f"Datasets for learning BSpline: {args.train_dataset}")
train_data = load_training_data(args.train_dataset)
human_token_list = [scoring_tokenizer(x, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) for x in train_data['original']]
if args.w_func == 'bspline' or args.w_func == 'bspline_theory_constrained':
machine_token_list = [scoring_tokenizer(x, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) for x in train_data['sampled']]
if args.w_func == 'bspline':
w_func = BSplineTwoSample(bspline_args, args.device)
w_func.fit(human_token_list, machine_token_list, scoring_model, args)
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
results = []
original_crit_list = []
sampled_crit_list = []
for idx in tqdm.tqdm(range(n_samples), desc=f"Computing sample variance"):
original_text = data["original"][idx]
sampled_text = data["sampled"][idx]
if args.compute_text == 'human':
computed_text = original_text
else:
computed_text = sampled_text
# when next token is drawn from human
tokenized = scoring_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if args.sampling_model_name == args.scoring_model_name:
logits_ref = logits_score
else:
tokenized = sampling_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = sampling_model(**tokenized).logits[:, :-1]
original_crit = human_criterion_fn(logits_ref, logits_score, labels, w_func)
original_crit_list.append(original_crit)
# when next token is drawn from machine (because the ratio involve the same random tokens as conditions, we set the text as the human one)
tokenized = scoring_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
labels = tokenized.input_ids[:, 1:]
with torch.no_grad():
logits_score = scoring_model(**tokenized).logits[:, :-1]
if args.sampling_model_name == args.scoring_model_name:
logits_ref = logits_score
else:
tokenized = sampling_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device)
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
logits_ref = sampling_model(**tokenized).logits[:, :-1]
sampled_crit = sample_criterion_fn(logits_ref, logits_score, labels, w_func)
sampled_crit_list.append(sampled_crit)
results_file = f'{args.output_file}.{args.w_func}.json'
ratio_list = []
for idx in range(n_samples):
L = min(len(original_crit_list[idx]), len(sampled_crit_list[idx]))
ratio = original_crit_list[idx][:L] / sampled_crit_list[idx][:L]
ratio_list.append(ratio)
L = int(torch.quantile(torch.tensor([len(ratio)*1.0 for ratio in ratio_list]), q=0.81).item())
ratio_var_list = torch.zeros(L-1)
ratio_mean_list = torch.zeros(L-1)
for l in range(L-1):
ratio_l = torch.tensor([ratio[l] for ratio in ratio_list if len(ratio) >= l+1])
ratio_var_list[l] = torch.var(ratio_l, unbiased=False)
ratio_mean_list[l] = torch.mean(ratio_l)
ratio_mean_list = ratio_mean_list[1:]
ratio_var_list = ratio_var_list[1:]
results = {
'ratio_var': ratio_var_list.tolist(),
'ratio_mean': ratio_mean_list.tolist(),
'ratio_var/mean': (ratio_var_list / ratio_mean_list).tolist(),
}
with open(results_file, 'w') as fout:
json.dump(results, fout)
print(f'Results written into {results_file}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', type=str, default="./exp_variance/results/xsum_gpt2-xl") # output file prefix
parser.add_argument('--dataset', type=str, default="xsum")
parser.add_argument('--dataset_file', type=str, default="./exp_variance/data/xsum_gpt2-xl")
parser.add_argument('--train_dataset', type=separated_string, default=[])
parser.add_argument('--sampling_model_name', type=str, default="gpt2-xl")
parser.add_argument('--scoring_model_name', type=str, default="gpt2-xl")
parser.add_argument('--w_func', type=str, default='identity', choices=['identity', 'absoluate', 'bspline'])
parser.add_argument("--config", type=json.loads, default='{"start": -32, "end": 0, "n_bases": 7, "spline_order": 2, "intercept": 1}', help='A JSON dict')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--compute_text', type=str, default='llm', choices=['human', 'llm'])
parser.add_argument('--cache_dir', type=str, default="../cache")
args = parser.parse_args()
experiment(args)