Compression-Lens / core /evaluator.py
Jellyfish042's picture
update
15b2f1f
"""
Evaluator module for UncheatableEval visualization.
Provides single-sample evaluation functions for Qwen3 and RWKV7 models.
"""
import gc
import math
import os
import time
from typing import List, Dict, Any, Optional
import torch
import torch.nn.functional as F
from .helpers import TokenizerBytesConverter
# Compression rate conversion factor
COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0
def get_device():
"""Get the best available device."""
if torch.cuda.is_available():
return "cuda"
else:
return "cpu"
def calculate_log_sum(logits: torch.Tensor, target_token_ids: torch.Tensor) -> torch.Tensor:
"""Calculate cross entropy loss for each token."""
# Use float32 for CPU compatibility, bfloat16 for CUDA
if logits.device.type == "cuda":
return F.cross_entropy(logits[:-1].to(torch.bfloat16), target_token_ids[1:], reduction="none")
else:
return F.cross_entropy(logits[:-1].float(), target_token_ids[1:], reduction="none")
def extract_topk_predictions(logit: torch.Tensor, target_ids: torch.Tensor, k: int = 10) -> List:
"""
Extract top-k predictions from logits.
Args:
logit: [seq_length, vocab_size] logit tensor
target_ids: [seq_length] actual target token IDs
k: number of top predictions to extract (default: 10)
Returns:
list: [[actual_id, rank, [[id1, prob1], [id2, prob2], ...]], ...]
"""
probs = F.softmax(logit, dim=-1)
top_probs, top_ids = torch.topk(probs, k, dim=-1)
results = []
for pos in range(logit.shape[0]):
target_id = target_ids[pos].item()
actual_prob = probs[pos, target_id].item()
rank = (probs[pos] > actual_prob).sum().item() + 1
topk_list = [[top_ids[pos, i].item(), round(top_probs[pos, i].item(), 6)] for i in range(k)]
results.append([target_id, rank, topk_list])
return results
def count_model_parameters_in_billions(model) -> float:
"""Count model parameters in billions."""
total_params = sum(p.numel() for p in model.parameters())
return total_params / 1e9
def count_rwkv_parameters_in_billions(rwkv_model) -> float:
"""Count RWKV model parameters in billions."""
total_params = 0
if hasattr(rwkv_model, "z"):
for param in rwkv_model.z.values():
total_params += param.numel()
if hasattr(rwkv_model, "w"):
for param in rwkv_model.w.values():
total_params += param.numel()
return total_params / 1e9
@torch.no_grad()
def evaluate_hf_single_sample(model, tokenizer, text: str, bos_mode: str = "add_newline_token") -> Dict[str, Any]:
"""
Evaluate a HuggingFace model on a single text sample.
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
text: Input text to evaluate
bos_mode: How to handle BOS token
Returns:
dict with byte_wise_losses, top5_predictions, compression_rate, etc.
"""
start_time = time.time()
# Create token-to-bytes converter
token2bytes_converter = TokenizerBytesConverter(model_name_or_path=tokenizer.name_or_path, tokenizer=tokenizer)
# Determine BOS token
bos_token = tokenizer.encode("\n")[0]
# if bos_mode in ["add_default_bos", "replace_with_bos"]:
# bos_token = tokenizer.bos_token_id
# elif bos_mode in ["add_default_eos", "replace_with_eos"]:
# bos_token = tokenizer.eos_token_id
# elif bos_mode in ["add_newline_token", "replace_with_newline_token"]:
# bos_token = tokenizer.encode("\n")[0]
# else:
# bos_token = tokenizer.bos_token_id
bos_tensor = torch.tensor([bos_token], device=model.device).unsqueeze(0)
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
inputs = inputs.to(model.device)
seq_length = inputs["input_ids"].shape[-1]
if seq_length < 2:
raise ValueError(f"Text is too short (only {seq_length} tokens)")
# Forward pass
input_chunk = inputs["input_ids"]
if bos_mode in ["add_default_bos", "add_default_eos", "add_newline_token"]:
input_chunk = torch.cat([bos_tensor, input_chunk], dim=-1)
if bos_mode in ["replace_with_bos", "replace_with_eos", "replace_with_newline_token"]:
input_chunk[0, 0] = bos_token
logit = model.forward(input_ids=input_chunk).logits[0, :, :]
loss = calculate_log_sum(logit, input_chunk.squeeze(0))
# Get per-token bytes
per_token_bytes = token2bytes_converter.encode_to_bytes(text)
# Verify bytes match
all_bytes = [byte for token in per_token_bytes for byte in token]
expected_bytes = list(text.encode("utf-8"))
if all_bytes != expected_bytes:
raise ValueError("Token bytes don't match original text bytes")
# Extract top-k predictions
sample_topk = extract_topk_predictions(logit[:-1], input_chunk.squeeze(0)[1:])
# Calculate byte-wise losses
byte_wise_losses = []
pending_loss = 0.0
for l, byte_values in zip(loss, per_token_bytes):
current_loss = l.item() + pending_loss
pending_loss = 0.0
if len(byte_values) == 0:
pending_loss = current_loss
continue
per_byte_loss = current_loss / len(byte_values)
for _ in range(len(byte_values)):
byte_wise_losses.append(per_byte_loss)
# Calculate overall metrics
total_loss = loss.sum().item()
num_bytes = len(text.encode("utf-8"))
avg_loss = total_loss / seq_length
compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
inference_time = time.time() - start_time
return {
"byte_wise_losses": byte_wise_losses,
"top5_predictions": sample_topk,
"compression_rate": compression_rate,
"total_loss": total_loss,
"num_tokens": seq_length,
"num_bytes": num_bytes,
"model_name": getattr(model.config, "_name_or_path", "unknown"),
"tokenizer": tokenizer,
"inference_time": inference_time,
}
@torch.no_grad()
def evaluate_rwkv7_single_sample(model, tokenizer, text: str) -> Dict[str, Any]:
"""
Evaluate a RWKV7 model on a single text sample.
Args:
model: RWKV7 model
tokenizer: RWKV tokenizer (TRIE_TOKENIZER)
text: Input text to evaluate
Returns:
dict with byte_wise_losses, top5_predictions, compression_rate, etc.
"""
start_time = time.time()
# Tokenize
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
input_seq = tokenized.ids
else:
input_seq = tokenized
input_length = len(input_seq)
if input_length < 2:
raise ValueError(f"Text is too short (only {input_length} tokens)")
# Forward pass with state
input_chunk = [0] + input_seq # Add BOS token (0)
device = get_device()
CHUNK_LEN = 1024
state = None
logit = torch.empty((0, 65536), device=device)
temp_input = input_chunk.copy()
while len(temp_input) > 0:
out, state = model.forward(temp_input[:CHUNK_LEN], state, full_output=True)
if len(temp_input) == 1:
out = out.unsqueeze(0)
temp_input = temp_input[CHUNK_LEN:]
logit = torch.concat((logit, out.to(device)), dim=0)
if len(input_chunk) == 1:
logit = logit.unsqueeze(0)
loss = calculate_log_sum(logit, torch.tensor(input_chunk).to(device))
# Get per-token bytes
token_bytes = [tokenizer.decodeBytes([token]) for token in input_chunk[1:]]
# Extract top-k predictions
sample_topk = extract_topk_predictions(logit[:-1], torch.tensor(input_chunk[1:]).to(device))
# Calculate byte-wise losses
byte_wise_losses = []
for l, byte_values in zip(loss.tolist(), token_bytes):
per_byte_loss = l / len(byte_values)
for _ in range(len(byte_values)):
byte_wise_losses.append(per_byte_loss)
# Calculate overall metrics
total_loss = loss.sum().item()
num_bytes = len(text.encode("utf-8"))
avg_loss = total_loss / input_length
compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
inference_time = time.time() - start_time
return {
"byte_wise_losses": byte_wise_losses,
"top5_predictions": sample_topk,
"compression_rate": compression_rate,
"total_loss": total_loss,
"num_tokens": input_length,
"num_bytes": num_bytes,
"model_name": "RWKV7-G1C-1.5B",
"tokenizer": tokenizer,
"inference_time": inference_time,
}