| import torch |
| from torch import nn |
| from tqdm import tqdm |
| from BSpline import BSpline |
|
|
|
|
| def optimal_beta(a: torch.Tensor, B: torch.Tensor, c: torch.Tensor = None) -> torch.Tensor: |
| REGULARIZATION = 0.01 |
| |
| alpha = REGULARIZATION * torch.eye(B.shape[0], device=B.device, dtype=B.dtype) |
| u = torch.linalg.solve(B+alpha, a) |
| if c is None: |
| w = u |
| else: |
| |
| v = torch.linalg.solve(B+alpha, c) |
| mu = torch.dot(c, u) / torch.dot(c, v) |
| w = u - mu * v |
| |
| norm = torch.sqrt(torch.dot(w, B @ w)) |
| return w / norm |
|
|
| def get_bias_vector_mc(token_list, model, bspline: BSpline, args, compute_cov: bool = False, sample_size: int = 5000): |
| """ |
| Monte Carol based method |
| sample_size: sampling size |
| """ |
| model.eval() |
| device = args.device |
| n_bases = bspline.n_bases |
| if bspline.add_intercept: |
| n_bases += 1 |
| bias_list = [] |
| cov_list = [] |
|
|
| for tokens in tqdm(token_list): |
| input_ids = tokens.input_ids[0] |
| with torch.no_grad(): |
| logits = model(**tokens).logits[0, :-1] |
| seq_len, vocab_size = logits.shape |
|
|
| probs = torch.softmax(logits, dim=-1) |
| logp = torch.log_softmax(logits, dim=-1) |
|
|
| labels = input_ids[1:] |
| log_ll = logp[torch.arange(seq_len), labels] |
| w_j = bspline.predict(log_ll.clamp_min(bspline.start).reshape(-1)).to(device) |
| w_j = w_j.reshape(seq_len, n_bases) |
|
|
| sampled_indices = torch.multinomial(probs, num_samples=sample_size, replacement=True) |
| |
| sampled_logp = logp.gather(1, sampled_indices) |
| |
| flat_sampled_logp = sampled_logp.clamp_min(bspline.start).reshape(-1) |
| flat_basis_samples = bspline.predict(flat_sampled_logp).to(device) |
| basis_samples = flat_basis_samples.reshape(seq_len, sample_size, n_bases) |
| |
| mean_ref = basis_samples.mean(dim=1) |
|
|
| bias_sample = (w_j - mean_ref).sum(dim=0) |
| bias_list.append(bias_sample) |
|
|
| if compute_cov: |
| cov_sample = torch.zeros(n_bases, n_bases, device=device) |
| |
| for t in range(seq_len): |
| phi_samples = basis_samples[t] |
| |
| sample_mean = phi_samples.mean(dim=0) |
| |
| centered_samples = phi_samples - sample_mean |
| cov_t = (centered_samples.t() @ centered_samples) / (sample_size - 1) |
| |
| if bspline.add_intercept: |
| cov_t[0, :] = 0.0 |
| cov_t[:, 0] = 0.0 |
| |
| cov_sample += cov_t |
| |
| cov_list.append(cov_sample) |
|
|
| bias_vector = torch.stack(bias_list, dim=0).mean(dim=0) |
| if compute_cov: |
| cov_matrix = torch.stack(cov_list, dim=0).mean(dim=0) |
| return bias_vector, cov_matrix |
| return bias_vector |
|
|
| def get_bias_vector(token_list, model, bspline: BSpline, args, compute_cov: bool = False, speedup_rate = 1): |
| """ |
| For each text in text_list, compute the bias vector |
| (mean difference between sampled basis and expected basis) |
| and, if requested, the covariance matrix of basis differences. |
| Returns bias (n_bases,) and optionally cov (n_bases, n_bases). |
| """ |
| model.eval() |
| device = args.device |
| n_bases = bspline.n_bases |
| if bspline.add_intercept: |
| n_bases += 1 |
| bias_list = [] |
| cov_list = [] |
|
|
| for tokens in tqdm(token_list): |
| input_ids = tokens.input_ids[0] |
| with torch.no_grad(): |
| logits = model(**tokens).logits[0, :-1] |
| seq_len, vocab_size = logits.shape |
|
|
| probs = torch.softmax(logits, dim=-1) |
| vocab_size = int(vocab_size / speedup_rate) |
| probs, _ = torch.topk(probs, k=vocab_size, dim=-1) |
| probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
| logp = torch.log_softmax(logits, dim=-1) |
|
|
| |
| labels = input_ids[1:] |
| |
| log_ll = logp[torch.arange(seq_len), labels] |
| w_j = bspline.predict(log_ll.clamp_min(bspline.start).reshape(-1)).to(device) |
| w_j = w_j.reshape(seq_len, n_bases) |
|
|
| |
| logp, _ = torch.topk(logp, k=vocab_size, dim=-1) |
| |
| flat_logp = logp.clamp_min(bspline.start).reshape(-1) |
| flat_basis = bspline.predict(flat_logp).to(device) |
| basis = flat_basis.reshape(seq_len, vocab_size, n_bases) |
| |
| mean_ref = (probs.unsqueeze(-1) * basis).sum(dim=1) |
|
|
| |
| bias_sample = (mean_ref - w_j).sum(dim=0) |
| bias_list.append(bias_sample) |
|
|
| if compute_cov: |
| |
| cov_sample = torch.zeros(n_bases, n_bases, device=device) |
| for t in range(seq_len): |
| p_t = probs[t] |
| phi_t = basis[t] |
| |
| Ex_t = (p_t.unsqueeze(1) * phi_t).sum(dim=0) |
| |
| ExxT_t = phi_t.t() @ (p_t.unsqueeze(1) * phi_t) |
| cov_t = ExxT_t - Ex_t.unsqueeze(1) @ Ex_t.unsqueeze(0) |
| cov_sample += cov_t |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cov_list.append(cov_sample) |
|
|
| bias_vector = torch.stack(bias_list, dim=0).mean(dim=0) |
| if compute_cov: |
| cov_matrix = torch.stack(cov_list, dim=0).mean(dim=0) |
| return bias_vector, cov_matrix |
| return bias_vector |
|
|
|
|
| class BSplineTheory(nn.Module): |
| def __init__(self, bspline_args, machine_text: bool = False): |
| super().__init__() |
| self.bspline = BSpline(**bspline_args) |
| self.machine_text = machine_text |
| self.beta_hat = None |
|
|
| def fit(self, human_token_list, machine_token_list, model, args): |
| device = args.device |
| print("Learning w function...") |
|
|
| print("Fetching bias and covariance for human texts...") |
| bias_a, cov_B = get_bias_vector(human_token_list, model, self.bspline, args, compute_cov=True) |
| print("Computing beta_hat...") |
| |
| |
|
|
| if self.machine_text: |
| print("Fetching bias for machine-generated texts...") |
| bias_c = get_bias_vector(machine_token_list, model, self.bspline, args, compute_cov=False) |
| |
| else: |
| bias_c = None |
|
|
| self.beta_hat = optimal_beta( |
| bias_a.to(device), cov_B.to(device), |
| bias_c.to(device) if bias_c is not None else None |
| ) |
| print("beta_hat:", torch.round(self.beta_hat, decimals=3)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| device = x.device |
| flat = x.clamp_min(self.bspline.start).reshape(-1) |
| basis = self.bspline.predict(flat).to(device) |
| w_flat = basis @ self.beta_hat.to(device) |
| return w_flat.reshape(x.shape) |
|
|