|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
import math |
|
|
|
|
|
MODEL_NAME = "roberta-base" |
|
|
print(f"Loading {MODEL_NAME}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
def calculate_perplexity(text): |
|
|
encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
input_ids = encodings.input_ids.to(device) |
|
|
seq_len = input_ids.shape[1] |
|
|
|
|
|
if seq_len < 2: return 100.0 |
|
|
|
|
|
nlls = [] |
|
|
BATCH_SIZE = 8 |
|
|
|
|
|
tensor_input_ids = input_ids.repeat(BATCH_SIZE, 1) |
|
|
start_idx = 1 |
|
|
end_idx = seq_len - 1 |
|
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none') |
|
|
|
|
|
print(f"Processing {seq_len} tokens with Batch Size {BATCH_SIZE}...") |
|
|
|
|
|
for i in range(start_idx, end_idx, BATCH_SIZE): |
|
|
current_batch_size = min(BATCH_SIZE, end_idx - i) |
|
|
batch_input_ids = tensor_input_ids[:current_batch_size].clone() |
|
|
batch_labels = torch.full(batch_input_ids.shape, -100).to(device) |
|
|
|
|
|
for j in range(current_batch_size): |
|
|
token_idx_to_mask = i + j |
|
|
batch_labels[j, token_idx_to_mask] = batch_input_ids[j, token_idx_to_mask].item() |
|
|
batch_input_ids[j, token_idx_to_mask] = tokenizer.mask_token_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(batch_input_ids) |
|
|
predictions = outputs.logits |
|
|
|
|
|
predictions = predictions.permute(0, 2, 1) |
|
|
loss = loss_fct(predictions, batch_labels) |
|
|
masked_losses = loss.sum(dim=1) |
|
|
nlls.append(masked_losses) |
|
|
|
|
|
if not nlls: return 0.0 |
|
|
all_nlls = torch.cat(nlls) |
|
|
mean_nll = all_nlls.mean() |
|
|
ppl = torch.exp(mean_nll) |
|
|
return ppl.item() |
|
|
|
|
|
|
|
|
|
|
|
human_text = "The specific nuance of that joke totally flew over my head, causing a bit of an awkward silence at the dinner table that lasted for what felt like an eternity." |
|
|
|
|
|
ai_text = "Artificial Intelligence is a branch of computer science that involves the development of systems capable of performing tasks deemed intelligent." |
|
|
|
|
|
print("\n--- Testing Human Text ---") |
|
|
ppl_human = calculate_perplexity(human_text) |
|
|
print(f"Human PPL: {ppl_human:.2f}") |
|
|
|
|
|
print("\n--- Testing AI Text ---") |
|
|
ppl_ai = calculate_perplexity(ai_text) |
|
|
print(f"AI PPL: {ppl_ai:.2f}") |
|
|
|