| import random |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import tqdm |
| import argparse |
| import json |
| from data_builder import load_data |
| from model import load_tokenizer, load_model |
| from nuisance_func import BSplineTwoSample |
| from utils import load_training_data, separated_string |
|
|
| def sample_stat_value(logits_ref, logits_score, labels, w_func): |
| assert logits_ref.shape[0] == 1 |
| assert logits_score.shape[0] == 1 |
| assert labels.shape[0] == 1 |
| if logits_ref.size(-1) != logits_score.size(-1): |
| |
| vocab_size = min(logits_ref.size(-1), logits_score.size(-1)) |
| logits_ref = logits_ref[:, :, :vocab_size] |
| logits_score = logits_score[:, :, :vocab_size] |
|
|
| lprobs_score = w_func(torch.log_softmax(logits_score, dim=-1)) |
| probs_ref = torch.softmax(logits_ref, dim=-1) |
| mean_ref = (probs_ref * lprobs_score).sum(dim=-1) |
| var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref) |
| stat = var_ref.flatten() |
| stat = torch.cumsum(stat, dim=0) / torch.arange(1, len(stat)+1, dtype=torch.float32, device=stat.device) |
| return stat |
|
|
| def human_stat_value(logits_ref, logits_score, labels, w_func): |
| assert logits_ref.shape[0] == 1 |
| assert logits_score.shape[0] == 1 |
| assert labels.shape[0] == 1 |
| if logits_ref.size(-1) != logits_score.size(-1): |
| |
| vocab_size = min(logits_ref.size(-1), logits_score.size(-1)) |
| logits_ref = logits_ref[:, :, :vocab_size] |
| logits_score = logits_score[:, :, :vocab_size] |
|
|
| labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels |
| lprobs_score = w_func(torch.log_softmax(logits_score, dim=-1)) |
| probs_ref = torch.softmax(logits_ref, dim=-1) |
| log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1) |
| mean_ref = (probs_ref * lprobs_score).sum(dim=-1) |
|
|
| mean_ref = mean_ref.flatten() |
| log_likelihood = log_likelihood.flatten() |
| stat = torch.zeros(log_likelihood.shape[0], device=log_likelihood.device) |
| for j in range(1, log_likelihood.shape[0] + 1): |
| term1 = torch.var(log_likelihood[:j], unbiased=False) |
| term2 = torch.var(mean_ref[:j], unbiased=False) |
| stat[j-1] = term1 - term2 |
| stat[0] = torch.zeros(1, device=log_likelihood.device) |
|
|
| stat = torch.cumsum(stat, dim=0) / torch.arange(1, len(stat)+1, dtype=torch.float32, device=stat.device) |
| return stat |
|
|
| def compute_sample_variance(cumsum_stat_list, L): |
| sample_var_list = torch.zeros(L-1) |
| for l in range(L-1): |
| sample_var = torch.var(torch.tensor([cumsum_stat[l] for cumsum_stat in cumsum_stat_list if len(cumsum_stat) >= l+1]), unbiased=False) |
| sample_var_list[l] = sample_var |
| sample_var_list = torch.cumsum(sample_var_list, dim=0) / torch.arange(1, L, dtype=torch.float32) |
| return sample_var_list |
|
|
| def experiment(args): |
| |
| scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.cache_dir) |
| scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir) |
| scoring_model.eval() |
| if args.sampling_model_name != args.scoring_model_name: |
| sampling_tokenizer = load_tokenizer(args.sampling_model_name, args.cache_dir) |
| sampling_model = load_model(args.sampling_model_name, args.device, args.cache_dir) |
| sampling_model.eval() |
| |
| data = load_data(args.dataset_file) |
| n_samples = len(data["sampled"]) |
|
|
| sample_criterion_fn = sample_stat_value |
| human_criterion_fn = human_stat_value |
| |
| if args.w_func == 'identity': |
| w_func = nn.Identity() |
| else: |
| bspline_args = args.config |
| |
| print(f"Datasets for learning BSpline: {args.train_dataset}") |
| train_data = load_training_data(args.train_dataset) |
| human_token_list = [scoring_tokenizer(x, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) for x in train_data['original']] |
| if args.w_func == 'bspline' or args.w_func == 'bspline_theory_constrained': |
| machine_token_list = [scoring_tokenizer(x, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) for x in train_data['sampled']] |
| |
| if args.w_func == 'bspline': |
| w_func = BSplineTwoSample(bspline_args, args.device) |
| w_func.fit(human_token_list, machine_token_list, scoring_model, args) |
| |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| results = [] |
| original_crit_list = [] |
| sampled_crit_list = [] |
|
|
| for idx in tqdm.tqdm(range(n_samples), desc=f"Computing sample variance"): |
| original_text = data["original"][idx] |
| sampled_text = data["sampled"][idx] |
| if args.compute_text == 'human': |
| computed_text = original_text |
| else: |
| computed_text = sampled_text |
| |
| tokenized = scoring_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) |
| labels = tokenized.input_ids[:, 1:] |
| with torch.no_grad(): |
| logits_score = scoring_model(**tokenized).logits[:, :-1] |
| if args.sampling_model_name == args.scoring_model_name: |
| logits_ref = logits_score |
| else: |
| tokenized = sampling_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) |
| assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch." |
| logits_ref = sampling_model(**tokenized).logits[:, :-1] |
| original_crit = human_criterion_fn(logits_ref, logits_score, labels, w_func) |
| original_crit_list.append(original_crit) |
| |
| |
| tokenized = scoring_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) |
| labels = tokenized.input_ids[:, 1:] |
| with torch.no_grad(): |
| logits_score = scoring_model(**tokenized).logits[:, :-1] |
| if args.sampling_model_name == args.scoring_model_name: |
| logits_ref = logits_score |
| else: |
| tokenized = sampling_tokenizer(computed_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(args.device) |
| assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch." |
| logits_ref = sampling_model(**tokenized).logits[:, :-1] |
| sampled_crit = sample_criterion_fn(logits_ref, logits_score, labels, w_func) |
| sampled_crit_list.append(sampled_crit) |
|
|
| results_file = f'{args.output_file}.{args.w_func}.json' |
| ratio_list = [] |
| for idx in range(n_samples): |
| L = min(len(original_crit_list[idx]), len(sampled_crit_list[idx])) |
| ratio = original_crit_list[idx][:L] / sampled_crit_list[idx][:L] |
| ratio_list.append(ratio) |
|
|
| L = int(torch.quantile(torch.tensor([len(ratio)*1.0 for ratio in ratio_list]), q=0.81).item()) |
|
|
| ratio_var_list = torch.zeros(L-1) |
| ratio_mean_list = torch.zeros(L-1) |
| for l in range(L-1): |
| ratio_l = torch.tensor([ratio[l] for ratio in ratio_list if len(ratio) >= l+1]) |
| ratio_var_list[l] = torch.var(ratio_l, unbiased=False) |
| ratio_mean_list[l] = torch.mean(ratio_l) |
|
|
| ratio_mean_list = ratio_mean_list[1:] |
| ratio_var_list = ratio_var_list[1:] |
| results = { |
| 'ratio_var': ratio_var_list.tolist(), |
| 'ratio_mean': ratio_mean_list.tolist(), |
| 'ratio_var/mean': (ratio_var_list / ratio_mean_list).tolist(), |
| } |
|
|
| with open(results_file, 'w') as fout: |
| json.dump(results, fout) |
| print(f'Results written into {results_file}') |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--output_file', type=str, default="./exp_variance/results/xsum_gpt2-xl") |
| parser.add_argument('--dataset', type=str, default="xsum") |
| parser.add_argument('--dataset_file', type=str, default="./exp_variance/data/xsum_gpt2-xl") |
| parser.add_argument('--train_dataset', type=separated_string, default=[]) |
| parser.add_argument('--sampling_model_name', type=str, default="gpt2-xl") |
| parser.add_argument('--scoring_model_name', type=str, default="gpt2-xl") |
| parser.add_argument('--w_func', type=str, default='identity', choices=['identity', 'absoluate', 'bspline']) |
| parser.add_argument("--config", type=json.loads, default='{"start": -32, "end": 0, "n_bases": 7, "spline_order": 2, "intercept": 1}', help='A JSON dict') |
| parser.add_argument('--seed', type=int, default=0) |
| parser.add_argument('--device', type=str, default="cuda") |
| parser.add_argument('--compute_text', type=str, default='llm', choices=['human', 'llm']) |
| parser.add_argument('--cache_dir', type=str, default="../cache") |
| args = parser.parse_args() |
|
|
| experiment(args) |
|
|