File size: 12,469 Bytes
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fcf271
491ce2b
9afeeeb
d68c16d
 
 
 
 
 
15b2f1f
 
 
 
 
 
d68c16d
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
cddd3a5
9afeeeb
 
 
 
 
 
 
 
 
cddd3a5
9afeeeb
 
 
cddd3a5
 
9afeeeb
cddd3a5
9afeeeb
cddd3a5
9afeeeb
cddd3a5
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68c16d
cddd3a5
d68c16d
 
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15b2f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68c16d
 
15b2f1f
d68c16d
 
 
15b2f1f
 
 
d68c16d
 
 
 
 
 
 
 
 
 
 
 
15b2f1f
 
 
 
d68c16d
 
 
9afeeeb
 
491ce2b
 
cddd3a5
 
9afeeeb
 
 
 
 
cddd3a5
9afeeeb
 
98b668c
9afeeeb
 
 
 
d68c16d
15b2f1f
d68c16d
9afeeeb
 
 
 
 
 
 
d68c16d
8a0d82d
15b2f1f
 
 
 
8a0d82d
 
 
 
 
 
 
 
 
 
cddd3a5
9afeeeb
15b2f1f
 
 
 
 
 
 
d68c16d
8a0d82d
 
 
 
15b2f1f
cddd3a5
9afeeeb
15b2f1f
 
 
 
 
 
 
d620a8f
491ce2b
9afeeeb
 
fa6172d
 
 
 
 
 
 
 
 
cddd3a5
9afeeeb
 
 
 
 
820f694
9afeeeb
 
 
 
 
cddd3a5
9afeeeb
 
 
 
 
 
 
 
 
820f694
15b2f1f
 
 
 
 
 
 
 
820f694
15b2f1f
820f694
91f5d7c
 
9afeeeb
cddd3a5
68b02f7
cddd3a5
68b02f7
 
 
 
 
 
 
 
 
 
 
 
cddd3a5
 
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820f694
15b2f1f
820f694
9afeeeb
15b2f1f
820f694
9afeeeb
 
 
d68c16d
 
 
 
49eb0e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""
UncheatableEval Visualization - Hugging Face Space

