File size: 3,262 Bytes
9a0518c
53fc829
9a0518c
 
53fc829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da2a63c
53fc829
 
9a0518c
53fc829
c800b92
4dc666a
 
53fc829
9a0518c
 
7dff699
 
9a0518c
 
 
53fc829
9a0518c
 
 
 
 
53fc829
9a0518c
 
 
 
 
 
7dff699
9a0518c
 
 
 
 
 
53fc829
 
 
 
 
9a0518c
53fc829
9a0518c
53fc829
 
 
 
 
9a0518c
53fc829
9a0518c
 
 
7dff699
 
 
 
 
 
 
9a0518c
 
 
53fc829
9a0518c
53fc829
 
 
 
 
 
 
 
 
 
 
 
 
 
9a0518c
 
 
4dc666a
 
da2a63c
4dc666a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
ROSA-QKV-1bit 交互式演示应用 - 主程序
"""

import gradio as gr
from constants import ROSA_CODE, ROSA_QUICK_CODE, GRADIO_MAJOR, MAX_DEMO_LEN
from code_highlighting import build_plain_code_html
from code_builder import build_code_html
from algorithm import initialize_line_numbers
from ui_styles import CSS
from javascript_handler import JS_FUNC, get_js_boot
from ui_handlers import on_demo, on_random

# 初始化行号映射
LINE_NUMBERS = initialize_line_numbers(ROSA_CODE)

# 构建代码 HTML
CODE_HTML = build_code_html(ROSA_CODE, LINE_NUMBERS)
QUICK_CODE_HTML = build_plain_code_html(ROSA_QUICK_CODE, "rosa-code-quick")

# 获取 JavaScript 启动代码
JS_BOOT = get_js_boot(JS_FUNC)

# 创建 Gradio 应用
demo_context = gr.Blocks(css=CSS, js=JS_BOOT)

with demo_context as demo:
    # 页面标题
    gr.HTML(
        '<div class="page-header">'
        '<div class="page-title">RWKV-8 ROSA-QKV-1bit Demo</div>'
        '<div class="page-subtitle">This is using naive algorithm (not suffix automaton). Enter or randomize q/k/v (0/1 only), then click [Start Demo].</div>'
        "</div>"
    )

    # 输入行
    with gr.Row():
        q_text = gr.Textbox(label="q sequence", value="01010101010101010101", lines=1)
        k_text = gr.Textbox(label="k sequence", value="10100110100110100110", lines=1)
        v_text = gr.Textbox(label="v sequence", value="11001100110011001100", lines=1)

    # 控制行
    with gr.Row():
        length = gr.Slider(4, 20, value=20, step=1, label="Random length")
        random_btn = gr.Button("Randomize")
        demo_btn = gr.Button("Start Demo", variant="primary")
        speed = gr.Slider(
            0.1,
            10.0,
            value=2.0,
            step=0.05,
            label="Playback speed",
            elem_id="speed_slider",
            interactive=True,
        )
        theme_toggle = gr.Checkbox(
            label="Dark mode",
            value=False,
            elem_id="theme_toggle",
        )

    # 输出
    out_text = gr.Textbox(label="Output", interactive=False)
    steps_json = gr.Textbox(
        visible=True,
        elem_id="steps_json",
        elem_classes=["rosa-hidden"],
    )

    # 可视化和代码面板
    gr.HTML(
        f'<div id="rosa-shell" class="rosa-shell">'
        f'<div class="rosa-pane"><div id="rosa-vis"></div></div>'
        f'<div class="rosa-code-pane">'
        f"{CODE_HTML}"
        f'<details class="quick-code-details">'
        f"<summary>Fast version (click to expand)</summary>"
        f"{QUICK_CODE_HTML}"
        f"</details>"
        f"</div>"
        f"</div>"
    )

    # 绑定事件处理器
    random_btn.click(on_random, inputs=[length], outputs=[q_text, k_text, v_text])

    def on_demo_with_lines(q, k, v):
        return on_demo(q, k, v, LINE_NUMBERS)

    demo_btn.click(
        on_demo_with_lines,
        inputs=[q_text, k_text, v_text],
        outputs=[steps_json, out_text],
    )
    demo.load(
        on_demo_with_lines,
        inputs=[q_text, k_text, v_text],
        outputs=[steps_json, out_text],
    )


if __name__ == "__main__":
    launch_kwargs = {}
    if GRADIO_MAJOR >= 6:
        launch_kwargs = {"css": CSS, "js": JS_BOOT, "ssr_mode": False}
    demo.launch(**launch_kwargs)