zh_0.18B_LLM / app.py
mdokl's picture
更新移动端识别逻辑
d9a5c17 verified
import time
import html
import uuid
import zlib
import json
import torch
import base64
import threading
import numpy as np
import gradio as gr
from queue import Queue
from user_agents import parse
import plotly.graph_objects as go
# 私有库
from make_model import make_model
from LazyCache import ExpiringDict
from train_and_use import El_text_continue_stream
from tokenizer import tokenizer,vocab_size,token2str
# 性能分析工具
# import sys
# import line_profiler
# profiler = line_profiler.LineProfiler()
# 计算设备检查
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建模型
model = make_model(
#token是从1开始的,0填充,剩下的用来覆盖全部字节
vocab_size = vocab_size+1+255,
embedding_dim = 768,
key_dim = 128,
head_number = 12,
position_information_type = "mask",
enable_affine = True,
enable_talking_head = True,
use_diff = False,
self_attention_block_size = 0,
feed_forward_dim = 1536,
enable_layer_norm = True,
deep = 12,
dropout_rate = 0.1,
enable_el_cache = True
).to(device)
# 加载模型权重文件
model.load_state_dict(torch.load('large_model_instruct_09291732.weight',map_location=device,weights_only=True))
# model.load_state_dict(torch.load('large_model_instruct_09302228.weight',map_location=device,weights_only=True))
# 设置为评估模式(关闭dropout等)
model = model.eval()
# 元素保质期550秒的全局字典
user_queues = ExpiringDict(ttl=550) # session_id对应的流式输出数据池,元素为模型输出的最新markdown_string,
# 替换user_history_sessions_show[session_id]最后一个元素的模型输出部分直接刷到chatbot
user_queues.start_auto_cleanup() # start_auto_cleanup()用于启动检查,定期释放超时的数据
user_stop_flags = ExpiringDict(ttl=550) # session_id对应的输出状态表示(True 表示未在进行流式输出)
user_stop_flags.start_auto_cleanup()
user_history_sessions_show = ExpiringDict(ttl=550) # session_id对应的元组形式会话列表,用于chatbot显示,格式:list([用户提问:markdown_string,模型输出:markdown_string])
user_history_sessions_show.start_auto_cleanup()
user_history_sessions_text = ExpiringDict(ttl=550) # session_id对应的文本会话历史,用于打印日志
user_history_sessions_text.start_auto_cleanup()
total_querys = ExpiringDict(ttl=550) # session_id对应的token列表,在注意力可视化
total_querys.start_auto_cleanup()
chatml_idx = ExpiringDict(ttl=550) # session_id对应的对话模板起始位置,在注意力可视化时进行区域分割
chatml_idx.start_auto_cleanup()
# 对于用户输入,分词后讲每个词汇用HTML span标签包裹在矩形中,用于显示美化
def token_wapper(token):
# 对特殊字符进行HTML转义处理
escaped_token = html.escape(token)
return f'<span class="token-box">{escaped_token}</span>'
# 对于触发字节回退机制的生僻字,无法直接打印对应的词汇,需要合并包装显示
def token_split_wapper(token):
escaped_token = html.escape(token)
return f'<span class="multi-token-box">({escaped_token})[多token]</span>'
# 将用户的输入转化为用于展示分词结果的HTML标记
def process_user_tokens(user_message):
# 5.0代表使用非随机的最大概率路径分词算法分词,产生整数token_id列表
user_tokens = tokenizer(user_message, 5.0)
# 将token还原并进行安全包装
words = [] # token列表
temp = [] # 要合并显示的特殊字节
for token in user_tokens:
if token > 0: # 字节回退对应的token_id是负数,0代表填充,剩余的是可以打印的词汇
if len(temp): # 在加入前,需要将可能存在的回退字节先打包解码加入
words.append(token_split_wapper(token2str(temp)))
temp = []
# 讲解码后的词汇加入列表
words.append(token_wapper(token2str([token])))
else:
# 讲回退字节送去合并
temp.append(token)
# 用输入如的一定可以正确解码,这里要对可能存在的回退字节打包。
if len(temp):
words.append(token_split_wapper(token2str(temp)))
# 返回包装好HTML标记形式的分词结果
return ''.join(words)
# 文本生成函数,为流式输出的数据池提供数据,通过 session_id 区分用户
def generate_text(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay):
# 这里立即刷出"⚡处理中⚡"更加人性化
user_queues[session_id].put("⚡处理中⚡", block=False)
# 模型支持批量处理,但当前场景不需要高度并行,优先考虑响应速度,一个批次就一个用户问题
tokens_batch = [tokenizer(f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant ", 5.0)]
# 获取历史长度(不包含当前用户问题)
history_len = len(sess) - 1
if history_len:
# 追加文本历史,用于日志打印
user_history_sessions_text[session_id] += f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
# 追加用户输入模板的始末位置,用于在注意力可视化时进行区域分割
chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|>"),len(total_querys[session_id])+len("<|im_start|>user "))]
# 追加分词结果,用于在可视化时对应注意力权重
total_querys[session_id] += tokens_batch[0]
# 追加助手回复模板的始末位置,用于在注意力可视化时进行区域分割
chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|><|im_start|>assistant "),len(total_querys[session_id]))]
else:
# 没有历史时,直接赋值
user_history_sessions_text[session_id] = f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant "
total_querys[session_id] = tokens_batch[0]
chatml_idx[session_id] = [(0,len("<|im_start|>user "))]
# 这里特殊处理,“助手模板始末位置”拼接在“用户模板始末位置”的后面
chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|><|im_start|>assistant "),len(total_querys[session_id]))]
# 讲分词结果映射到模型嵌入层的下标,由于字节回退的结果的结果是负数,+255全部映射成正数
tokens_batch = np.array(tokens_batch, dtype=np.int64) + 255
# 将问题转化为模型需要的torch张量格式
inputs = torch.from_numpy(tokens_batch).to(device).data
# 模型输出
with torch.no_grad():
out = "" # 模型输出初始为空
last_len = -1 # 上次成功解码的位置,初始为-1
# 通过EL-Attention进行推理优化的流式输出方法
# 不同用户的KV-Cache用session_id区分
# history_len==0:直接处理整个输入,重置KV-Cache
# history_len>0:在KV-Cache上增量追加用户问题中的新token
# 注意:返回的o最多只有4个元素(假设了模型自己不会连续生成需要回退的生僻字)
for o in El_text_continue_stream(
model, inputs, out_length=max_length,
repeat_penalty_value=repeat_penalty,
temperature=temperature,decay=decay,
session_id=session_id,history_len=history_len):
# 如果当前位置可以完整解码
if o[0,-1] > 255:
# 将未解码的部分一起解码
temp = token2str(o[0][last_len:].cpu().numpy()-255)
total_querys[session_id] += list(o[0][last_len:].cpu().numpy()-255)
out += temp
user_history_sessions_text[session_id] += temp
# 重置解码光标
last_len = -1
user_queues[session_id].put(out, block=False)
else:
# 无法解码,光标固定,等最后一起解码
last_len -= 1
# 如果用户主动断开连接,停止生成,如果有部分终止标记,统一去除
if user_stop_flags.get(session_id, True):
if '<' + out.split('<')[-1] in '<|im_end|>':
# 显示的部分去除标记
out = '<'+'<'.join(out.split('<')[:-1])
# 历史的部分保留标记
user_history_sessions_text[session_id] = '<'+'<'.join(user_history_sessions_text[session_id].split('<')[:-1])+'<|im_end|>'
break
# 模型输出了终止标记,直接终止输出
if '<|im_end|>' in out:
# 显示的部分,去除标记
out = out.split('<|im_end|>')[0]
user_queues[session_id].put(out, block=False)
break
# 设置输出暂停标记(给流式输出函数发送信号)
user_stop_flags[session_id] = True
# 打印对话内容,便于调试错误
print(user_history_sessions_text[session_id])
# 消息按钮处理逻辑:发送消息 / 停止生成 / 会话过期
def send_message(sess, btn_label, user_message, session_id, temperature, repeat_penalty, max_length, decay):
# 如果发现会话过期了,直接返并在按钮上提示
if session_id not in user_history_sessions_text:
return "", "会话过期!"
# 按钮文本是消息发送,并且用户消息非空
if btn_label == "发送消息" and user_message:
# chatbot中回显用户消息
user_tokens_display = process_user_tokens(user_message)
# chatbot中的会话保存到历史消息中
user_history_sessions_show[session_id] = sess
# 在历史消息中添加用户消息与空的模型输出
user_history_sessions_show[session_id] += [[user_tokens_display, ""]]
# 创建文本生成的进程
thread = threading.Thread(target=generate_text, args=(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay))
# 设置为主进程退出时退出,避免浪费资源
thread.daemon = True
# 先设置正在生成的标记
user_stop_flags[session_id] = False
# 再启动文本生成进行
thread.start()
# 清空输入框,更新按钮文本为停止生成
return "", "停止生成"
else:
# 会话没过期,也不是发送消息,那就是终止生成,通过种植标记发送信号
user_stop_flags[session_id] = True
# 这里要保留用户输入的文本,同时更新按钮文本为“发送消息”
return user_message, "发送消息"
# 清空会话方法
def clear_session():
# 调试时可以在这里打印性能分析结果
# profiler.print_stats(stream=sys.stdout)
# 注意力可视化的参数全部清零,避免异常
return [], gr.update(maximum=0, value=0), gr.update(maximum=0, value=0)
# 通过不断检查模型输出队列并更新chatbot来模拟流式输出
def stream_output(sess):
global user_queues, user_stop_flags, user_history_sessions_show, user_history_sessions_text
# 页面加载时初始化 session_id 作为用户的唯一标识
session_id = str(uuid.uuid4())
# 创建用于流式输出的消息队列
user_queues[session_id] = Queue()
# 默认是停止状态,等待用户发送消息
user_stop_flags[session_id] = True
user_history_sessions_show[session_id] = [] # 初始化历史会话记录,用于chatbot
user_history_sessions_text[session_id] = "" # 初始化历史会话记录,用于日志打印
chatml_idx[session_id] = [] # 会话模板始末位置,用于在注意力可视化时分割区域
# 设置页面的初始状态
yield [], "发送消息", session_id, gr.update(maximum=0), gr.update(maximum=0)
# 流式输出循环
t = time.time()
while True:
# 处理队列中的消息
if not user_queues[session_id].empty():
t = time.time()
# 取到最后一个加入的数据
while user_queues[session_id].qsize() > 1:
user_queues[session_id].get()
out = user_queues[session_id].get()
# 用队列中最新的的模型输出替换当前的模型输出
sess = user_history_sessions_show[session_id]
sess[-1][1] = out
# 更新按钮文本为“停止生成”,如果发现输出暂停了就恢复为“发送消息”
button_label = "停止生成" if not user_stop_flags.get(session_id, True) else "发送消息"
# 更新注意力可视化部分的滑动条最大长度
if model.encoder.encoder_layers[-1].multi_head_attention.cnt != None and session_id in model.encoder.encoder_layers[-1].multi_head_attention.cnt:
new_max_tokens = model.encoder.encoder_layers[-1].multi_head_attention.cnt[session_id]
yield sess, button_label, session_id, gr.update(maximum=new_max_tokens), gr.update(maximum=new_max_tokens)
else:
yield sess, button_label, session_id, gr.update(maximum=0), gr.update(maximum=0)
else:
time.sleep(0.1) # 防止 busy-wait 占满 CPU
# 长期没有输出需求,退出循环
if time.time() - t > 600:
break
# 注意力可视化组件
def create_attention_visualization():
"""创建桌面端注意力可视化界面"""
gr.Markdown("# 注意力可视化", elem_classes="title")
# 固定在底部的控制区域
with gr.Row():
token_slider = gr.Slider(
0, 0, value=0, step=1,
label="查看第几个token"
)
highlight_slider = gr.Slider(
0, 0, value=0, step=1,
label="高亮token位置"
)
# 可视化信息显示
info_box = gr.Textbox(
label="进度显示",
lines=2,
max_lines=2,
autoscroll=False,
interactive=False
)
# 单一的合并热力图
combined_plot = gr.HTML(
value="<div id='heatmap-container' style='height:600px;width:100%;'></div>",
show_label=False,
container=False
)
return token_slider, highlight_slider, info_box, combined_plot
# 创建合并的热力图(只有数据)
def create_combined_heatmap(session_id, token_idx=0, highlight_col=None):
"""创建合并的单一热力图数据,不包含任何渲染配置"""
# 检查数据有效性
for layer in model.encoder.encoder_layers:
if (layer.multi_head_attention.attention_matrix == None or
session_id not in layer.multi_head_attention.attention_matrix):
return None
# 获取最大长度
max_len = max(0, model.encoder.encoder_layers[-1].multi_head_attention.cnt[session_id] + 1)
token_idx = min(token_idx, max_len)
# 构建合并矩阵:12个head层 + 11个分隔层 = 155行
combined_matrix = []
# 准备token标签
if session_id in total_querys:
tokens = [(token2str([token_id]) if token_id > 0 else f'byte{-token_id}')
for token_id in total_querys[session_id]]
tokens = (tokens + [''] * max_len)[:max_len]
else:
tokens = [''] * max_len
# 准备数据
for layer_n, layer in enumerate(reversed(model.encoder.encoder_layers)):
layer_n = 11 - layer_n # 计算实际的层索引
# 添加当前层的12个head
if (layer.multi_head_attention.attention_matrix and
session_id in layer.multi_head_attention.attention_matrix):
for head_i in range(11, -1, -1):
head_data = (list(layer.multi_head_attention.attention_matrix[session_id][head_i][token_idx]) + [0] * max_len)[:max_len]
combined_matrix.append(head_data)
# 添加分隔层
if layer_n > 0:
separator = [1.0] * max_len
combined_matrix.append(separator)
# 返回纯数据字典,不包含任何图形对象
return {
'matrix': combined_matrix,
'tokens': tokens,
'shape': [len(combined_matrix), max_len],
'highlight_col': highlight_col,
'token_idx': token_idx,
'chatml_regions': chatml_idx.get(session_id, [])
}
# 当前信息可视化
def update_attention_visualization(session_id, token_idx, highlight_col):
if session_id in total_querys:
info = f"[高亮token:{token2str([total_querys[session_id][highlight_col]])}]{token2str(total_querys[session_id][:token_idx+1])}"
else:
info = ""
data_dict = create_combined_heatmap(session_id, token_idx, highlight_col)
if data_dict and data_dict['matrix']:
# 后端:将对数值量化
matrix_np = np.array(data_dict['matrix'], dtype=np.float32)
# 将对数值映射到0~65535
matrix_log = np.log10(matrix_np + 1e-8) # 避免log(0)
# 将对数范围映射到0~1(假设注意力权重在1e-8到1之间)
matrix_norm = (matrix_log - (-8)) / (0 - (-8)) # 从-8到0映射到0~1
matrix_quantized = (matrix_norm * 65535).astype(np.uint16)
compressed = base64.b64encode(zlib.compress(matrix_quantized.tobytes())).decode('ascii')
# 只传输最基础的数据
html_content = f"""
<div id="heatmap-container" style="height:600px;width:100%;"></div>
<script id="heatmap-data"
data-compressed="{compressed}"
data-shape='{json.dumps(data_dict["shape"])}'
data-tokens='{json.dumps(data_dict["tokens"])}'
data-highlight-col='{data_dict["highlight_col"]}'
data-token-idx='{data_dict["token_idx"]}'
data-chatml-regions='{json.dumps(data_dict["chatml_regions"])}'>
</script>
"""
return html_content, info
return "<div id='heatmap-container' style='height:600px;width:100%;'></div><script id='heatmap-data'></script>", info
# 检测设备类型
def check_device(request: gr.Request):
ua_string = parse(request.headers.get("user-agent", ""))
print(str(ua_string))
return "PC" in str(ua_string)[:2]
def toggle_attention_visibility(is_desktop):
"""根据设备类型切换注意力可视化可见性"""
return gr.update(visible=is_desktop)
# UI美化
css = """
/* 大标题居中 */
.title {
text-align: center;
}
/* 高级设置按键样式 */
#adv-param button {
justify-content: center;
}
#adv-param > button > span {
font-size: 16px !important;
font-weight: 600 !important;
}
/* 用户输入包装样式 */
.token-box {
display: inline-block;
background-color: #f0f0f0;
border: 1px solid #ddd;
border-radius: 4px;
padding: 2px 4px;
margin: 2px;
font-family: monospace;
}
.multi-token-box {
display: inline-block;
background-color: #e6f7ff;
border: 1px solid #91d5ff;
border-radius: 4px;
padding: 2px 4px;
margin: 2px;
font-family: monospace;
}
/* 隐藏悬浮触发的plotly logo */
.js-plotly-plot .plotly .modebar {
display: none !important;
}
/* 隐藏小垃圾桶 */
.icon-button-wrapper button[aria-label="清空对话"] {
display: none !important;
}
/* 避免图像更新时闪烁 */
.html-container.padding.pending {
opacity: 1 !important;
filter: none !important;
}
"""
# 主UI构建
# profiler.add_function(create_combined_heatmap)
with gr.Blocks(css=css, theme=gr.themes.Default(), head='''
<link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🔬</text></svg>">
<script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
<script src="https://cdn.plot.ly/plotly-2.24.1.min.js"></script>
<script>
// 全局渲染函数 - 完整版本
window.renderCompressedHeatmap = function(compressedData, tokens, highlightCol, tokenIdx, chatmlRegions) {
// console.log('开始渲染压缩数据');
if (!compressedData || !compressedData.data) {
console.error('没有压缩数据');
return;
}
try {
// Base64解码
const binaryStr = atob(compressedData.data);
const bytes = new Uint8Array(binaryStr.length);
for (let i = 0; i < binaryStr.length; i++) {
bytes[i] = binaryStr.charCodeAt(i);
}
// zlib解压
const decompressed = pako.inflate(bytes);
const shape = compressedData.shape;
// 前端:还原对数域
// 关键修改:使用Uint16Array而不是Float32Array
const uint16Array = new Uint16Array(decompressed.buffer);
const float32Array = new Float32Array(uint16Array.length);
for (let i = 0; i < uint16Array.length; i++) {
const logValue = (uint16Array[i] / 65535.0) * 8 - 8; // 从0~1映射回-8~0
float32Array[i] = Math.pow(10, logValue);
}
// 重建矩阵
const matrix = [];
for (let i = 0; i < shape[0]; i++) {
matrix.push(Array.from(float32Array.slice(i * shape[1], (i + 1) * shape[1])));
}
// 创建自定义悬停信息
const customData = [];
for (let i = 0; i < shape[0]; i++) {
const rowData = [];
for (let j = 0; j < shape[1]; j++) {
const layer = Math.floor(i / 13);
const head = 11 - (i % 13); // 反转head顺序
const token = tokens && tokens[j] ? tokens[j] : '';
rowData.push(`Head:${head}<br>Token:${token}`);
}
customData.push(rowData);
}
// 创建矩形框
const shapes = [];
// 2. 区域分割矩形(对话模板)
const colors = ["green", "blue"];
if (chatmlRegions && chatmlRegions.length > 0) {
chatmlRegions.forEach((region, i) => {
if (region && region.length === 2) {
shapes.push({
type: 'rect',
x0: region[0] - 0.5,
x1: region[1] - 0.5,
y0: -0.5,
y1: shape[0] - 0.5,
line: {
color: colors[i % 2],
width: 2
},
fillcolor: 'rgba(0,0,0,0)'
});
// console.log(`添加区域分割矩形 ${i}:`, region, '颜色:', colors[i % 2]);
}
});
}
// 1. 高亮矩形
if (highlightCol !== null && highlightCol !== undefined && highlightCol <= tokenIdx) {
shapes.push({
type: 'rect',
x0: highlightCol - 0.5,
x1: highlightCol + 0.5,
y0: -0.5,
y1: shape[0] - 0.5,
line: {
color: 'white',
width: 1
},
fillcolor: 'rgba(0,0,0,0)'
});
// console.log('添加高亮矩形,位置:', highlightCol);
}
// 渲染
const container = document.getElementById('heatmap-container');
if (container) {
Plotly.purge(container);
const trace = {
z: matrix,
type: 'heatmap',
colorscale: [
[0.0, '#000000'],
[0.04, '#4d0000'],
[0.16, '#cc0000'],
[0.36, '#ff8800'],
[0.64, '#ffff00'],
[1.0, '#ffffff']
],
showscale: false,
customdata: customData,
hovertemplate: '%{customdata}<br>Position:%{x}<br>Attention:%{z:.3f}<extra></extra>',
hoverongaps: false
};
const layout = {
margin: {l: 30, r: 5, t: 0, b: 0},
height: 600,
yaxis: {
tickmode: 'array',
tickvals: Array.from({length: 12}, (_, i) => i * 13 + 6),
ticktext: Array.from({length: 12}, (_, i) => `L${11 - i}`),
showgrid: false,
zeroline: false,
fixedrange: true
},
xaxis: {
showticklabels: false,
showgrid: false,
zeroline: false,
fixedrange: true
},
shapes: shapes
};
Plotly.react(container, [trace], layout).then(() => {
// console.log('热力图渲染完成!包含:',
// shapes.length, '个矩形框',
// '- 高亮:', highlightCol !== null ? '有' : '无',
// '- 区域:', chatmlRegions ? chatmlRegions.length : 0
// );
});
}
} catch (error) {
console.error('渲染失败:', error);
}
}
</script>
''') as demo:
with gr.Row():
with gr.Column():
gr.Markdown("# 0.18B中文大语言模型", elem_classes="title")
# 左侧:原有功能
chatbot = gr.Chatbot(
label="对话",
autoscroll=False,
show_copy_button=True,
height=400,
type="tuples"
)
with gr.Column():
msg = gr.Textbox(
placeholder="请输入你的问题...",
label="",
lines=3,
show_label=False
)
with gr.Row():
send_btn = gr.Button("发送消息")
clear_btn = gr.Button("清空会话")
# 参数设置区域
with gr.Accordion("高级参数设置", open=False, elem_id="adv-param"):
with gr.Row():
temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature")
repeat_penalty = gr.Slider(0.0, 5.0, value=1.3, step=0.1, label="Repeat Penalty")
with gr.Row():
max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length")
decay = gr.Slider(0.90, 1.0, value=0.99, step=0.005, label="Repeat Penalty Decay Rate")
with gr.Column(visible=False) as attention_col:
# 右侧:注意力可视化 - 现在只返回4个值
(token_slider, highlight_slider,
info_box, combined_plot) = create_attention_visualization()
# 状态存储
session_id = gr.State()
# 绑定聊天交互事件
send_btn.click(
send_message,
inputs=[chatbot, send_btn, msg, session_id, temperature, repeat_penalty, max_length, decay],
outputs=[msg, send_btn],
show_progress="hidden"
)
clear_btn.click(
clear_session,
inputs=[],
outputs=[chatbot, token_slider, highlight_slider],
show_progress="hidden"
)
# 绑定注意力可视化交互事件
# 修改当前token
token_event = token_slider.change(
update_attention_visualization,
inputs=[session_id, token_slider, highlight_slider],
outputs=[combined_plot, info_box],
show_progress="hidden",
api_name=False
).then(
None,
None,
None,
js="""
function() {
// console.log('数据更新完成,触发渲染');
const dataScript = document.getElementById('heatmap-data');
if (dataScript && dataScript.dataset.compressed) {
const compressedData = {
data: dataScript.dataset.compressed,
shape: JSON.parse(dataScript.dataset.shape)
};
const tokens = JSON.parse(dataScript.dataset.tokens || '[]');
const highlightCol = dataScript.dataset.highlightCol ? parseInt(dataScript.dataset.highlightCol) : null;
const tokenIdx = dataScript.dataset.tokenIdx ? parseInt(dataScript.dataset.tokenIdx) : 0;
const chatmlRegions = JSON.parse(dataScript.dataset.chatmlRegions || '[]');
// console.log('找到压缩数据,开始渲染');
if (window.renderCompressedHeatmap) {
window.renderCompressedHeatmap(compressedData, tokens, highlightCol, tokenIdx, chatmlRegions);
}
} else {
// console.log('未找到压缩数据');
}
return [];
}
"""
)
# 修改高亮token
highlight_event = highlight_slider.change(
None, # 不调用Python函数
inputs=[highlight_slider, info_box], # 输入:滑块值 + 当前info_box内容
outputs=[info_box], # 输出:更新后的info_box
js="""
function(highlightCol, currentInfo) {
const dataScript = document.getElementById('heatmap-data');
if (!dataScript) return [currentInfo]; // 返回原内容
// 更新高亮字段
dataScript.dataset.highlightCol = highlightCol;
// 获取必要信息
const tokens = JSON.parse(dataScript.dataset.tokens || '[]');
const shape = JSON.parse(dataScript.dataset.shape || '[0,0]');
const tokenIdx = dataScript.dataset.tokenIdx ? parseInt(dataScript.dataset.tokenIdx) : 0;
const chatmlRegions = JSON.parse(dataScript.dataset.chatmlRegions || '[]');
let newInfo = currentInfo;
// 1. 更新info_box内容
if (tokens[highlightCol]) {
newInfo = currentInfo.replace(
/\\[高亮token:[^\\]]*\\]/,
`[高亮token:${tokens[highlightCol]}]`
);
}
// 2. 更新热力图矩形框
const container = document.getElementById('heatmap-container');
if (container) {
const shapes = [];
const totalRows = shape[0];
// 区域分割矩形
const colors = ["green", "blue"];
chatmlRegions.forEach((region, i) => {
if (region && region.length === 2) {
shapes.push({
type: 'rect',
x0: region[0] - 0.5, x1: region[1] - 0.5,
y0: -0.5, y1: totalRows - 0.5,
line: { color: colors[i % 2], width: 2 },
fillcolor: 'rgba(0,0,0,0)'
});
}
});
// 高亮矩形
if (highlightCol !== null && highlightCol !== undefined && highlightCol <= tokenIdx) {
shapes.push({
type: 'rect',
x0: highlightCol - 0.5, x1: highlightCol + 0.5,
y0: -0.5, y1: totalRows - 0.5,
line: { color: 'white', width: 1 },
fillcolor: 'rgba(0,0,0,0)'
});
}
// 只更新布局中的矩形框
Plotly.relayout(container, { shapes: shapes });
}
return [newInfo]; // 返回更新后的内容
}
""",
show_progress="hidden",
api_name=False
)
# 存储设备类型
is_pc = gr.State()
# 加载时执行
demo.load(
check_device,
inputs=[],
outputs=[is_pc],
show_progress="hidden"
).then(
toggle_attention_visibility,
inputs=[is_pc],
outputs=[attention_col],
show_progress="hidden"
).then(
stream_output,
inputs=[chatbot],
outputs=[chatbot, send_btn, session_id, token_slider, highlight_slider],
show_progress="hidden"
)
if __name__ == "__main__":
demo.queue(
max_size=128,
default_concurrency_limit=128,
api_open=False
)
demo.launch(
share=False,
show_error=True,
show_api=False
)