""" Evaluator module for UncheatableEval visualization. Provides single-sample evaluation functions for Qwen3 and RWKV7 models. """ import gc import math import os import time from typing import List, Dict, Any, Optional import torch import torch.nn.functional as F from .helpers import TokenizerBytesConverter # Compression rate conversion factor COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0 def get_device(): """Get the best available device.""" if torch.cuda.is_available(): return "cuda" else: return "cpu" def calculate_log_sum(logits: torch.Tensor, target_token_ids: torch.Tensor) -> torch.Tensor: """Calculate cross entropy loss for each token.""" # Use float32 for CPU compatibility, bfloat16 for CUDA if logits.device.type == "cuda": return F.cross_entropy(logits[:-1].to(torch.bfloat16), target_token_ids[1:], reduction="none") else: return F.cross_entropy(logits[:-1].float(), target_token_ids[1:], reduction="none") def extract_topk_predictions(logit: torch.Tensor, target_ids: torch.Tensor, k: int = 10) -> List: """ Extract top-k predictions from logits. Args: logit: [seq_length, vocab_size] logit tensor target_ids: [seq_length] actual target token IDs k: number of top predictions to extract (default: 10) Returns: list: [[actual_id, rank, [[id1, prob1], [id2, prob2], ...]], ...] """ probs = F.softmax(logit, dim=-1) top_probs, top_ids = torch.topk(probs, k, dim=-1) results = [] for pos in range(logit.shape[0]): target_id = target_ids[pos].item() actual_prob = probs[pos, target_id].item() rank = (probs[pos] > actual_prob).sum().item() + 1 topk_list = [[top_ids[pos, i].item(), round(top_probs[pos, i].item(), 6)] for i in range(k)] results.append([target_id, rank, topk_list]) return results def count_model_parameters_in_billions(model) -> float: """Count model parameters in billions.""" total_params = sum(p.numel() for p in model.parameters()) return total_params / 1e9 def count_rwkv_parameters_in_billions(rwkv_model) -> float: """Count RWKV model parameters in billions.""" total_params = 0 if hasattr(rwkv_model, "z"): for param in rwkv_model.z.values(): total_params += param.numel() if hasattr(rwkv_model, "w"): for param in rwkv_model.w.values(): total_params += param.numel() return total_params / 1e9 @torch.no_grad() def evaluate_hf_single_sample(model, tokenizer, text: str, bos_mode: str = "add_newline_token") -> Dict[str, Any]: """ Evaluate a HuggingFace model on a single text sample. Args: model: HuggingFace model tokenizer: HuggingFace tokenizer text: Input text to evaluate bos_mode: How to handle BOS token Returns: dict with byte_wise_losses, top5_predictions, compression_rate, etc. """ start_time = time.time() # Create token-to-bytes converter token2bytes_converter = TokenizerBytesConverter(model_name_or_path=tokenizer.name_or_path, tokenizer=tokenizer) # Determine BOS token bos_token = tokenizer.encode("\n")[0] # if bos_mode in ["add_default_bos", "replace_with_bos"]: # bos_token = tokenizer.bos_token_id # elif bos_mode in ["add_default_eos", "replace_with_eos"]: # bos_token = tokenizer.eos_token_id # elif bos_mode in ["add_newline_token", "replace_with_newline_token"]: # bos_token = tokenizer.encode("\n")[0] # else: # bos_token = tokenizer.bos_token_id bos_tensor = torch.tensor([bos_token], device=model.device).unsqueeze(0) # Tokenize input inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False) inputs = inputs.to(model.device) seq_length = inputs["input_ids"].shape[-1] if seq_length < 2: raise ValueError(f"Text is too short (only {seq_length} tokens)") # Forward pass input_chunk = inputs["input_ids"] if bos_mode in ["add_default_bos", "add_default_eos", "add_newline_token"]: input_chunk = torch.cat([bos_tensor, input_chunk], dim=-1) if bos_mode in ["replace_with_bos", "replace_with_eos", "replace_with_newline_token"]: input_chunk[0, 0] = bos_token logit = model.forward(input_ids=input_chunk).logits[0, :, :] loss = calculate_log_sum(logit, input_chunk.squeeze(0)) # Get per-token bytes per_token_bytes = token2bytes_converter.encode_to_bytes(text) # Verify bytes match all_bytes = [byte for token in per_token_bytes for byte in token] expected_bytes = list(text.encode("utf-8")) if all_bytes != expected_bytes: raise ValueError("Token bytes don't match original text bytes") # Extract top-k predictions sample_topk = extract_topk_predictions(logit[:-1], input_chunk.squeeze(0)[1:]) # Calculate byte-wise losses byte_wise_losses = [] pending_loss = 0.0 for l, byte_values in zip(loss, per_token_bytes): current_loss = l.item() + pending_loss pending_loss = 0.0 if len(byte_values) == 0: pending_loss = current_loss continue per_byte_loss = current_loss / len(byte_values) for _ in range(len(byte_values)): byte_wise_losses.append(per_byte_loss) # Calculate overall metrics total_loss = loss.sum().item() num_bytes = len(text.encode("utf-8")) avg_loss = total_loss / seq_length compression_rate = avg_loss * COMPRESSION_RATE_FACTOR inference_time = time.time() - start_time return { "byte_wise_losses": byte_wise_losses, "top5_predictions": sample_topk, "compression_rate": compression_rate, "total_loss": total_loss, "num_tokens": seq_length, "num_bytes": num_bytes, "model_name": getattr(model.config, "_name_or_path", "unknown"), "tokenizer": tokenizer, "inference_time": inference_time, } @torch.no_grad() def evaluate_rwkv7_single_sample(model, tokenizer, text: str) -> Dict[str, Any]: """ Evaluate a RWKV7 model on a single text sample. Args: model: RWKV7 model tokenizer: RWKV tokenizer (TRIE_TOKENIZER) text: Input text to evaluate Returns: dict with byte_wise_losses, top5_predictions, compression_rate, etc. """ start_time = time.time() # Tokenize tokenized = tokenizer.encode(text) if hasattr(tokenized, "ids"): input_seq = tokenized.ids else: input_seq = tokenized input_length = len(input_seq) if input_length < 2: raise ValueError(f"Text is too short (only {input_length} tokens)") # Forward pass with state input_chunk = [0] + input_seq # Add BOS token (0) device = get_device() CHUNK_LEN = 1024 state = None logit = torch.empty((0, 65536), device=device) temp_input = input_chunk.copy() while len(temp_input) > 0: out, state = model.forward(temp_input[:CHUNK_LEN], state, full_output=True) if len(temp_input) == 1: out = out.unsqueeze(0) temp_input = temp_input[CHUNK_LEN:] logit = torch.concat((logit, out.to(device)), dim=0) if len(input_chunk) == 1: logit = logit.unsqueeze(0) loss = calculate_log_sum(logit, torch.tensor(input_chunk).to(device)) # Get per-token bytes token_bytes = [tokenizer.decodeBytes([token]) for token in input_chunk[1:]] # Extract top-k predictions sample_topk = extract_topk_predictions(logit[:-1], torch.tensor(input_chunk[1:]).to(device)) # Calculate byte-wise losses byte_wise_losses = [] for l, byte_values in zip(loss.tolist(), token_bytes): per_byte_loss = l / len(byte_values) for _ in range(len(byte_values)): byte_wise_losses.append(per_byte_loss) # Calculate overall metrics total_loss = loss.sum().item() num_bytes = len(text.encode("utf-8")) avg_loss = total_loss / input_length compression_rate = avg_loss * COMPRESSION_RATE_FACTOR inference_time = time.time() - start_time return { "byte_wise_losses": byte_wise_losses, "top5_predictions": sample_topk, "compression_rate": compression_rate, "total_loss": total_loss, "num_tokens": input_length, "num_bytes": num_bytes, "model_name": "RWKV7-G1C-1.5B", "tokenizer": tokenizer, "inference_time": inference_time, }