Spaces:
Running
Running
| """ | |
| Evaluator module for UncheatableEval visualization. | |
| Provides single-sample evaluation functions for Qwen3 and RWKV7 models. | |
| """ | |
| import gc | |
| import math | |
| import os | |
| import time | |
| from typing import List, Dict, Any, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from .helpers import TokenizerBytesConverter | |
| # Compression rate conversion factor | |
| COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0 | |
| def get_device(): | |
| """Get the best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| else: | |
| return "cpu" | |
| def calculate_log_sum(logits: torch.Tensor, target_token_ids: torch.Tensor) -> torch.Tensor: | |
| """Calculate cross entropy loss for each token.""" | |
| # Use float32 for CPU compatibility, bfloat16 for CUDA | |
| if logits.device.type == "cuda": | |
| return F.cross_entropy(logits[:-1].to(torch.bfloat16), target_token_ids[1:], reduction="none") | |
| else: | |
| return F.cross_entropy(logits[:-1].float(), target_token_ids[1:], reduction="none") | |
| def extract_topk_predictions(logit: torch.Tensor, target_ids: torch.Tensor, k: int = 10) -> List: | |
| """ | |
| Extract top-k predictions from logits. | |
| Args: | |
| logit: [seq_length, vocab_size] logit tensor | |
| target_ids: [seq_length] actual target token IDs | |
| k: number of top predictions to extract (default: 10) | |
| Returns: | |
| list: [[actual_id, rank, [[id1, prob1], [id2, prob2], ...]], ...] | |
| """ | |
| probs = F.softmax(logit, dim=-1) | |
| top_probs, top_ids = torch.topk(probs, k, dim=-1) | |
| results = [] | |
| for pos in range(logit.shape[0]): | |
| target_id = target_ids[pos].item() | |
| actual_prob = probs[pos, target_id].item() | |
| rank = (probs[pos] > actual_prob).sum().item() + 1 | |
| topk_list = [[top_ids[pos, i].item(), round(top_probs[pos, i].item(), 6)] for i in range(k)] | |
| results.append([target_id, rank, topk_list]) | |
| return results | |
| def count_model_parameters_in_billions(model) -> float: | |
| """Count model parameters in billions.""" | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| return total_params / 1e9 | |
| def count_rwkv_parameters_in_billions(rwkv_model) -> float: | |
| """Count RWKV model parameters in billions.""" | |
| total_params = 0 | |
| if hasattr(rwkv_model, "z"): | |
| for param in rwkv_model.z.values(): | |
| total_params += param.numel() | |
| if hasattr(rwkv_model, "w"): | |
| for param in rwkv_model.w.values(): | |
| total_params += param.numel() | |
| return total_params / 1e9 | |
| def evaluate_hf_single_sample(model, tokenizer, text: str, bos_mode: str = "add_newline_token") -> Dict[str, Any]: | |
| """ | |
| Evaluate a HuggingFace model on a single text sample. | |
| Args: | |
| model: HuggingFace model | |
| tokenizer: HuggingFace tokenizer | |
| text: Input text to evaluate | |
| bos_mode: How to handle BOS token | |
| Returns: | |
| dict with byte_wise_losses, top5_predictions, compression_rate, etc. | |
| """ | |
| start_time = time.time() | |
| # Create token-to-bytes converter | |
| token2bytes_converter = TokenizerBytesConverter(model_name_or_path=tokenizer.name_or_path, tokenizer=tokenizer) | |
| # Determine BOS token | |
| bos_token = tokenizer.encode("\n")[0] | |
| # if bos_mode in ["add_default_bos", "replace_with_bos"]: | |
| # bos_token = tokenizer.bos_token_id | |
| # elif bos_mode in ["add_default_eos", "replace_with_eos"]: | |
| # bos_token = tokenizer.eos_token_id | |
| # elif bos_mode in ["add_newline_token", "replace_with_newline_token"]: | |
| # bos_token = tokenizer.encode("\n")[0] | |
| # else: | |
| # bos_token = tokenizer.bos_token_id | |
| bos_tensor = torch.tensor([bos_token], device=model.device).unsqueeze(0) | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False) | |
| inputs = inputs.to(model.device) | |
| seq_length = inputs["input_ids"].shape[-1] | |
| if seq_length < 2: | |
| raise ValueError(f"Text is too short (only {seq_length} tokens)") | |
| # Forward pass | |
| input_chunk = inputs["input_ids"] | |
| if bos_mode in ["add_default_bos", "add_default_eos", "add_newline_token"]: | |
| input_chunk = torch.cat([bos_tensor, input_chunk], dim=-1) | |
| if bos_mode in ["replace_with_bos", "replace_with_eos", "replace_with_newline_token"]: | |
| input_chunk[0, 0] = bos_token | |
| logit = model.forward(input_ids=input_chunk).logits[0, :, :] | |
| loss = calculate_log_sum(logit, input_chunk.squeeze(0)) | |
| # Get per-token bytes | |
| per_token_bytes = token2bytes_converter.encode_to_bytes(text) | |
| # Verify bytes match | |
| all_bytes = [byte for token in per_token_bytes for byte in token] | |
| expected_bytes = list(text.encode("utf-8")) | |
| if all_bytes != expected_bytes: | |
| raise ValueError("Token bytes don't match original text bytes") | |
| # Extract top-k predictions | |
| sample_topk = extract_topk_predictions(logit[:-1], input_chunk.squeeze(0)[1:]) | |
| # Calculate byte-wise losses | |
| byte_wise_losses = [] | |
| pending_loss = 0.0 | |
| for l, byte_values in zip(loss, per_token_bytes): | |
| current_loss = l.item() + pending_loss | |
| pending_loss = 0.0 | |
| if len(byte_values) == 0: | |
| pending_loss = current_loss | |
| continue | |
| per_byte_loss = current_loss / len(byte_values) | |
| for _ in range(len(byte_values)): | |
| byte_wise_losses.append(per_byte_loss) | |
| # Calculate overall metrics | |
| total_loss = loss.sum().item() | |
| num_bytes = len(text.encode("utf-8")) | |
| avg_loss = total_loss / seq_length | |
| compression_rate = avg_loss * COMPRESSION_RATE_FACTOR | |
| inference_time = time.time() - start_time | |
| return { | |
| "byte_wise_losses": byte_wise_losses, | |
| "top5_predictions": sample_topk, | |
| "compression_rate": compression_rate, | |
| "total_loss": total_loss, | |
| "num_tokens": seq_length, | |
| "num_bytes": num_bytes, | |
| "model_name": getattr(model.config, "_name_or_path", "unknown"), | |
| "tokenizer": tokenizer, | |
| "inference_time": inference_time, | |
| } | |
| def evaluate_rwkv7_single_sample(model, tokenizer, text: str) -> Dict[str, Any]: | |
| """ | |
| Evaluate a RWKV7 model on a single text sample. | |
| Args: | |
| model: RWKV7 model | |
| tokenizer: RWKV tokenizer (TRIE_TOKENIZER) | |
| text: Input text to evaluate | |
| Returns: | |
| dict with byte_wise_losses, top5_predictions, compression_rate, etc. | |
| """ | |
| start_time = time.time() | |
| # Tokenize | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| input_seq = tokenized.ids | |
| else: | |
| input_seq = tokenized | |
| input_length = len(input_seq) | |
| if input_length < 2: | |
| raise ValueError(f"Text is too short (only {input_length} tokens)") | |
| # Forward pass with state | |
| input_chunk = [0] + input_seq # Add BOS token (0) | |
| device = get_device() | |
| CHUNK_LEN = 1024 | |
| state = None | |
| logit = torch.empty((0, 65536), device=device) | |
| temp_input = input_chunk.copy() | |
| while len(temp_input) > 0: | |
| out, state = model.forward(temp_input[:CHUNK_LEN], state, full_output=True) | |
| if len(temp_input) == 1: | |
| out = out.unsqueeze(0) | |
| temp_input = temp_input[CHUNK_LEN:] | |
| logit = torch.concat((logit, out.to(device)), dim=0) | |
| if len(input_chunk) == 1: | |
| logit = logit.unsqueeze(0) | |
| loss = calculate_log_sum(logit, torch.tensor(input_chunk).to(device)) | |
| # Get per-token bytes | |
| token_bytes = [tokenizer.decodeBytes([token]) for token in input_chunk[1:]] | |
| # Extract top-k predictions | |
| sample_topk = extract_topk_predictions(logit[:-1], torch.tensor(input_chunk[1:]).to(device)) | |
| # Calculate byte-wise losses | |
| byte_wise_losses = [] | |
| for l, byte_values in zip(loss.tolist(), token_bytes): | |
| per_byte_loss = l / len(byte_values) | |
| for _ in range(len(byte_values)): | |
| byte_wise_losses.append(per_byte_loss) | |
| # Calculate overall metrics | |
| total_loss = loss.sum().item() | |
| num_bytes = len(text.encode("utf-8")) | |
| avg_loss = total_loss / input_length | |
| compression_rate = avg_loss * COMPRESSION_RATE_FACTOR | |
| inference_time = time.time() - start_time | |
| return { | |
| "byte_wise_losses": byte_wise_losses, | |
| "top5_predictions": sample_topk, | |
| "compression_rate": compression_rate, | |
| "total_loss": total_loss, | |
| "num_tokens": input_length, | |
| "num_bytes": num_bytes, | |
| "model_name": "RWKV7-G1C-1.5B", | |
| "tokenizer": tokenizer, | |
| "inference_time": inference_time, | |
| } | |