Compare byte-level prediction performance between Qwen3-1.7B-Base and RWKV7-G1C-1.5B.
"""

import gc
import os
from pathlib import Path

import gradio as gr
import torch

# Detect device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IS_CPU = DEVICE == "cpu"

# Model configuration
QWEN_MODEL_ID = "Qwen/Qwen3-1.7B-Base"
RWKV_MODEL_URL = "https://huggingface.co/BlinkDL/rwkv7-g1/resolve/main/rwkv7-g1c-1.5b-20260110-ctx8192.pth"
RWKV_MODEL_FILENAME = "rwkv7-g1c-1.5b-20260110-ctx8192.pth"

# Get the directory where this script is located
SCRIPT_DIR = Path(__file__).parent.absolute()
MODELS_DIR = SCRIPT_DIR / "models"
SUPPORT_DIR = SCRIPT_DIR / "support"

# Text length limits
MAX_TEXT_LENGTH = 8192
MIN_TEXT_LENGTH = 1

# Global model cache
_qwen_model = None
_qwen_tokenizer = None
_rwkv_model = None
_rwkv_tokenizer = None
_rwkv_model_path = None
_stats_manager = None

# Precomputed example cache
_precomputed_html = None
_precomputed_text = None
PRECOMPUTED_DIR = SCRIPT_DIR / "precomputed"


def download_rwkv_model(progress=None):
    """Download RWKV7 model if not exists."""
    from huggingface_hub import hf_hub_download

    model_path = MODELS_DIR / RWKV_MODEL_FILENAME

    if model_path.exists():
        return str(model_path)

    MODELS_DIR.mkdir(parents=True, exist_ok=True)

    # Download from HuggingFace Hub
    downloaded_path = hf_hub_download(
        repo_id="BlinkDL/rwkv7-g1", filename=RWKV_MODEL_FILENAME, local_dir=str(MODELS_DIR), local_dir_use_symlinks=False
    )

    return downloaded_path


def load_qwen_model():
    """Load Qwen3-1.7B-Base model."""
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)

    # Configure based on device
    if IS_CPU:
        model_kwargs = {"torch_dtype": torch.float32, "device_map": None, "trust_remote_code": True, "low_cpu_mem_usage": True}
        model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
    else:
        model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "auto", "trust_remote_code": True}
        try:
            model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, attn_implementation="flash_attention_2", **model_kwargs).eval()
        except Exception:
            model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()

    return model, tokenizer


def load_rwkv7_model(model_path: str):
    """Load RWKV7-G1C-1.5B model."""
    os.environ["RWKV_JIT_ON"] = "1"
    os.environ["RWKV_V7_ON"] = "1"

    # Set CUDA flag based on device
    if IS_CPU:
        os.environ["RWKV_CUDA_ON"] = "0"
    else:
        os.environ["RWKV_CUDA_ON"] = "1"

    from rwkv.model import RWKV
    from rwkv.rwkv_tokenizer import TRIE_TOKENIZER

    # Use appropriate strategy for device
    if IS_CPU:
        strategy = "cpu fp32"
    else:
        strategy = "cuda fp16"

    # RWKV library automatically adds .pth extension, so remove it if present
    if model_path.endswith(".pth"):
        model_path = model_path[:-4]

    model = RWKV(model=model_path, strategy=strategy)

    vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
    tokenizer = TRIE_TOKENIZER(vocab_path)

    return model, tokenizer


def validate_input(text: str) -> tuple[bool, str]:
    """Validate input text."""
    if not text or not text.strip():
        return False, "Please enter some text to analyze."

    text = text.strip()

    if len(text) < MIN_TEXT_LENGTH:
        return False, f"Text is too short. Minimum {MIN_TEXT_LENGTH} characters required."

    if len(text) > MAX_TEXT_LENGTH:
        return False, f"Text is too long. Maximum {MAX_TEXT_LENGTH} characters allowed. Current: {len(text)}"

    return True, text


def load_precomputed_example():
    """Load precomputed example visualization."""
    global _precomputed_html, _precomputed_text

    html_path = PRECOMPUTED_DIR / "example_visualization.html"
    metadata_path = PRECOMPUTED_DIR / "example_metadata.json"

    if html_path.exists() and metadata_path.exists():
        import json
        with open(html_path, "r", encoding="utf-8") as f:
            _precomputed_html = f.read()
        with open(metadata_path, "r", encoding="utf-8") as f:
            metadata = json.load(f)
            _precomputed_text = metadata.get("example_text", "")
        print(f"Loaded precomputed example ({len(_precomputed_text)} chars)")
        return True
    else:
        print("No precomputed example found. Run precompute_example.py first.")
        return False


def initialize_models():
    """Initialize and cache both models at startup."""
    global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _rwkv_model_path, _stats_manager

    print("Initializing models...")

    # Load precomputed example first
    load_precomputed_example()

    # Download RWKV model if needed
    print("Checking RWKV7 model...")
    _rwkv_model_path = download_rwkv_model()

    # Load Qwen model
    print("Loading Qwen3-1.7B-Base...")
    _qwen_model, _qwen_tokenizer = load_qwen_model()

    # Load RWKV7 model
    print("Loading RWKV7-G1C-1.5B...")
    _rwkv_model, _rwkv_tokenizer = load_rwkv7_model(_rwkv_model_path)

    # Initialize stats manager
    from core.inference_stats import InferenceStatsManager
    _stats_manager = InferenceStatsManager()

    print("Models loaded successfully!")


def wrap_html_in_iframe(html: str) -> str:
    """Wrap HTML in an iframe for Gradio display."""
    # For srcdoc attribute, we only need to escape quotes
    # The HTML entities inside (like &quot;, &#10;) should remain as-is
    escaped = html.replace('"', "&quot;")
    return f"""
    <div style="width:100%;height:700px;border:1px solid #ddd;border-radius:8px;overflow:hidden;">
        <iframe srcdoc="{escaped}"
                style="width:100%;height:100%;border:none;"
                sandbox="allow-scripts"></iframe>
    </div>
    """


def run_evaluation(text: str, progress=gr.Progress()):
    """Run evaluation on both models and generate visualization."""
    from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
    from visualization.html_generator import generate_comparison_html

    # Use cached models
    global _qwen_model, _qwen_tokenizer, _rwkv_model, _rwkv_tokenizer, _stats_manager

    # Validate input
    valid, result = validate_input(text)
    if not valid:
        raise gr.Error(result)

    text = result  # Use cleaned text

    try:
        # Get token counts for prediction first
        qwen_inputs = _qwen_tokenizer(text, return_tensors="pt", add_special_tokens=False)
        qwen_token_count = qwen_inputs["input_ids"].shape[-1]
        qwen_predicted_time = _stats_manager.predict_time("qwen", qwen_token_count)

        rwkv_tokenized = _rwkv_tokenizer.encode(text)
        rwkv_token_count = len(rwkv_tokenized.ids if hasattr(rwkv_tokenized, "ids") else rwkv_tokenized)
        rwkv_predicted_time = _stats_manager.predict_time("rwkv", rwkv_token_count)

        # Step 1: Evaluate Qwen (using cached model)
        if qwen_predicted_time is not None:
            progress(0, desc=f"Evaluating with Qwen3... (estimated: {qwen_predicted_time:.1f}s)")
        else:
            progress(0, desc="Evaluating with Qwen3...")

        result_qwen = evaluate_hf_single_sample(_qwen_model, _qwen_tokenizer, text, bos_mode="add_newline_token")

        # Save stats and print comparison
        _stats_manager.add_record("qwen", qwen_token_count, result_qwen["inference_time"])
        if qwen_predicted_time is not None:
            print(f"Qwen3 completed in {result_qwen['inference_time']:.2f}s (predicted: {qwen_predicted_time:.2f}s)")
        else:
            print(f"Qwen3 completed in {result_qwen['inference_time']:.2f}s")

        # Step 2: Evaluate RWKV7 (using cached model)
        if rwkv_predicted_time is not None:
            progress(0, desc=f"Evaluating with RWKV7... (estimated: {rwkv_predicted_time:.1f}s)")
        else:
            progress(0, desc="Evaluating with RWKV7...")

        result_rwkv = evaluate_rwkv7_single_sample(_rwkv_model, _rwkv_tokenizer, text)

        # Save stats and print comparison
        _stats_manager.add_record("rwkv", rwkv_token_count, result_rwkv["inference_time"])
        if rwkv_predicted_time is not None:
            print(f"RWKV7 completed in {result_rwkv['inference_time']:.2f}s (predicted: {rwkv_predicted_time:.2f}s)")
        else:
            print(f"RWKV7 completed in {result_rwkv['inference_time']:.2f}s")

        # Step 3: Generate visualization
        progress(0, desc="Generating visualization...")
        html = generate_comparison_html(
            text=text,
            byte_losses_a=result_rwkv["byte_wise_losses"],
            byte_losses_b=result_qwen["byte_wise_losses"],
            model_a_name="RWKV7-G1C-1.5B",
            model_b_name="Qwen3-1.7B-Base",
            topk_predictions_a=result_rwkv["top5_predictions"],
            topk_predictions_b=result_qwen["top5_predictions"],
            tokenizer_a=result_rwkv["tokenizer"],
            tokenizer_b=result_qwen["tokenizer"],
            model_type_a="rwkv7",
            model_type_b="hf",
        )

        # Wrap HTML for iframe display
        wrapped_html = wrap_html_in_iframe(html)

        return wrapped_html

    except torch.cuda.OutOfMemoryError:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        raise gr.Error("GPU memory insufficient. Please try:\n" "1. Use shorter text\n" "2. Wait a moment and try again")
    except Exception as e:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        raise gr.Error(f"Evaluation failed: {str(e)}")


def clear_inputs():
    """Clear all inputs and outputs."""
    return "", None


def get_default_example():
    """Get the default example for display on page load."""
    global _precomputed_html, _precomputed_text

    if _precomputed_html and _precomputed_text:
        wrapped_html = wrap_html_in_iframe(_precomputed_html)
        return _precomputed_text, wrapped_html
    else:
        return "", None


# Build Gradio UI
with gr.Blocks(title="Compression-Lens: RWKV-7 vs Qwen3", theme=gr.themes.Soft()) as demo:
    gr.HTML(
        """
    <div style="text-align: center; margin-bottom: 20px;">
        <h1 style="margin-bottom: 10px;">🔬 Compression-Lens: RWKV-7 vs Qwen3 Byte-Level Comparison</h1>
        <p style="margin-bottom: 15px; color: #666;">Compare the byte-level prediction performance between <strong>RWKV7-G1C-1.5B</strong> and <strong>Qwen3-1.7B-Base</strong>.</p>
        <div style="display: flex; justify-content: center; align-items: center; gap: 10px;">
            <a href="https://github.com/Jellyfish042/uncheatable_eval" target="_blank" style="text-decoration: none;">
                <img src="https://img.shields.io/badge/GitHub-Project-181717?logo=github" alt="GitHub Project">
            </a>
            <a href="https://huggingface.co/spaces/Jellyfish042/UncheatableEval" target="_blank" style="text-decoration: none;">
                <img src="https://img.shields.io/badge/%F0%9F%8F%86%20Leaderboard-Gradio-ff7c00" alt="Leaderboard">
            </a>
        </div>
    </div>
    """
    )

    with gr.Row():
        with gr.Column(scale=1):
            text_input = gr.Textbox(
                label="Input Text",
                placeholder=f"Enter text to analyze (max {MAX_TEXT_LENGTH} characters)...",
                lines=10,
                max_lines=20,
            )

            with gr.Row():
                clear_btn = gr.Button("Clear", variant="secondary")
                run_btn = gr.Button("▶ Run Comparison", variant="primary")

    gr.Markdown("---")

    with gr.Row():
        with gr.Column():
            output_html = gr.HTML(label="Visualization")

    # Event handlers
    clear_btn.click(fn=clear_inputs, outputs=[text_input, output_html])

    run_btn.click(fn=run_evaluation, inputs=[text_input], outputs=[output_html])

    # Load default example on page load
    demo.load(fn=get_default_example, outputs=[text_input, output_html])


if __name__ == "__main__":
    # Initialize models before launching the app
    initialize_models()

    # Launch the Gradio app
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)