Spaces:
Build error
Build error
| from __future__ import annotations | |
| import glob | |
| import io | |
| import os | |
| import random | |
| import struct | |
| from contextlib import contextmanager | |
| from html import escape | |
| import msgpack | |
| import streamlit as st | |
| import torch | |
| import tqdm | |
| from huggingface_hub import HfFileSystem | |
| from transformers import AutoTokenizer | |
| st.set_page_config(layout="wide") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "MonetLLM/codemonet-vd-1.4B-100BT-hf") | |
| CONTEXT_WINDOW = int(os.environ.get("CONTEXT_WINDOW", "12")) | |
| CANDIDATE_THRESHOLD = int(os.environ.get("CANDIDATE_THRESHOLD", "50")) | |
| HORIZONTAL_STYLE = """<style class="hide-element"> | |
| /* Hides the style container and removes the extra spacing */ | |
| .element-container:has(.hide-element) { | |
| display: none; | |
| } | |
| /* | |
| The selector for >.element-container is necessary to avoid selecting the whole | |
| body of the streamlit app, which is also a stVerticalBlock. | |
| */ | |
| div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) { | |
| display: flex; | |
| flex-direction: row !important; | |
| flex-wrap: wrap; | |
| gap: 0.5rem; | |
| align-items: baseline; | |
| } | |
| /* Buttons and their parent container all have a width of 704px, which we need to override */ | |
| div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div { | |
| width: max-content !important; | |
| } | |
| /* Just an example of how you would style buttons, if desired */ | |
| /* | |
| div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) button { | |
| border-color: red; | |
| } | |
| */ | |
| </style>""" | |
| def prepare_routing_resources(): | |
| fs = HfFileSystem() | |
| for filename in fs.glob(f"datasets/{MODEL_NAME}-viewer-data/*"): | |
| if not os.path.exists(os.path.basename(filename)): | |
| print(f"[*] Download {filename}...") | |
| fs.download(filename, ".") | |
| input_tokens = torch.load("inputs.pt") | |
| examples_tables = [] | |
| for i in tqdm.trange(len(glob.glob("examples-*.msgpack"))): | |
| with open(f"examples-{i}.msgpack", "rb") as fp: | |
| fp.seek(-4, io.SEEK_END) | |
| table_size = struct.unpack(">I", fp.read(4))[0] | |
| fp.seek(-(table_size + 4), io.SEEK_END) | |
| examples_tables.append(msgpack.Unpacker(fp).unpack()) | |
| candidates = [] | |
| for i, table in enumerate(tqdm.tqdm(examples_tables)): | |
| candidates.append([]) | |
| with open(f"examples-{i}.msgpack", "rb") as fp: | |
| unpacker = msgpack.Unpacker(fp) | |
| for j in range(len(table)): | |
| if len(unpacker.unpack()) > CANDIDATE_THRESHOLD: | |
| candidates[-1].append(j) | |
| routing_tables = [] | |
| for i in tqdm.trange(len(examples_tables)): | |
| with open(f"routings-{i}.msgpack", "rb") as fp: | |
| fp.seek(-4, io.SEEK_END) | |
| table_size = struct.unpack(">I", fp.read(4))[0] | |
| fp.seek(-(table_size + 4), io.SEEK_END) | |
| routing_tables.append(msgpack.Unpacker(fp).unpack()) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| return input_tokens, examples_tables, routing_tables, candidates, tokenizer | |
| input_tokens, examples_tables, routing_tables, candidates, tokenizer = ( | |
| prepare_routing_resources() | |
| ) | |
| def render_routing_examples_in_html(router_index: int, expert_id: int) -> str: | |
| with open(f"examples-{router_index}.msgpack", "rb") as fp: | |
| fp.seek(examples_tables[router_index][expert_id]) | |
| examples = msgpack.Unpacker(fp).unpack() | |
| with open(f"routings-{router_index}.msgpack", "rb") as fp: | |
| table = [] | |
| for i, j, _ in examples: | |
| start = max(j - CONTEXT_WINDOW, 0) | |
| end = min(j + CONTEXT_WINDOW, len(routing_tables[router_index][i])) | |
| fp.seek(routing_tables[router_index][i][start]) | |
| unpacker = msgpack.Unpacker(fp, strict_map_key=False) | |
| activated = [unpacker.unpack().get(expert_id, 0) for _ in range(start, end)] | |
| full_text = tokenizer.decode(input_tokens[i]) | |
| encodings = tokenizer(full_text, add_special_tokens=False) | |
| offset = len(encodings.input_ids) - input_tokens.size(1) | |
| spans, lslice = [], None | |
| for k in range(start, end): | |
| if offset + k >= 0 and (sslice := encodings.token_to_chars(offset + k)): | |
| span, score = full_text[slice(*sslice)], activated[k - start] | |
| if lslice == sslice: | |
| score = max(spans.pop(-1)[1], score) | |
| spans.append((escape(span), score)) | |
| lslice = sslice | |
| spans = [ | |
| f"<span style='background-color: rgba(144, 238, 144, {score}' title='Routing: {score*100:.2f}%'>{span}</span>" | |
| for span, score in spans | |
| ] | |
| table.append( | |
| f""" | |
| <tr> | |
| <td align='right'> | |
| <span style='font-weight: bold'> | |
| {escape(tokenizer.decode(input_tokens[i, j]))} ({activated[j - start] * 100:.2f}%) | |
| </span> | |
| </td> | |
| <td align='left'> | |
| (...) {"".join(spans)} (...) | |
| </td> | |
| <td align='right'> | |
| ({i}, {j}) | |
| </td> | |
| </tr> | |
| """ | |
| ) | |
| return f""" | |
| <div style='background-color: white; color: black; padding: 1em 3em; font-size: 12pt'> | |
| <h2 style='font-size: 18pt'> Activated Examples of Group {router_index} / Expert {expert_id} </h2> | |
| <table> | |
| {"".join(table)} | |
| </table> | |
| </div> | |
| """ | |
| def st_horizontal(): | |
| st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True) | |
| with st.container(): | |
| st.markdown( | |
| '<span class="hide-element horizontal-marker"></span>', | |
| unsafe_allow_html=True, | |
| ) | |
| yield | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| router_groups = [f"Routing Group {i}" for i in range(len(examples_tables))] | |
| router_index = st.selectbox("Expert Routing Group", router_groups, index=4) | |
| with col2: | |
| expert_id = st.number_input("Expert Index", 0, len(examples_tables[0]), 52338) | |
| with st_horizontal(): | |
| show_btn = st.button("Show") | |
| random_btn = st.button("Random") | |
| if show_btn or random_btn: | |
| router_index = router_groups.index(router_index) | |
| if random_btn: | |
| expert_id = random.choice(candidates[router_index]) | |
| st.html(render_routing_examples_in_html(router_index, expert_id)) | |