File size: 5,623 Bytes
15b2f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Precompute example evaluation results for the default demo.

This script runs the evaluation on the example text and saves the results
so they can be loaded instantly when users visit the page.
"""

import json
import os
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))

import torch

# Get the directory where this script is located
SCRIPT_DIR = Path(__file__).parent.absolute()
MODELS_DIR = SCRIPT_DIR / "models"
SUPPORT_DIR = SCRIPT_DIR / "support"
PRECOMPUTED_DIR = SCRIPT_DIR / "precomputed"

# Model configuration
QWEN_MODEL_ID = "Qwen/Qwen3-1.7B-Base"
RWKV_MODEL_FILENAME = "rwkv7-g1c-1.5b-20260110-ctx8192.pth"

# Detect device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IS_CPU = DEVICE == "cpu"


def download_rwkv_model():
    """Download RWKV7 model if not exists."""
    from huggingface_hub import hf_hub_download

    model_path = MODELS_DIR / RWKV_MODEL_FILENAME

    if model_path.exists():
        return str(model_path)

    MODELS_DIR.mkdir(parents=True, exist_ok=True)

    downloaded_path = hf_hub_download(
        repo_id="BlinkDL/rwkv7-g1", filename=RWKV_MODEL_FILENAME, local_dir=str(MODELS_DIR), local_dir_use_symlinks=False
    )

    return downloaded_path


def load_qwen_model():
    """Load Qwen3-1.7B-Base model."""
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID, trust_remote_code=True)

    if IS_CPU:
        model_kwargs = {"torch_dtype": torch.float32, "device_map": None, "trust_remote_code": True, "low_cpu_mem_usage": True}
        model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()
    else:
        model_kwargs = {"torch_dtype": torch.bfloat16, "device_map": "auto", "trust_remote_code": True}
        try:
            model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, attn_implementation="flash_attention_2", **model_kwargs).eval()
        except Exception:
            model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID, **model_kwargs).eval()

    return model, tokenizer


def load_rwkv7_model(model_path: str):
    """Load RWKV7-G1C-1.5B model."""
    os.environ["RWKV_JIT_ON"] = "1"
    os.environ["RWKV_V7_ON"] = "1"

    if IS_CPU:
        os.environ["RWKV_CUDA_ON"] = "0"
    else:
        os.environ["RWKV_CUDA_ON"] = "1"

    from rwkv.model import RWKV
    from rwkv.rwkv_tokenizer import TRIE_TOKENIZER

    if IS_CPU:
        strategy = "cpu fp32"
    else:
        strategy = "cuda fp16"

    if model_path.endswith(".pth"):
        model_path = model_path[:-4]

    model = RWKV(model=model_path, strategy=strategy)

    vocab_path = str(SUPPORT_DIR / "rwkv_vocab_v20230424.txt")
    tokenizer = TRIE_TOKENIZER(vocab_path)

    return model, tokenizer


def precompute_example():
    """Precompute the example and save results."""
    from core.evaluator import evaluate_hf_single_sample, evaluate_rwkv7_single_sample
    from visualization.html_generator import generate_comparison_html

    # Read example text
    example_file = SCRIPT_DIR / "the_bitter_lesson.txt"
    with open(example_file, "r", encoding="utf-8") as f:
        example_text = f.read()

    print(f"Example text length: {len(example_text)} characters")

    # Download and load models
    print("Downloading RWKV model if needed...")
    rwkv_model_path = download_rwkv_model()

    print("Loading Qwen3-1.7B-Base...")
    qwen_model, qwen_tokenizer = load_qwen_model()

    print("Loading RWKV7-G1C-1.5B...")
    rwkv_model, rwkv_tokenizer = load_rwkv7_model(rwkv_model_path)

    # Run evaluations
    print("Evaluating with Qwen3...")
    result_qwen = evaluate_hf_single_sample(qwen_model, qwen_tokenizer, example_text, bos_mode="add_newline_token")
    print(f"Qwen3 completed in {result_qwen['inference_time']:.2f}s")

    print("Evaluating with RWKV7...")
    result_rwkv = evaluate_rwkv7_single_sample(rwkv_model, rwkv_tokenizer, example_text)
    print(f"RWKV7 completed in {result_rwkv['inference_time']:.2f}s")

    # Generate HTML visualization
    print("Generating visualization...")
    html = generate_comparison_html(
        text=example_text,
        byte_losses_a=result_rwkv["byte_wise_losses"],
        byte_losses_b=result_qwen["byte_wise_losses"],
        model_a_name="RWKV7-G1C-1.5B",
        model_b_name="Qwen3-1.7B-Base",
        topk_predictions_a=result_rwkv["top5_predictions"],
        topk_predictions_b=result_qwen["top5_predictions"],
        tokenizer_a=result_rwkv["tokenizer"],
        tokenizer_b=result_qwen["tokenizer"],
        model_type_a="rwkv7",
        model_type_b="hf",
    )

    # Save precomputed results
    PRECOMPUTED_DIR.mkdir(parents=True, exist_ok=True)

    # Save HTML
    html_path = PRECOMPUTED_DIR / "example_visualization.html"
    with open(html_path, "w", encoding="utf-8") as f:
        f.write(html)
    print(f"Saved HTML to {html_path}")

    # Save metadata
    metadata = {
        "example_text": example_text,
        "qwen_inference_time": result_qwen["inference_time"],
        "rwkv_inference_time": result_rwkv["inference_time"],
        "qwen_compression_rate": result_qwen["compression_rate"],
        "rwkv_compression_rate": result_rwkv["compression_rate"],
    }
    metadata_path = PRECOMPUTED_DIR / "example_metadata.json"
    with open(metadata_path, "w", encoding="utf-8") as f:
        json.dump(metadata, f, ensure_ascii=False, indent=2)
    print(f"Saved metadata to {metadata_path}")

    print("Done! Precomputed example is ready.")


if __name__ == "__main__":
    precompute_example()