Spaces:
Running
Running
File size: 7,469 Bytes
3cdaafb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import Union
import os
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import transformers
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)
torch.set_grad_enabled(False)
huggingface_config = {
# Only required for private models from Huggingface (e.g. LLaMA models)
"TOKEN": os.environ.get("HF_TOKEN", None)
}
# selected using Falcon-7B and Falcon-7B-Instruct at bfloat16
BINOCULARS_ACCURACY_THRESHOLD = 0.9015310749276843 # optimized for f1-score
BINOCULARS_FPR_THRESHOLD = 0.8536432310785527 # optimized for low-fpr [chosen at 0.01%]
DEVICE_1 = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE_2 = "cuda:1" if torch.cuda.device_count() > 1 else DEVICE_1
def assert_tokenizer_consistency(model_id_1, model_id_2):
identical_tokenizers = (
AutoTokenizer.from_pretrained(model_id_1).vocab
== AutoTokenizer.from_pretrained(model_id_2).vocab
)
if not identical_tokenizers:
raise ValueError(f"Tokenizers are not identical for {model_id_1} and {model_id_2}.")
def perplexity(encoding: transformers.BatchEncoding,
logits: torch.Tensor,
median: bool = False,
temperature: float = 1.0):
shifted_logits = logits[..., :-1, :].contiguous() / temperature
shifted_labels = encoding.input_ids[..., 1:].contiguous()
shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()
if median:
ce_nan = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels).
masked_fill(~shifted_attention_mask.bool(), float("nan")))
ppl = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
else:
ppl = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels) *
shifted_attention_mask).sum(1) / shifted_attention_mask.sum(1)
ppl = ppl.to("cpu").float().numpy()
return ppl
def entropy(p_logits: torch.Tensor,
q_logits: torch.Tensor,
encoding: transformers.BatchEncoding,
pad_token_id: int,
median: bool = False,
sample_p: bool = False,
temperature: float = 1.0):
vocab_size = p_logits.shape[-1]
total_tokens_available = q_logits.shape[-2]
p_scores, q_scores = p_logits / temperature, q_logits / temperature
p_proba = softmax_fn(p_scores).view(-1, vocab_size)
if sample_p:
p_proba = torch.multinomial(p_proba.view(-1, vocab_size), replacement=True, num_samples=1).view(-1)
q_scores = q_scores.view(-1, vocab_size)
ce = ce_loss_fn(input=q_scores, target=p_proba).view(-1, total_tokens_available)
padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)
if median:
ce_nan = ce.masked_fill(~padding_mask.bool(), float("nan"))
agg_ce = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
else:
agg_ce = (((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy())
return agg_ce
class Binoculars(object):
def __init__(self,
observer_name_or_path: str = "tiiuae/falcon-7b",
performer_name_or_path: str = "tiiuae/falcon-7b-instruct",
use_bfloat16: bool = True,
max_token_observed: int = 512,
mode: str = "low-fpr",
) -> None:
assert_tokenizer_consistency(observer_name_or_path, performer_name_or_path)
self.change_mode(mode)
self.observer_model = AutoModelForCausalLM.from_pretrained(observer_name_or_path,
device_map={"": DEVICE_1},
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bfloat16
else torch.float32,
token=huggingface_config["TOKEN"]
)
self.performer_model = AutoModelForCausalLM.from_pretrained(performer_name_or_path,
device_map={"": DEVICE_2},
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bfloat16
else torch.float32,
token=huggingface_config["TOKEN"]
)
self.observer_model.eval()
self.performer_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(observer_name_or_path)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_token_observed = max_token_observed
def change_mode(self, mode: str) -> None:
if mode == "low-fpr":
self.threshold = BINOCULARS_FPR_THRESHOLD
elif mode == "accuracy":
self.threshold = BINOCULARS_ACCURACY_THRESHOLD
else:
raise ValueError(f"Invalid mode: {mode}")
def _tokenize(self, batch: list[str]) -> transformers.BatchEncoding:
batch_size = len(batch)
encodings = self.tokenizer(
batch,
return_tensors="pt",
padding="longest" if batch_size > 1 else False,
truncation=True,
max_length=self.max_token_observed,
return_token_type_ids=False).to(self.observer_model.device)
return encodings
@torch.inference_mode()
def _get_logits(self, encodings: transformers.BatchEncoding) -> torch.Tensor:
observer_logits = self.observer_model(**encodings.to(DEVICE_1)).logits
performer_logits = self.performer_model(**encodings.to(DEVICE_2)).logits
if DEVICE_1 != "cpu":
torch.cuda.synchronize()
return observer_logits, performer_logits
def compute_score(self, input_text: Union[list[str], str]) -> Union[float, list[float]]:
batch = [input_text] if isinstance(input_text, str) else input_text
encodings = self._tokenize(batch)
observer_logits, performer_logits = self._get_logits(encodings)
ppl = perplexity(encodings, performer_logits)
x_ppl = entropy(observer_logits.to(DEVICE_1), performer_logits.to(DEVICE_1),
encodings.to(DEVICE_1), self.tokenizer.pad_token_id)
binoculars_scores = ppl / x_ppl
binoculars_scores = binoculars_scores.tolist()
return binoculars_scores[0] if isinstance(input_text, str) else binoculars_scores
def predict(self, input_text: Union[list[str], str]) -> Union[list[str], str]:
binoculars_scores = np.array(self.compute_score(input_text))
pred = np.where(binoculars_scores < self.threshold,
"Most likely AI-generated",
"Most likely human-generated"
).tolist()
return pred |