File size: 8,549 Bytes
9afeeeb
 
 
 
 
 
 
 
 
15b2f1f
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c2c6d
9afeeeb
 
 
 
 
 
 
 
 
 
a2836e3
44c2c6d
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2836e3
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
15b2f1f
 
9afeeeb
a2836e3
9afeeeb
 
a2836e3
 
 
 
 
 
 
 
 
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2836e3
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15b2f1f
9afeeeb
 
 
 
 
 
 
 
 
a2836e3
15b2f1f
9afeeeb
 
 
 
a2836e3
9afeeeb
 
 
 
 
 
 
 
 
 
 
15b2f1f
 
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2836e3
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
15b2f1f
9afeeeb
 
 
 
 
 
 
 
 
a2836e3
15b2f1f
9afeeeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
Evaluator module for UncheatableEval visualization.

Provides single-sample evaluation functions for Qwen3 and RWKV7 models.
"""

import gc
import math
import os
import time
from typing import List, Dict, Any, Optional

import torch
import torch.nn.functional as F

from .helpers import TokenizerBytesConverter


# Compression rate conversion factor
COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0


def get_device():
    """Get the best available device."""
    if torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"


def calculate_log_sum(logits: torch.Tensor, target_token_ids: torch.Tensor) -> torch.Tensor:
    """Calculate cross entropy loss for each token."""
    # Use float32 for CPU compatibility, bfloat16 for CUDA
    if logits.device.type == "cuda":
        return F.cross_entropy(logits[:-1].to(torch.bfloat16), target_token_ids[1:], reduction="none")
    else:
        return F.cross_entropy(logits[:-1].float(), target_token_ids[1:], reduction="none")


def extract_topk_predictions(logit: torch.Tensor, target_ids: torch.Tensor, k: int = 10) -> List:
    """
    Extract top-k predictions from logits.

    Args:
        logit: [seq_length, vocab_size] logit tensor
        target_ids: [seq_length] actual target token IDs
        k: number of top predictions to extract (default: 10)

    Returns:
        list: [[actual_id, rank, actual_prob, [[id1, prob1], [id2, prob2], ...]], ...]
    """
    probs = F.softmax(logit, dim=-1)
    top_probs, top_ids = torch.topk(probs, k, dim=-1)

    results = []
    for pos in range(logit.shape[0]):
        target_id = target_ids[pos].item()
        actual_prob = probs[pos, target_id].item()
        rank = (probs[pos] > actual_prob).sum().item() + 1

        topk_list = [[top_ids[pos, i].item(), round(top_probs[pos, i].item(), 6)] for i in range(k)]
        results.append([target_id, rank, actual_prob, topk_list])

    return results


def count_model_parameters_in_billions(model) -> float:
    """Count model parameters in billions."""
    total_params = sum(p.numel() for p in model.parameters())
    return total_params / 1e9


def count_rwkv_parameters_in_billions(rwkv_model) -> float:
    """Count RWKV model parameters in billions."""
    total_params = 0
    if hasattr(rwkv_model, "z"):
        for param in rwkv_model.z.values():
            total_params += param.numel()
    if hasattr(rwkv_model, "w"):
        for param in rwkv_model.w.values():
            total_params += param.numel()
    return total_params / 1e9


@torch.no_grad()
def evaluate_hf_single_sample(model, tokenizer, text: str, bos_mode: str = "add_newline_token") -> Dict[str, Any]:
    """
    Evaluate a HuggingFace model on a single text sample.

    Args:
        model: HuggingFace model
        tokenizer: HuggingFace tokenizer
        text: Input text to evaluate
        bos_mode: How to handle BOS token

    Returns:
        dict with byte_wise_losses, top5_predictions, compression_rate, etc.
    """
    start_time = time.time()

    # Create token-to-bytes converter
    token2bytes_converter = TokenizerBytesConverter(model_name_or_path=tokenizer.name_or_path, tokenizer=tokenizer)

    # Determine BOS token
    bos_token = tokenizer.encode("\n")[0]
    # if bos_mode in ["add_default_bos", "replace_with_bos"]:
    #     bos_token = tokenizer.bos_token_id
    # elif bos_mode in ["add_default_eos", "replace_with_eos"]:
    #     bos_token = tokenizer.eos_token_id
    # elif bos_mode in ["add_newline_token", "replace_with_newline_token"]:
    #     bos_token = tokenizer.encode("\n")[0]
    # else:
    #     bos_token = tokenizer.bos_token_id

    bos_tensor = torch.tensor([bos_token], device=model.device).unsqueeze(0)

    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    inputs = inputs.to(model.device)
    seq_length = inputs["input_ids"].shape[-1]

    if seq_length < 2:
        raise ValueError(f"Text is too short (only {seq_length} tokens)")

    # Forward pass
    input_chunk = inputs["input_ids"]
    if bos_mode in ["add_default_bos", "add_default_eos", "add_newline_token"]:
        input_chunk = torch.cat([bos_tensor, input_chunk], dim=-1)
    if bos_mode in ["replace_with_bos", "replace_with_eos", "replace_with_newline_token"]:
        input_chunk[0, 0] = bos_token

    logit = model.forward(input_ids=input_chunk).logits[0, :, :]
    loss = calculate_log_sum(logit, input_chunk.squeeze(0))

    # Get per-token bytes
    per_token_bytes = token2bytes_converter.encode_to_bytes(text)

    # Verify bytes match
    all_bytes = [byte for token in per_token_bytes for byte in token]
    expected_bytes = list(text.encode("utf-8"))
    if all_bytes != expected_bytes:
        raise ValueError("Token bytes don't match original text bytes")

    # Extract top-k predictions
    sample_topk = extract_topk_predictions(logit[:-1], input_chunk.squeeze(0)[1:])

    # Calculate byte-wise losses
    byte_wise_losses = []
    pending_loss = 0.0

    for l, byte_values in zip(loss, per_token_bytes):
        current_loss = l.item() + pending_loss
        pending_loss = 0.0

        if len(byte_values) == 0:
            pending_loss = current_loss
            continue

        per_byte_loss = current_loss / len(byte_values)
        for _ in range(len(byte_values)):
            byte_wise_losses.append(per_byte_loss)

    # Calculate overall metrics
    total_loss = loss.sum().item()
    num_bytes = len(text.encode("utf-8"))
    avg_loss = total_loss / seq_length
    compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
    inference_time = time.time() - start_time

    return {
        "byte_wise_losses": byte_wise_losses,
        "top5_predictions": sample_topk,
        "compression_rate": compression_rate,
        "total_loss": total_loss,
        "num_tokens": seq_length,
        "num_bytes": num_bytes,
        "model_name": getattr(model.config, "_name_or_path", "unknown"),
        "tokenizer": tokenizer,
        "inference_time": inference_time,
    }


@torch.no_grad()
def evaluate_rwkv7_single_sample(model, tokenizer, text: str) -> Dict[str, Any]:
    """
    Evaluate a RWKV7 model on a single text sample.

    Args:
        model: RWKV7 model
        tokenizer: RWKV tokenizer (TRIE_TOKENIZER)
        text: Input text to evaluate

    Returns:
        dict with byte_wise_losses, top5_predictions, compression_rate, etc.
    """
    start_time = time.time()

    # Tokenize
    tokenized = tokenizer.encode(text)
    if hasattr(tokenized, "ids"):
        input_seq = tokenized.ids
    else:
        input_seq = tokenized

    input_length = len(input_seq)

    if input_length < 2:
        raise ValueError(f"Text is too short (only {input_length} tokens)")

    # Forward pass with state
    input_chunk = [0] + input_seq  # Add BOS token (0)
    device = get_device()

    CHUNK_LEN = 1024
    state = None
    logit = torch.empty((0, 65536), device=device)

    temp_input = input_chunk.copy()
    while len(temp_input) > 0:
        out, state = model.forward(temp_input[:CHUNK_LEN], state, full_output=True)
        if len(temp_input) == 1:
            out = out.unsqueeze(0)
        temp_input = temp_input[CHUNK_LEN:]
        logit = torch.concat((logit, out.to(device)), dim=0)

    if len(input_chunk) == 1:
        logit = logit.unsqueeze(0)

    loss = calculate_log_sum(logit, torch.tensor(input_chunk).to(device))

    # Get per-token bytes
    token_bytes = [tokenizer.decodeBytes([token]) for token in input_chunk[1:]]

    # Extract top-k predictions
    sample_topk = extract_topk_predictions(logit[:-1], torch.tensor(input_chunk[1:]).to(device))

    # Calculate byte-wise losses
    byte_wise_losses = []
    for l, byte_values in zip(loss.tolist(), token_bytes):
        per_byte_loss = l / len(byte_values)
        for _ in range(len(byte_values)):
            byte_wise_losses.append(per_byte_loss)

    # Calculate overall metrics
    total_loss = loss.sum().item()
    num_bytes = len(text.encode("utf-8"))
    avg_loss = total_loss / input_length
    compression_rate = avg_loss * COMPRESSION_RATE_FACTOR
    inference_time = time.time() - start_time

    return {
        "byte_wise_losses": byte_wise_losses,
        "top5_predictions": sample_topk,
        "compression_rate": compression_rate,
        "total_loss": total_loss,
        "num_tokens": input_length,
        "num_bytes": num_bytes,
        "model_name": "RWKV7-G1C-1.5B",
        "tokenizer": tokenizer,
        "inference_time": inference_time,
    }