Spaces:
Restarting
Restarting
File size: 8,549 Bytes
9afeeeb 15b2f1f 9afeeeb 44c2c6d 9afeeeb a2836e3 44c2c6d 9afeeeb a2836e3 9afeeeb 15b2f1f 9afeeeb a2836e3 9afeeeb a2836e3 9afeeeb a2836e3 9afeeeb 15b2f1f 9afeeeb a2836e3 15b2f1f 9afeeeb a2836e3 9afeeeb 15b2f1f 9afeeeb a2836e3 9afeeeb 15b2f1f 9afeeeb a2836e3 15b2f1f 9afeeeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
"""
Evaluator module for UncheatableEval visualization.
Provides single-sample evaluation functions for Qwen3 and RWKV7 models.
"""
import gc
import math
import os
import time
from typing import List, Dict, Any, Optional
import torch
import torch.nn.functional as F
from .helpers import TokenizerBytesConverter
# Compression rate conversion factor
COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0
def get_device():
"""Get the best available device."""
if torch.cuda.is_available():
return "cuda"
else:
return "cpu"
def calculate_log_sum(logits: torch.Tensor, target_token_ids: torch.Tensor) -> torch.Tensor:
"""Calculate cross entropy loss for each token."""
# Use float32 for CPU compatibility, bfloat16 for CUDA
if logits.device.type == "cuda":
return F.cross_entropy(logits[:-1].to(torch.bfloat16), target_token_ids[1:], reduction="none")
else:
return F.cross_entropy(logits[:-1].float(), target_token_ids[1:], reduction="none")
def extract_topk_predictions(logit: torch.Tensor, target_ids: torch.Tensor, k: int = 10) -> List:
"""
Extract top-k predictions from logits.
Args:
logit: [seq_length, vocab_size] logit tensor
target_ids: [seq_length] actual target token IDs
k: number of top predictions to extract (default: 10)
Returns:
list: [[actual_id, rank, actual_prob, [[id1, prob1], [id2, prob2], ...]], ...]
"""
probs = F.softmax(logit, dim=-1)
top_probs, top_ids = torch.topk(probs, k, dim=-1)
results = []
for pos in range(logit.shape[0]):
target_id = target_ids[pos].item()
actual_prob = probs[pos, target_id].item()
rank = (probs[pos] > actual_prob).sum().item() + 1
topk_list = [[top_ids[pos, i].item(), round(top_probs[pos, i].item(), 6)] for i in range(k)]
results.append([target_id, rank, actual_prob, topk_list])
return results
def count_model_parameters_in_billions(model) -> float:
"""Count model parameters in billions."""
total_params = sum(p.numel() for p in model.parameters())
return total_params / 1e9
def count_rwkv_parameters_in_billions(rwkv_model) -> float:
"""Count RWKV model parameters in billions."""
total_params = 0
if hasattr(rwkv_model, "z"):
for param in rwkv_model.z.values():
total_params += param.numel()
if hasattr(rwkv_model, "w"):
for param in rwkv_model.w.values():
total_params += param.numel()
return total_params / 1e9
@torch.no_grad()
def evaluate_hf_single_sample(model, tokenizer, text: str, bos_mode: str = "add_newline_token") -> Dict[str, Any]:
"""
Evaluate a HuggingFace model on a single text sample.
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
text: Input text to evaluate
bos_mode: How to handle BOS token
Returns:
dict with byte_wise_losses, top5_predictions, compression_rate, etc.
"""
start_time = time.time()
# Create token-to-bytes converter
token2bytes_converter = TokenizerBytesConverter(model_name_or_path=tokenizer.name_or_path, tokenizer=tokenizer)
# Determine BOS token
bos_token = tokenizer.encode("\n")[0]
# if bos_mode in ["add_default_bos", "replace_with_bos"]:
# bos_token = tokenizer.bos_token_id
# elif bos_mode in ["add_default_eos", "replace_with_eos"]:
# bos_token = tokenizer.eos_token_id
# elif bos_mode in ["add_newline_token", "replace_with_newline_token"]:
# bos_token = tokenizer.encode("\n")[0]
# else:
# bos_token = tokenizer.bos_token_id
bos_tensor = torch.tensor([bos_token], device=model.device).unsqueeze(0)
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
inputs = inputs.to(model.device)
seq_length = inputs["input_ids"].shape[-1]
if seq_length < 2:
raise ValueError(f"Text is too short (only {seq_length} tokens)")
# Forward pass
input_chunk = inputs["input_ids"]
if bos_mode in ["add_default_bos", "add_default_eos", "add_newline_token"]:
input_chunk = torch.cat([bos_tensor, input_chunk], dim=-1)
if bos_mode in ["replace_with_bos", "replace_with_eos", "replace_with_newline_token"]:
input_chunk[0, 0] = bos_token
logit = model.forward(input_ids=input_chunk).logits[0, :, :]
loss = calculate_log_sum(logit, input_chunk.squeeze(0))
# Get per-token bytes
per_token_bytes = token2bytes_converter.encode_to_bytes(text)
# Verify bytes match
all_bytes = [byte for token in per_token_bytes for byte in token]
expected_bytes = list(text.encode("utf-8"))
if all_bytes != expected_bytes:
raise ValueError("Token bytes don't match original text bytes")
# Extract top-k predictions
sample_topk = extract_topk_predictions(logit[:-1], input_chunk.squeeze(0)[1:])
# Calculate byte-wise losses
byte_wise_losses = []
pending_loss = 0.0
for l, byte_values in zip(loss, per_token_bytes):
current_loss = l.item() + pending_loss
pending_loss = 0.0
if len(byte_values) == 0:
pending_loss = current_loss
continue
per_byte_loss = current_loss / len(byte_values)
for _ in range(len(byte_values)):
byte_wise_losses.append(per_byte_loss)
# Calculate overall metrics
total_loss = loss.sum().item()
num_bytes = len(text.encode("utf-8"))
avg_loss = total_loss / seq_length
compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
inference_time = time.time() - start_time
return {
"byte_wise_losses": byte_wise_losses,
"top5_predictions": sample_topk,
"compression_rate": compression_rate,
"total_loss": total_loss,
"num_tokens": seq_length,
"num_bytes": num_bytes,
"model_name": getattr(model.config, "_name_or_path", "unknown"),
"tokenizer": tokenizer,
"inference_time": inference_time,
}
@torch.no_grad()
def evaluate_rwkv7_single_sample(model, tokenizer, text: str) -> Dict[str, Any]:
"""
Evaluate a RWKV7 model on a single text sample.
Args:
model: RWKV7 model
tokenizer: RWKV tokenizer (TRIE_TOKENIZER)
text: Input text to evaluate
Returns:
dict with byte_wise_losses, top5_predictions, compression_rate, etc.
"""
start_time = time.time()
# Tokenize
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
input_seq = tokenized.ids
else:
input_seq = tokenized
input_length = len(input_seq)
if input_length < 2:
raise ValueError(f"Text is too short (only {input_length} tokens)")
# Forward pass with state
input_chunk = [0] + input_seq # Add BOS token (0)
device = get_device()
CHUNK_LEN = 1024
state = None
logit = torch.empty((0, 65536), device=device)
temp_input = input_chunk.copy()
while len(temp_input) > 0:
out, state = model.forward(temp_input[:CHUNK_LEN], state, full_output=True)
if len(temp_input) == 1:
out = out.unsqueeze(0)
temp_input = temp_input[CHUNK_LEN:]
logit = torch.concat((logit, out.to(device)), dim=0)
if len(input_chunk) == 1:
logit = logit.unsqueeze(0)
loss = calculate_log_sum(logit, torch.tensor(input_chunk).to(device))
# Get per-token bytes
token_bytes = [tokenizer.decodeBytes([token]) for token in input_chunk[1:]]
# Extract top-k predictions
sample_topk = extract_topk_predictions(logit[:-1], torch.tensor(input_chunk[1:]).to(device))
# Calculate byte-wise losses
byte_wise_losses = []
for l, byte_values in zip(loss.tolist(), token_bytes):
per_byte_loss = l / len(byte_values)
for _ in range(len(byte_values)):
byte_wise_losses.append(per_byte_loss)
# Calculate overall metrics
total_loss = loss.sum().item()
num_bytes = len(text.encode("utf-8"))
avg_loss = total_loss / input_length
compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
inference_time = time.time() - start_time
return {
"byte_wise_losses": byte_wise_losses,
"top5_predictions": sample_topk,
"compression_rate": compression_rate,
"total_loss": total_loss,
"num_tokens": input_length,
"num_bytes": num_bytes,
"model_name": "RWKV7-G1C-1.5B",
"tokenizer": tokenizer,
"inference_time": inference_time,
}
|