Spaces:
Sleeping
Sleeping
| import time | |
| import html | |
| import uuid | |
| import zlib | |
| import json | |
| import torch | |
| import base64 | |
| import threading | |
| import numpy as np | |
| import gradio as gr | |
| from queue import Queue | |
| from user_agents import parse | |
| import plotly.graph_objects as go | |
| # 私有库 | |
| from make_model import make_model | |
| from LazyCache import ExpiringDict | |
| from train_and_use import El_text_continue_stream | |
| from tokenizer import tokenizer,vocab_size,token2str | |
| # 性能分析工具 | |
| # import sys | |
| # import line_profiler | |
| # profiler = line_profiler.LineProfiler() | |
| # 计算设备检查 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # 创建模型 | |
| model = make_model( | |
| #token是从1开始的,0填充,剩下的用来覆盖全部字节 | |
| vocab_size = vocab_size+1+255, | |
| embedding_dim = 768, | |
| key_dim = 128, | |
| head_number = 12, | |
| position_information_type = "mask", | |
| enable_affine = True, | |
| enable_talking_head = True, | |
| use_diff = False, | |
| self_attention_block_size = 0, | |
| feed_forward_dim = 1536, | |
| enable_layer_norm = True, | |
| deep = 12, | |
| dropout_rate = 0.1, | |
| enable_el_cache = True | |
| ).to(device) | |
| # 加载模型权重文件 | |
| model.load_state_dict(torch.load('large_model_instruct_09291732.weight',map_location=device,weights_only=True)) | |
| # model.load_state_dict(torch.load('large_model_instruct_09302228.weight',map_location=device,weights_only=True)) | |
| # 设置为评估模式(关闭dropout等) | |
| model = model.eval() | |
| # 元素保质期550秒的全局字典 | |
| user_queues = ExpiringDict(ttl=550) # session_id对应的流式输出数据池,元素为模型输出的最新markdown_string, | |
| # 替换user_history_sessions_show[session_id]最后一个元素的模型输出部分直接刷到chatbot | |
| user_queues.start_auto_cleanup() # start_auto_cleanup()用于启动检查,定期释放超时的数据 | |
| user_stop_flags = ExpiringDict(ttl=550) # session_id对应的输出状态表示(True 表示未在进行流式输出) | |
| user_stop_flags.start_auto_cleanup() | |
| user_history_sessions_show = ExpiringDict(ttl=550) # session_id对应的元组形式会话列表,用于chatbot显示,格式:list([用户提问:markdown_string,模型输出:markdown_string]) | |
| user_history_sessions_show.start_auto_cleanup() | |
| user_history_sessions_text = ExpiringDict(ttl=550) # session_id对应的文本会话历史,用于打印日志 | |
| user_history_sessions_text.start_auto_cleanup() | |
| total_querys = ExpiringDict(ttl=550) # session_id对应的token列表,在注意力可视化 | |
| total_querys.start_auto_cleanup() | |
| chatml_idx = ExpiringDict(ttl=550) # session_id对应的对话模板起始位置,在注意力可视化时进行区域分割 | |
| chatml_idx.start_auto_cleanup() | |
| # 对于用户输入,分词后讲每个词汇用HTML span标签包裹在矩形中,用于显示美化 | |
| def token_wapper(token): | |
| # 对特殊字符进行HTML转义处理 | |
| escaped_token = html.escape(token) | |
| return f'<span class="token-box">{escaped_token}</span>' | |
| # 对于触发字节回退机制的生僻字,无法直接打印对应的词汇,需要合并包装显示 | |
| def token_split_wapper(token): | |
| escaped_token = html.escape(token) | |
| return f'<span class="multi-token-box">({escaped_token})[多token]</span>' | |
| # 将用户的输入转化为用于展示分词结果的HTML标记 | |
| def process_user_tokens(user_message): | |
| # 5.0代表使用非随机的最大概率路径分词算法分词,产生整数token_id列表 | |
| user_tokens = tokenizer(user_message, 5.0) | |
| # 将token还原并进行安全包装 | |
| words = [] # token列表 | |
| temp = [] # 要合并显示的特殊字节 | |
| for token in user_tokens: | |
| if token > 0: # 字节回退对应的token_id是负数,0代表填充,剩余的是可以打印的词汇 | |
| if len(temp): # 在加入前,需要将可能存在的回退字节先打包解码加入 | |
| words.append(token_split_wapper(token2str(temp))) | |
| temp = [] | |
| # 讲解码后的词汇加入列表 | |
| words.append(token_wapper(token2str([token]))) | |
| else: | |
| # 讲回退字节送去合并 | |
| temp.append(token) | |
| # 用输入如的一定可以正确解码,这里要对可能存在的回退字节打包。 | |
| if len(temp): | |
| words.append(token_split_wapper(token2str(temp))) | |
| # 返回包装好HTML标记形式的分词结果 | |
| return ''.join(words) | |
| # 文本生成函数,为流式输出的数据池提供数据,通过 session_id 区分用户 | |
| def generate_text(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay): | |
| # 这里立即刷出"⚡处理中⚡"更加人性化 | |
| user_queues[session_id].put("⚡处理中⚡", block=False) | |
| # 模型支持批量处理,但当前场景不需要高度并行,优先考虑响应速度,一个批次就一个用户问题 | |
| tokens_batch = [tokenizer(f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant ", 5.0)] | |
| # 获取历史长度(不包含当前用户问题) | |
| history_len = len(sess) - 1 | |
| if history_len: | |
| # 追加文本历史,用于日志打印 | |
| user_history_sessions_text[session_id] += f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant " | |
| # 追加用户输入模板的始末位置,用于在注意力可视化时进行区域分割 | |
| chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|>"),len(total_querys[session_id])+len("<|im_start|>user "))] | |
| # 追加分词结果,用于在可视化时对应注意力权重 | |
| total_querys[session_id] += tokens_batch[0] | |
| # 追加助手回复模板的始末位置,用于在注意力可视化时进行区域分割 | |
| chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|><|im_start|>assistant "),len(total_querys[session_id]))] | |
| else: | |
| # 没有历史时,直接赋值 | |
| user_history_sessions_text[session_id] = f"<|im_start|>user {user_message}<|im_end|><|im_start|>assistant " | |
| total_querys[session_id] = tokens_batch[0] | |
| chatml_idx[session_id] = [(0,len("<|im_start|>user "))] | |
| # 这里特殊处理,“助手模板始末位置”拼接在“用户模板始末位置”的后面 | |
| chatml_idx[session_id] += [(len(total_querys[session_id])-len("<|im_end|><|im_start|>assistant "),len(total_querys[session_id]))] | |
| # 讲分词结果映射到模型嵌入层的下标,由于字节回退的结果的结果是负数,+255全部映射成正数 | |
| tokens_batch = np.array(tokens_batch, dtype=np.int64) + 255 | |
| # 将问题转化为模型需要的torch张量格式 | |
| inputs = torch.from_numpy(tokens_batch).to(device).data | |
| # 模型输出 | |
| with torch.no_grad(): | |
| out = "" # 模型输出初始为空 | |
| last_len = -1 # 上次成功解码的位置,初始为-1 | |
| # 通过EL-Attention进行推理优化的流式输出方法 | |
| # 不同用户的KV-Cache用session_id区分 | |
| # history_len==0:直接处理整个输入,重置KV-Cache | |
| # history_len>0:在KV-Cache上增量追加用户问题中的新token | |
| # 注意:返回的o最多只有4个元素(假设了模型自己不会连续生成需要回退的生僻字) | |
| for o in El_text_continue_stream( | |
| model, inputs, out_length=max_length, | |
| repeat_penalty_value=repeat_penalty, | |
| temperature=temperature,decay=decay, | |
| session_id=session_id,history_len=history_len): | |
| # 如果当前位置可以完整解码 | |
| if o[0,-1] > 255: | |
| # 将未解码的部分一起解码 | |
| temp = token2str(o[0][last_len:].cpu().numpy()-255) | |
| total_querys[session_id] += list(o[0][last_len:].cpu().numpy()-255) | |
| out += temp | |
| user_history_sessions_text[session_id] += temp | |
| # 重置解码光标 | |
| last_len = -1 | |
| user_queues[session_id].put(out, block=False) | |
| else: | |
| # 无法解码,光标固定,等最后一起解码 | |
| last_len -= 1 | |
| # 如果用户主动断开连接,停止生成,如果有部分终止标记,统一去除 | |
| if user_stop_flags.get(session_id, True): | |
| if '<' + out.split('<')[-1] in '<|im_end|>': | |
| # 显示的部分去除标记 | |
| out = '<'+'<'.join(out.split('<')[:-1]) | |
| # 历史的部分保留标记 | |
| user_history_sessions_text[session_id] = '<'+'<'.join(user_history_sessions_text[session_id].split('<')[:-1])+'<|im_end|>' | |
| break | |
| # 模型输出了终止标记,直接终止输出 | |
| if '<|im_end|>' in out: | |
| # 显示的部分,去除标记 | |
| out = out.split('<|im_end|>')[0] | |
| user_queues[session_id].put(out, block=False) | |
| break | |
| # 设置输出暂停标记(给流式输出函数发送信号) | |
| user_stop_flags[session_id] = True | |
| # 打印对话内容,便于调试错误 | |
| print(user_history_sessions_text[session_id]) | |
| # 消息按钮处理逻辑:发送消息 / 停止生成 / 会话过期 | |
| def send_message(sess, btn_label, user_message, session_id, temperature, repeat_penalty, max_length, decay): | |
| # 如果发现会话过期了,直接返并在按钮上提示 | |
| if session_id not in user_history_sessions_text: | |
| return "", "会话过期!" | |
| # 按钮文本是消息发送,并且用户消息非空 | |
| if btn_label == "发送消息" and user_message: | |
| # chatbot中回显用户消息 | |
| user_tokens_display = process_user_tokens(user_message) | |
| # chatbot中的会话保存到历史消息中 | |
| user_history_sessions_show[session_id] = sess | |
| # 在历史消息中添加用户消息与空的模型输出 | |
| user_history_sessions_show[session_id] += [[user_tokens_display, ""]] | |
| # 创建文本生成的进程 | |
| thread = threading.Thread(target=generate_text, args=(sess, user_message, session_id, temperature, repeat_penalty, max_length, decay)) | |
| # 设置为主进程退出时退出,避免浪费资源 | |
| thread.daemon = True | |
| # 先设置正在生成的标记 | |
| user_stop_flags[session_id] = False | |
| # 再启动文本生成进行 | |
| thread.start() | |
| # 清空输入框,更新按钮文本为停止生成 | |
| return "", "停止生成" | |
| else: | |
| # 会话没过期,也不是发送消息,那就是终止生成,通过种植标记发送信号 | |
| user_stop_flags[session_id] = True | |
| # 这里要保留用户输入的文本,同时更新按钮文本为“发送消息” | |
| return user_message, "发送消息" | |
| # 清空会话方法 | |
| def clear_session(): | |
| # 调试时可以在这里打印性能分析结果 | |
| # profiler.print_stats(stream=sys.stdout) | |
| # 注意力可视化的参数全部清零,避免异常 | |
| return [], gr.update(maximum=0, value=0), gr.update(maximum=0, value=0) | |
| # 通过不断检查模型输出队列并更新chatbot来模拟流式输出 | |
| def stream_output(sess): | |
| global user_queues, user_stop_flags, user_history_sessions_show, user_history_sessions_text | |
| # 页面加载时初始化 session_id 作为用户的唯一标识 | |
| session_id = str(uuid.uuid4()) | |
| # 创建用于流式输出的消息队列 | |
| user_queues[session_id] = Queue() | |
| # 默认是停止状态,等待用户发送消息 | |
| user_stop_flags[session_id] = True | |
| user_history_sessions_show[session_id] = [] # 初始化历史会话记录,用于chatbot | |
| user_history_sessions_text[session_id] = "" # 初始化历史会话记录,用于日志打印 | |
| chatml_idx[session_id] = [] # 会话模板始末位置,用于在注意力可视化时分割区域 | |
| # 设置页面的初始状态 | |
| yield [], "发送消息", session_id, gr.update(maximum=0), gr.update(maximum=0) | |
| # 流式输出循环 | |
| t = time.time() | |
| while True: | |
| # 处理队列中的消息 | |
| if not user_queues[session_id].empty(): | |
| t = time.time() | |
| # 取到最后一个加入的数据 | |
| while user_queues[session_id].qsize() > 1: | |
| user_queues[session_id].get() | |
| out = user_queues[session_id].get() | |
| # 用队列中最新的的模型输出替换当前的模型输出 | |
| sess = user_history_sessions_show[session_id] | |
| sess[-1][1] = out | |
| # 更新按钮文本为“停止生成”,如果发现输出暂停了就恢复为“发送消息” | |
| button_label = "停止生成" if not user_stop_flags.get(session_id, True) else "发送消息" | |
| # 更新注意力可视化部分的滑动条最大长度 | |
| if model.encoder.encoder_layers[-1].multi_head_attention.cnt != None and session_id in model.encoder.encoder_layers[-1].multi_head_attention.cnt: | |
| new_max_tokens = model.encoder.encoder_layers[-1].multi_head_attention.cnt[session_id] | |
| yield sess, button_label, session_id, gr.update(maximum=new_max_tokens), gr.update(maximum=new_max_tokens) | |
| else: | |
| yield sess, button_label, session_id, gr.update(maximum=0), gr.update(maximum=0) | |
| else: | |
| time.sleep(0.1) # 防止 busy-wait 占满 CPU | |
| # 长期没有输出需求,退出循环 | |
| if time.time() - t > 600: | |
| break | |
| # 注意力可视化组件 | |
| def create_attention_visualization(): | |
| """创建桌面端注意力可视化界面""" | |
| gr.Markdown("# 注意力可视化", elem_classes="title") | |
| # 固定在底部的控制区域 | |
| with gr.Row(): | |
| token_slider = gr.Slider( | |
| 0, 0, value=0, step=1, | |
| label="查看第几个token" | |
| ) | |
| highlight_slider = gr.Slider( | |
| 0, 0, value=0, step=1, | |
| label="高亮token位置" | |
| ) | |
| # 可视化信息显示 | |
| info_box = gr.Textbox( | |
| label="进度显示", | |
| lines=2, | |
| max_lines=2, | |
| autoscroll=False, | |
| interactive=False | |
| ) | |
| # 单一的合并热力图 | |
| combined_plot = gr.HTML( | |
| value="<div id='heatmap-container' style='height:600px;width:100%;'></div>", | |
| show_label=False, | |
| container=False | |
| ) | |
| return token_slider, highlight_slider, info_box, combined_plot | |
| # 创建合并的热力图(只有数据) | |
| def create_combined_heatmap(session_id, token_idx=0, highlight_col=None): | |
| """创建合并的单一热力图数据,不包含任何渲染配置""" | |
| # 检查数据有效性 | |
| for layer in model.encoder.encoder_layers: | |
| if (layer.multi_head_attention.attention_matrix == None or | |
| session_id not in layer.multi_head_attention.attention_matrix): | |
| return None | |
| # 获取最大长度 | |
| max_len = max(0, model.encoder.encoder_layers[-1].multi_head_attention.cnt[session_id] + 1) | |
| token_idx = min(token_idx, max_len) | |
| # 构建合并矩阵:12个head层 + 11个分隔层 = 155行 | |
| combined_matrix = [] | |
| # 准备token标签 | |
| if session_id in total_querys: | |
| tokens = [(token2str([token_id]) if token_id > 0 else f'byte{-token_id}') | |
| for token_id in total_querys[session_id]] | |
| tokens = (tokens + [''] * max_len)[:max_len] | |
| else: | |
| tokens = [''] * max_len | |
| # 准备数据 | |
| for layer_n, layer in enumerate(reversed(model.encoder.encoder_layers)): | |
| layer_n = 11 - layer_n # 计算实际的层索引 | |
| # 添加当前层的12个head | |
| if (layer.multi_head_attention.attention_matrix and | |
| session_id in layer.multi_head_attention.attention_matrix): | |
| for head_i in range(11, -1, -1): | |
| head_data = (list(layer.multi_head_attention.attention_matrix[session_id][head_i][token_idx]) + [0] * max_len)[:max_len] | |
| combined_matrix.append(head_data) | |
| # 添加分隔层 | |
| if layer_n > 0: | |
| separator = [1.0] * max_len | |
| combined_matrix.append(separator) | |
| # 返回纯数据字典,不包含任何图形对象 | |
| return { | |
| 'matrix': combined_matrix, | |
| 'tokens': tokens, | |
| 'shape': [len(combined_matrix), max_len], | |
| 'highlight_col': highlight_col, | |
| 'token_idx': token_idx, | |
| 'chatml_regions': chatml_idx.get(session_id, []) | |
| } | |
| # 当前信息可视化 | |
| def update_attention_visualization(session_id, token_idx, highlight_col): | |
| if session_id in total_querys: | |
| info = f"[高亮token:{token2str([total_querys[session_id][highlight_col]])}]{token2str(total_querys[session_id][:token_idx+1])}" | |
| else: | |
| info = "" | |
| data_dict = create_combined_heatmap(session_id, token_idx, highlight_col) | |
| if data_dict and data_dict['matrix']: | |
| # 后端:将对数值量化 | |
| matrix_np = np.array(data_dict['matrix'], dtype=np.float32) | |
| # 将对数值映射到0~65535 | |
| matrix_log = np.log10(matrix_np + 1e-8) # 避免log(0) | |
| # 将对数范围映射到0~1(假设注意力权重在1e-8到1之间) | |
| matrix_norm = (matrix_log - (-8)) / (0 - (-8)) # 从-8到0映射到0~1 | |
| matrix_quantized = (matrix_norm * 65535).astype(np.uint16) | |
| compressed = base64.b64encode(zlib.compress(matrix_quantized.tobytes())).decode('ascii') | |
| # 只传输最基础的数据 | |
| html_content = f""" | |
| <div id="heatmap-container" style="height:600px;width:100%;"></div> | |
| <script id="heatmap-data" | |
| data-compressed="{compressed}" | |
| data-shape='{json.dumps(data_dict["shape"])}' | |
| data-tokens='{json.dumps(data_dict["tokens"])}' | |
| data-highlight-col='{data_dict["highlight_col"]}' | |
| data-token-idx='{data_dict["token_idx"]}' | |
| data-chatml-regions='{json.dumps(data_dict["chatml_regions"])}'> | |
| </script> | |
| """ | |
| return html_content, info | |
| return "<div id='heatmap-container' style='height:600px;width:100%;'></div><script id='heatmap-data'></script>", info | |
| # 检测设备类型 | |
| def check_device(request: gr.Request): | |
| ua_string = parse(request.headers.get("user-agent", "")) | |
| print(str(ua_string)) | |
| return "PC" in str(ua_string)[:2] | |
| def toggle_attention_visibility(is_desktop): | |
| """根据设备类型切换注意力可视化可见性""" | |
| return gr.update(visible=is_desktop) | |
| # UI美化 | |
| css = """ | |
| /* 大标题居中 */ | |
| .title { | |
| text-align: center; | |
| } | |
| /* 高级设置按键样式 */ | |
| #adv-param button { | |
| justify-content: center; | |
| } | |
| #adv-param > button > span { | |
| font-size: 16px !important; | |
| font-weight: 600 !important; | |
| } | |
| /* 用户输入包装样式 */ | |
| .token-box { | |
| display: inline-block; | |
| background-color: #f0f0f0; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| padding: 2px 4px; | |
| margin: 2px; | |
| font-family: monospace; | |
| } | |
| .multi-token-box { | |
| display: inline-block; | |
| background-color: #e6f7ff; | |
| border: 1px solid #91d5ff; | |
| border-radius: 4px; | |
| padding: 2px 4px; | |
| margin: 2px; | |
| font-family: monospace; | |
| } | |
| /* 隐藏悬浮触发的plotly logo */ | |
| .js-plotly-plot .plotly .modebar { | |
| display: none !important; | |
| } | |
| /* 隐藏小垃圾桶 */ | |
| .icon-button-wrapper button[aria-label="清空对话"] { | |
| display: none !important; | |
| } | |
| /* 避免图像更新时闪烁 */ | |
| .html-container.padding.pending { | |
| opacity: 1 !important; | |
| filter: none !important; | |
| } | |
| """ | |
| # 主UI构建 | |
| # profiler.add_function(create_combined_heatmap) | |
| with gr.Blocks(css=css, theme=gr.themes.Default(), head=''' | |
| <link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🔬</text></svg>"> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script> | |
| <script src="https://cdn.plot.ly/plotly-2.24.1.min.js"></script> | |
| <script> | |
| // 全局渲染函数 - 完整版本 | |
| window.renderCompressedHeatmap = function(compressedData, tokens, highlightCol, tokenIdx, chatmlRegions) { | |
| // console.log('开始渲染压缩数据'); | |
| if (!compressedData || !compressedData.data) { | |
| console.error('没有压缩数据'); | |
| return; | |
| } | |
| try { | |
| // Base64解码 | |
| const binaryStr = atob(compressedData.data); | |
| const bytes = new Uint8Array(binaryStr.length); | |
| for (let i = 0; i < binaryStr.length; i++) { | |
| bytes[i] = binaryStr.charCodeAt(i); | |
| } | |
| // zlib解压 | |
| const decompressed = pako.inflate(bytes); | |
| const shape = compressedData.shape; | |
| // 前端:还原对数域 | |
| // 关键修改:使用Uint16Array而不是Float32Array | |
| const uint16Array = new Uint16Array(decompressed.buffer); | |
| const float32Array = new Float32Array(uint16Array.length); | |
| for (let i = 0; i < uint16Array.length; i++) { | |
| const logValue = (uint16Array[i] / 65535.0) * 8 - 8; // 从0~1映射回-8~0 | |
| float32Array[i] = Math.pow(10, logValue); | |
| } | |
| // 重建矩阵 | |
| const matrix = []; | |
| for (let i = 0; i < shape[0]; i++) { | |
| matrix.push(Array.from(float32Array.slice(i * shape[1], (i + 1) * shape[1]))); | |
| } | |
| // 创建自定义悬停信息 | |
| const customData = []; | |
| for (let i = 0; i < shape[0]; i++) { | |
| const rowData = []; | |
| for (let j = 0; j < shape[1]; j++) { | |
| const layer = Math.floor(i / 13); | |
| const head = 11 - (i % 13); // 反转head顺序 | |
| const token = tokens && tokens[j] ? tokens[j] : ''; | |
| rowData.push(`Head:${head}<br>Token:${token}`); | |
| } | |
| customData.push(rowData); | |
| } | |
| // 创建矩形框 | |
| const shapes = []; | |
| // 2. 区域分割矩形(对话模板) | |
| const colors = ["green", "blue"]; | |
| if (chatmlRegions && chatmlRegions.length > 0) { | |
| chatmlRegions.forEach((region, i) => { | |
| if (region && region.length === 2) { | |
| shapes.push({ | |
| type: 'rect', | |
| x0: region[0] - 0.5, | |
| x1: region[1] - 0.5, | |
| y0: -0.5, | |
| y1: shape[0] - 0.5, | |
| line: { | |
| color: colors[i % 2], | |
| width: 2 | |
| }, | |
| fillcolor: 'rgba(0,0,0,0)' | |
| }); | |
| // console.log(`添加区域分割矩形 ${i}:`, region, '颜色:', colors[i % 2]); | |
| } | |
| }); | |
| } | |
| // 1. 高亮矩形 | |
| if (highlightCol !== null && highlightCol !== undefined && highlightCol <= tokenIdx) { | |
| shapes.push({ | |
| type: 'rect', | |
| x0: highlightCol - 0.5, | |
| x1: highlightCol + 0.5, | |
| y0: -0.5, | |
| y1: shape[0] - 0.5, | |
| line: { | |
| color: 'white', | |
| width: 1 | |
| }, | |
| fillcolor: 'rgba(0,0,0,0)' | |
| }); | |
| // console.log('添加高亮矩形,位置:', highlightCol); | |
| } | |
| // 渲染 | |
| const container = document.getElementById('heatmap-container'); | |
| if (container) { | |
| Plotly.purge(container); | |
| const trace = { | |
| z: matrix, | |
| type: 'heatmap', | |
| colorscale: [ | |
| [0.0, '#000000'], | |
| [0.04, '#4d0000'], | |
| [0.16, '#cc0000'], | |
| [0.36, '#ff8800'], | |
| [0.64, '#ffff00'], | |
| [1.0, '#ffffff'] | |
| ], | |
| showscale: false, | |
| customdata: customData, | |
| hovertemplate: '%{customdata}<br>Position:%{x}<br>Attention:%{z:.3f}<extra></extra>', | |
| hoverongaps: false | |
| }; | |
| const layout = { | |
| margin: {l: 30, r: 5, t: 0, b: 0}, | |
| height: 600, | |
| yaxis: { | |
| tickmode: 'array', | |
| tickvals: Array.from({length: 12}, (_, i) => i * 13 + 6), | |
| ticktext: Array.from({length: 12}, (_, i) => `L${11 - i}`), | |
| showgrid: false, | |
| zeroline: false, | |
| fixedrange: true | |
| }, | |
| xaxis: { | |
| showticklabels: false, | |
| showgrid: false, | |
| zeroline: false, | |
| fixedrange: true | |
| }, | |
| shapes: shapes | |
| }; | |
| Plotly.react(container, [trace], layout).then(() => { | |
| // console.log('热力图渲染完成!包含:', | |
| // shapes.length, '个矩形框', | |
| // '- 高亮:', highlightCol !== null ? '有' : '无', | |
| // '- 区域:', chatmlRegions ? chatmlRegions.length : 0 | |
| // ); | |
| }); | |
| } | |
| } catch (error) { | |
| console.error('渲染失败:', error); | |
| } | |
| } | |
| </script> | |
| ''') as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("# 0.18B中文大语言模型", elem_classes="title") | |
| # 左侧:原有功能 | |
| chatbot = gr.Chatbot( | |
| label="对话", | |
| autoscroll=False, | |
| show_copy_button=True, | |
| height=400, | |
| type="tuples" | |
| ) | |
| with gr.Column(): | |
| msg = gr.Textbox( | |
| placeholder="请输入你的问题...", | |
| label="", | |
| lines=3, | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("发送消息") | |
| clear_btn = gr.Button("清空会话") | |
| # 参数设置区域 | |
| with gr.Accordion("高级参数设置", open=False, elem_id="adv-param"): | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0001, 3.0001, value=0.0001, step=0.1, label="Temperature") | |
| repeat_penalty = gr.Slider(0.0, 5.0, value=1.3, step=0.1, label="Repeat Penalty") | |
| with gr.Row(): | |
| max_length = gr.Slider(64, 8192, value=512, step=64, label="Max Length") | |
| decay = gr.Slider(0.90, 1.0, value=0.99, step=0.005, label="Repeat Penalty Decay Rate") | |
| with gr.Column(visible=False) as attention_col: | |
| # 右侧:注意力可视化 - 现在只返回4个值 | |
| (token_slider, highlight_slider, | |
| info_box, combined_plot) = create_attention_visualization() | |
| # 状态存储 | |
| session_id = gr.State() | |
| # 绑定聊天交互事件 | |
| send_btn.click( | |
| send_message, | |
| inputs=[chatbot, send_btn, msg, session_id, temperature, repeat_penalty, max_length, decay], | |
| outputs=[msg, send_btn], | |
| show_progress="hidden" | |
| ) | |
| clear_btn.click( | |
| clear_session, | |
| inputs=[], | |
| outputs=[chatbot, token_slider, highlight_slider], | |
| show_progress="hidden" | |
| ) | |
| # 绑定注意力可视化交互事件 | |
| # 修改当前token | |
| token_event = token_slider.change( | |
| update_attention_visualization, | |
| inputs=[session_id, token_slider, highlight_slider], | |
| outputs=[combined_plot, info_box], | |
| show_progress="hidden", | |
| api_name=False | |
| ).then( | |
| None, | |
| None, | |
| None, | |
| js=""" | |
| function() { | |
| // console.log('数据更新完成,触发渲染'); | |
| const dataScript = document.getElementById('heatmap-data'); | |
| if (dataScript && dataScript.dataset.compressed) { | |
| const compressedData = { | |
| data: dataScript.dataset.compressed, | |
| shape: JSON.parse(dataScript.dataset.shape) | |
| }; | |
| const tokens = JSON.parse(dataScript.dataset.tokens || '[]'); | |
| const highlightCol = dataScript.dataset.highlightCol ? parseInt(dataScript.dataset.highlightCol) : null; | |
| const tokenIdx = dataScript.dataset.tokenIdx ? parseInt(dataScript.dataset.tokenIdx) : 0; | |
| const chatmlRegions = JSON.parse(dataScript.dataset.chatmlRegions || '[]'); | |
| // console.log('找到压缩数据,开始渲染'); | |
| if (window.renderCompressedHeatmap) { | |
| window.renderCompressedHeatmap(compressedData, tokens, highlightCol, tokenIdx, chatmlRegions); | |
| } | |
| } else { | |
| // console.log('未找到压缩数据'); | |
| } | |
| return []; | |
| } | |
| """ | |
| ) | |
| # 修改高亮token | |
| highlight_event = highlight_slider.change( | |
| None, # 不调用Python函数 | |
| inputs=[highlight_slider, info_box], # 输入:滑块值 + 当前info_box内容 | |
| outputs=[info_box], # 输出:更新后的info_box | |
| js=""" | |
| function(highlightCol, currentInfo) { | |
| const dataScript = document.getElementById('heatmap-data'); | |
| if (!dataScript) return [currentInfo]; // 返回原内容 | |
| // 更新高亮字段 | |
| dataScript.dataset.highlightCol = highlightCol; | |
| // 获取必要信息 | |
| const tokens = JSON.parse(dataScript.dataset.tokens || '[]'); | |
| const shape = JSON.parse(dataScript.dataset.shape || '[0,0]'); | |
| const tokenIdx = dataScript.dataset.tokenIdx ? parseInt(dataScript.dataset.tokenIdx) : 0; | |
| const chatmlRegions = JSON.parse(dataScript.dataset.chatmlRegions || '[]'); | |
| let newInfo = currentInfo; | |
| // 1. 更新info_box内容 | |
| if (tokens[highlightCol]) { | |
| newInfo = currentInfo.replace( | |
| /\\[高亮token:[^\\]]*\\]/, | |
| `[高亮token:${tokens[highlightCol]}]` | |
| ); | |
| } | |
| // 2. 更新热力图矩形框 | |
| const container = document.getElementById('heatmap-container'); | |
| if (container) { | |
| const shapes = []; | |
| const totalRows = shape[0]; | |
| // 区域分割矩形 | |
| const colors = ["green", "blue"]; | |
| chatmlRegions.forEach((region, i) => { | |
| if (region && region.length === 2) { | |
| shapes.push({ | |
| type: 'rect', | |
| x0: region[0] - 0.5, x1: region[1] - 0.5, | |
| y0: -0.5, y1: totalRows - 0.5, | |
| line: { color: colors[i % 2], width: 2 }, | |
| fillcolor: 'rgba(0,0,0,0)' | |
| }); | |
| } | |
| }); | |
| // 高亮矩形 | |
| if (highlightCol !== null && highlightCol !== undefined && highlightCol <= tokenIdx) { | |
| shapes.push({ | |
| type: 'rect', | |
| x0: highlightCol - 0.5, x1: highlightCol + 0.5, | |
| y0: -0.5, y1: totalRows - 0.5, | |
| line: { color: 'white', width: 1 }, | |
| fillcolor: 'rgba(0,0,0,0)' | |
| }); | |
| } | |
| // 只更新布局中的矩形框 | |
| Plotly.relayout(container, { shapes: shapes }); | |
| } | |
| return [newInfo]; // 返回更新后的内容 | |
| } | |
| """, | |
| show_progress="hidden", | |
| api_name=False | |
| ) | |
| # 存储设备类型 | |
| is_pc = gr.State() | |
| # 加载时执行 | |
| demo.load( | |
| check_device, | |
| inputs=[], | |
| outputs=[is_pc], | |
| show_progress="hidden" | |
| ).then( | |
| toggle_attention_visibility, | |
| inputs=[is_pc], | |
| outputs=[attention_col], | |
| show_progress="hidden" | |
| ).then( | |
| stream_output, | |
| inputs=[chatbot], | |
| outputs=[chatbot, send_btn, session_id, token_slider, highlight_slider], | |
| show_progress="hidden" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue( | |
| max_size=128, | |
| default_concurrency_limit=128, | |
| api_open=False | |
| ) | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| show_api=False | |
| ) | |