Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| from transformers import AutoModelForCausalLM | |
| import pandas as pd | |
| import gradio as gr | |
| # Build model & get some layers | |
| tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') | |
| m = AutoModelForCausalLM.from_pretrained("lora-x/backpack-gpt2", trust_remote_code=True) | |
| m.eval() | |
| lm_head = m.get_lm_head() # (V, d) | |
| word_embeddings = m.backpack.get_word_embeddings() # (V, d) | |
| sense_network = m.backpack.get_sense_network() # (V, nv, d) | |
| num_senses = m.backpack.get_num_senses() | |
| sense_names = [i for i in range(num_senses)] | |
| """ | |
| Single token sense lookup | |
| """ | |
| def visualize_word(word, count=10, remove_space=False): | |
| if not remove_space: | |
| word = ' ' + word | |
| print(f"Looking up word '{word}'...") | |
| token_ids = tokenizer(word)['input_ids'] | |
| tokens = [tokenizer.decode(token_id) for token_id in token_ids] | |
| tokens = ", ".join(tokens) # display tokenization for user | |
| print(f"Tokenized as: {tokens}") | |
| # look up sense vectors only for the first token | |
| # contents = vecs[token_ids[0]] # torch.Size([16, 768]) | |
| sense_input_embeds = word_embeddings(torch.tensor([token_ids[0]]).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim | |
| senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d) | |
| senses = torch.squeeze(senses) # (nv, s=1, d) | |
| # for pos and neg respectively, create a list (for each sense) of list (top k) of tuples (word, logit) | |
| pos_word_lists = [] | |
| neg_word_lists = [] | |
| sense_names = [] # column header | |
| for i in range(senses.shape[0]): | |
| logits = lm_head(senses[i,:]) | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| sense_names.append('sense {}'.format(i)) | |
| pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)] | |
| pos_sorted_logits = [sorted_logits[j].item() for j in range(count)] | |
| pos_word_lists.append(list(zip(pos_sorted_words, pos_sorted_logits))) | |
| neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)] | |
| neg_sorted_logits = [sorted_logits[-j-1].item() for j in range(count)] | |
| neg_word_lists.append(list(zip(neg_sorted_words, neg_sorted_logits))) | |
| def create_dataframe(word_lists, sense_names, count): | |
| data = dict(zip(sense_names, word_lists)) | |
| df = pd.DataFrame(index=[i for i in range(count)], | |
| columns=list(data.keys())) | |
| for prop, word_list in data.items(): | |
| for i, word_pair in enumerate(word_list): | |
| cell_value = "space ({:.2f})".format(word_pair[1]) | |
| cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1]) | |
| df.at[i, prop] = cell_value | |
| return df | |
| pos_df = create_dataframe(pos_word_lists, sense_names, count) | |
| neg_df = create_dataframe(neg_word_lists, sense_names, count) | |
| return pos_df, neg_df, tokens | |
| """ | |
| Returns: | |
| - tokens: the tokenization of the input sentence, also used as options to choose from for get_token_contextual_weights | |
| - top_k_words_df: a dataframe of the top k words predicted by the model | |
| - length: of the input sentence, stored as a gr.State variable so other methods can find the | |
| contextualization weights for the *last* token that's needed | |
| - contextualization_weights: gr.State variable, stores the contextualization weights for the input sentence | |
| """ | |
| def predict_next_word (sentence, top_k = 5, contextualization_weights = None): | |
| if sentence == "": | |
| return None, None, None, None | |
| # For better tokenization, by default, adds a space at the beginning of the sentence if it doesn't already have one | |
| # and remove trailing space | |
| sentence = sentence.strip() | |
| if sentence[0] != ' ': | |
| sentence = ' ' + sentence | |
| print(f"Sentence: '{sentence}'") | |
| # Make input, keeping track of original length | |
| token_ids = tokenizer(sentence)['input_ids'] | |
| tokens = [[tokenizer.decode(token_id) for token_id in token_ids]] # a list of a single list because used as dataframe | |
| length = len(token_ids) | |
| inp = torch.zeros((1,512)).long() | |
| inp[0,:length] = torch.tensor(token_ids).long() | |
| # Get output at correct index | |
| if contextualization_weights is None: | |
| print("contextualization_weights IS None, freshly computing contextualization_weights") | |
| output = m(inp) | |
| logits, contextualization_weights = output.logits[0,length-1,:], output.contextualization | |
| # Store contextualization weights and return it as a gr.State var for use by get_token_contextual_weights | |
| else: | |
| print("contextualization_weights is NOT None, using passed in contextualization_weights") | |
| output = m.run_with_custom_contextualization(inp, contextualization_weights) | |
| logits = output.logits[0,length-1,:] | |
| probs = logits.softmax(dim=-1) # probs over next word | |
| probs, indices = torch.sort(probs, descending=True) | |
| top_k_words = [(tokenizer.decode(indices[i]), round(probs[i].item(), 3)) for i in range(top_k)] | |
| top_k_words_df = pd.DataFrame(top_k_words, columns=['word', 'probability'], index=range(1, top_k+1)) | |
| top_k_words_df = top_k_words_df.T | |
| print(top_k_words_df) | |
| return tokens, top_k_words_df, length, contextualization_weights | |
| """ | |
| Returns a dataframe of senses with weights for the selected token. | |
| Args: | |
| contextualization_weights: a gr.State variable that stores the contextualization weights for the input sentence. | |
| length: length of the input sentence, used to get the contextualization weights for the last token | |
| token: the selected token | |
| token_index: the index of the selected token in the input sentence | |
| pos_count: how many top positive words to display for each sense | |
| neg_count: how many top negative words to display for each sense | |
| """ | |
| def get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count = 5, neg_count = 3): | |
| print(">>>>>in get_token_contextual_weights") | |
| print(f"Selected {token_index}th token: {token}") | |
| # get contextualization weights for the selected token | |
| # Only care about the weights for the last word, since that's what contributes to the output | |
| token_contextualization_weights = contextualization_weights[0, :, length-1, token_index] | |
| token_contextualization_weights_list = [round(x, 3) for x in token_contextualization_weights.tolist()] | |
| # get sense vectors of the selected token | |
| token_ids = tokenizer(token)['input_ids'] # keep as a list bc sense_network expects s dim | |
| sense_input_embeds = word_embeddings(torch.tensor(token_ids).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim | |
| senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d) | |
| senses = torch.squeeze(senses) # (nv, s=1, d) | |
| # build dataframe | |
| pos_dfs, neg_dfs = [], [] | |
| for i in range(num_senses): | |
| logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257] | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(pos_count)] | |
| pos_df = pd.DataFrame(pos_sorted_words, columns=["Sense {}".format(i)]) | |
| pos_dfs.append(pos_df) | |
| neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(neg_count)] | |
| neg_df = pd.DataFrame(neg_sorted_words, columns=["Top Negative"]) | |
| neg_dfs.append(neg_df) | |
| sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \ | |
| sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \ | |
| sense12words, sense13words, sense14words, sense15words = pos_dfs | |
| sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, \ | |
| sense6negwords, sense7negwords, sense8negwords, sense9negwords, sense10negwords, sense11negwords, \ | |
| sense12negwords, sense13negwords, sense14negwords, sense15negwords = neg_dfs | |
| sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \ | |
| sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \ | |
| sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list | |
| return token, token_index, \ | |
| sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, \ | |
| sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \ | |
| sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, \ | |
| sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, \ | |
| sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \ | |
| sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider | |
| """ | |
| Wrapper for when the user selects a new token in the tokens dataframe. | |
| Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights. | |
| """ | |
| def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, pos_count = 5, neg_count = 3): | |
| print(">>>>>in new_token_contextual_weights") | |
| token_index = evt.index[1] # selected token is the token_index-th token in the sentence | |
| token = evt.value | |
| if not token: | |
| return None, None, \ | |
| None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \ | |
| None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \ | |
| None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None | |
| return get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count, neg_count) | |
| def change_sense0_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 0, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense1_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 1, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense2_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 2, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense3_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 3, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense4_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 4, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense5_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 5, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense6_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 6, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense7_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 7, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense8_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 8, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense9_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 9, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense10_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 10, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense11_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 11, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense12_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 12, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense13_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 13, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense14_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 14, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| def change_sense15_weight(contextualization_weights, length, token_index, new_weight): | |
| contextualization_weights[0, 15, length-1, token_index] = new_weight | |
| return contextualization_weights | |
| """ | |
| Clears all gr.State variables used to store info across methods when the input sentence changes. | |
| """ | |
| def clear_states(contextualization_weights, token_index, length): | |
| contextualization_weights = None | |
| token_index = None | |
| length = 0 | |
| return contextualization_weights, token_index, length | |
| def reset_weights(contextualization_weights): | |
| print("Resetting weights...") | |
| contextualization_weights = None | |
| return contextualization_weights | |
| with gr.Blocks( theme = gr.themes.Base(), | |
| css = """#sense0slider, #sense1slider, #sense2slider, #sense3slider, #sense4slider, #sense5slider, #sense6slider, #sense7slider, | |
| #sense8slider, #sense9slider, #sense1slider0, #sense11slider, #sense12slider, #sense13slider, #sense14slider, #sense15slider | |
| { height: 200px; width: 200px; transform: rotate(270deg); }""" | |
| ) as demo: | |
| gr.Markdown(""" | |
| ## Backpack Sense Visualization | |
| """) | |
| with gr.Tab("Language Modeling"): | |
| contextualization_weights = gr.State(None) # store session data for sharing between functions | |
| token_index = gr.State(None) | |
| length = gr.State(0) | |
| top_k = gr.State(10) | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| input_sentence = gr.Textbox(label="Input Sentence", placeholder='Enter a sentence and click "Predict next word". Then, you can go to the Tokens section, click on a token, and see its contextualization weights.') | |
| with gr.Column(scale=1): | |
| predict = gr.Button(value="Predict next word", variant="primary") | |
| reset_weights_button = gr.Button("Reset weights") | |
| gr.Markdown("""#### Top-k predicted next word""") | |
| top_k_words = gr.Dataframe(interactive=False) | |
| gr.Markdown("""### **Token Breakdown:** click on a token below to see its senses and contextualization weights""") | |
| tokens = gr.DataFrame() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| selected_token = gr.Textbox(label="Current Selected Token", interactive=False) | |
| with gr.Column(scale=8): | |
| gr.Markdown("""#### | |
| Once a token is chosen, you can **use the sliders below to change the weight of any sense or multiple senses** for that token, \ | |
| and then click "Predict next word" to see updated next-word predictions. Erase all changes with "Reset weights". | |
| """) | |
| # sense sliders and top sense words dataframes | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense0slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 0", elem_id="sense0slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense1slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 1", elem_id="sense1slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense2slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 2", elem_id="sense2slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense3slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 3", elem_id="sense3slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense4slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 4", elem_id="sense4slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense5slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 5", elem_id="sense5slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense6slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 6", elem_id="sense6slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense7slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 7", elem_id="sense7slider", interactive=True) | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense0words = gr.DataFrame(headers = ["Sense 0"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense1words = gr.DataFrame(headers = ["Sense 1"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense2words = gr.DataFrame(headers = ["Sense 2"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense3words = gr.DataFrame(headers = ["Sense 3"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense4words = gr.DataFrame(headers = ["Sense 4"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense5words = gr.DataFrame(headers = ["Sense 5"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense6words = gr.DataFrame(headers = ["Sense 6"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense7words = gr.DataFrame(headers = ["Sense 7"]) | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense0negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense1negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense2negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense3negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense4negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense5negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense6negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense7negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense9slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 9", elem_id="sense9slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense10slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 10", elem_id="sense1slider0", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense11slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 11", elem_id="sense11slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense12slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 12", elem_id="sense12slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense13slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 13", elem_id="sense13slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense14slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 14", elem_id="sense14slider", interactive=True) | |
| with gr.Column(scale=0, min_width=120): | |
| sense15slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 15", elem_id="sense15slider", interactive=True) | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense8words = gr.DataFrame(headers = ["Sense 8"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense9words = gr.DataFrame(headers = ["Sense 9"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense10words = gr.DataFrame(headers = ["Sense 10"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense11words = gr.DataFrame(headers = ["Sense 11"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense12words = gr.DataFrame(headers = ["Sense 12"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense13words = gr.DataFrame(headers = ["Sense 13"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense14words = gr.DataFrame(headers = ["Sense 14"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense15words = gr.DataFrame(headers = ["Sense 15"]) | |
| with gr.Row(): | |
| with gr.Column(scale=0, min_width=120): | |
| sense8negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense9negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense10negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense11negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense12negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense13negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense14negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| with gr.Column(scale=0, min_width=120): | |
| sense15negwords = gr.DataFrame(headers = ["Top Negative"]) | |
| gr.Markdown("""Note: **"Top Negative"** shows words that have the most negative dot products with the sense vector, which can exhibit more coherent meaning than those with the most positive dot products. | |
| To see more representative words of each sense, scroll to the top and use the **"Individual Word Sense Look Up"** tab.""") | |
| # gr.Examples( | |
| # examples=[["Messi plays for", top_k, None]], | |
| # inputs=[input_sentence, top_k, contextualization_weights], | |
| # outputs=[tokens, top_k_words, length, contextualization_weights], | |
| # fn=predict_next_word, | |
| # ) | |
| sense0slider.change(fn=change_sense0_weight, | |
| inputs=[contextualization_weights, length, token_index, sense0slider], | |
| outputs=[contextualization_weights]) | |
| sense1slider.change(fn=change_sense1_weight, | |
| inputs=[contextualization_weights, length, token_index, sense1slider], | |
| outputs=[contextualization_weights]) | |
| sense2slider.change(fn=change_sense2_weight, | |
| inputs=[contextualization_weights, length, token_index, sense2slider], | |
| outputs=[contextualization_weights]) | |
| sense3slider.change(fn=change_sense3_weight, | |
| inputs=[contextualization_weights, length, token_index, sense3slider], | |
| outputs=[contextualization_weights]) | |
| sense4slider.change(fn=change_sense4_weight, | |
| inputs=[contextualization_weights, length, token_index, sense4slider], | |
| outputs=[contextualization_weights]) | |
| sense5slider.change(fn=change_sense5_weight, | |
| inputs=[contextualization_weights, length, token_index, sense5slider], | |
| outputs=[contextualization_weights]) | |
| sense6slider.change(fn=change_sense6_weight, | |
| inputs=[contextualization_weights, length, token_index, sense6slider], | |
| outputs=[contextualization_weights]) | |
| sense7slider.change(fn=change_sense7_weight, | |
| inputs=[contextualization_weights, length, token_index, sense7slider], | |
| outputs=[contextualization_weights]) | |
| sense8slider.change(fn=change_sense8_weight, | |
| inputs=[contextualization_weights, length, token_index, sense8slider], | |
| outputs=[contextualization_weights]) | |
| sense9slider.change(fn=change_sense9_weight, | |
| inputs=[contextualization_weights, length, token_index, sense9slider], | |
| outputs=[contextualization_weights]) | |
| sense10slider.change(fn=change_sense10_weight, | |
| inputs=[contextualization_weights, length, token_index, sense10slider], | |
| outputs=[contextualization_weights]) | |
| sense11slider.change(fn=change_sense11_weight, | |
| inputs=[contextualization_weights, length, token_index, sense11slider], | |
| outputs=[contextualization_weights]) | |
| sense12slider.change(fn=change_sense12_weight, | |
| inputs=[contextualization_weights, length, token_index, sense12slider], | |
| outputs=[contextualization_weights]) | |
| sense13slider.change(fn=change_sense13_weight, | |
| inputs=[contextualization_weights, length, token_index, sense13slider], | |
| outputs=[contextualization_weights]) | |
| sense14slider.change(fn=change_sense14_weight, | |
| inputs=[contextualization_weights, length, token_index, sense14slider], | |
| outputs=[contextualization_weights]) | |
| sense15slider.change(fn=change_sense15_weight, | |
| inputs=[contextualization_weights, length, token_index, sense15slider], | |
| outputs=[contextualization_weights]) | |
| predict.click( | |
| fn=predict_next_word, | |
| inputs = [input_sentence, top_k, contextualization_weights], | |
| outputs= [tokens, top_k_words, length, contextualization_weights], | |
| ) | |
| tokens.select(fn=new_token_contextual_weights, | |
| inputs=[contextualization_weights, length], | |
| outputs= [selected_token, token_index, | |
| sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, | |
| sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, | |
| sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, | |
| sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, | |
| sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, | |
| sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider] | |
| ) | |
| reset_weights_button.click( | |
| fn=reset_weights, | |
| inputs=[contextualization_weights], | |
| outputs=[contextualization_weights] | |
| ).success( | |
| fn=predict_next_word, | |
| inputs = [input_sentence, top_k, contextualization_weights], | |
| outputs= [tokens, top_k_words, length, contextualization_weights], | |
| ).success( | |
| fn=get_token_contextual_weights, | |
| inputs=[contextualization_weights, length, selected_token, token_index], | |
| outputs= [selected_token, token_index, | |
| sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, | |
| sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, | |
| sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, | |
| sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, | |
| sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, | |
| sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider] | |
| ) | |
| input_sentence.change( | |
| fn=clear_states, | |
| inputs=[contextualization_weights, token_index, length], | |
| outputs=[contextualization_weights, token_index, length] | |
| ) | |
| with gr.Tab("Individual Word Sense Look Up"): | |
| gr.Markdown("""> Note on tokenization: Backpack uses the GPT-2 tokenizer, which includes the space before a word as part \ | |
| of the token, so by default, a space character `' '` is added to the beginning of the word \ | |
| you look up. You can disable this by checking `Remove space before word`, but know this might \ | |
| cause strange behaviors like breaking `afraid` into `af` and `raid`, or `slight` into `s` and `light`. | |
| """) | |
| with gr.Row(): | |
| word = gr.Textbox(label="Word", placeholder="e.g. science") | |
| token_breakdown = gr.Textbox(label="Token Breakdown (senses are for the first token only)") | |
| remove_space = gr.Checkbox(label="Remove space before word", default=False) | |
| count = gr.Slider(minimum=1, maximum=20, value=10, label="Top K", step=1) | |
| look_up_button = gr.Button("Look up") | |
| pos_outputs = gr.Dataframe(label="Highest Scoring Senses") | |
| neg_outputs = gr.Dataframe(label="Lowest Scoring Senses") | |
| gr.Examples( | |
| examples=["science", "afraid", "book", "slight"], | |
| inputs=[word], | |
| outputs=[pos_outputs, neg_outputs, token_breakdown], | |
| fn=visualize_word, | |
| cache_examples=True, | |
| ) | |
| look_up_button.click( | |
| fn=visualize_word, | |
| inputs= [word, count, remove_space], | |
| outputs= [pos_outputs, neg_outputs, token_breakdown], | |
| ) | |
| demo.launch() | |
| # Code for generating slider functions & event listners | |
| # for i in range(16): | |
| # print( | |
| # f"""def change_sense{i}_weight(contextualization_weights, length, token_index, new_weight): | |
| # print(f"Changing weight for the {i}th sense of the {{token_index}}th token.") | |
| # print("new_weight to be assigned = ", new_weight) | |
| # contextualization_weights[0, {i}, length-1, token_index] = new_weight | |
| # print("contextualization_weights: ", contextualization_weights[0, :, length-1, token_index]) | |
| # return contextualization_weights""" | |
| # ) | |
| # for i in range(16): | |
| # print( | |
| # f""" sense{i}slider.change(fn=change_sense{i}_weight, | |
| # inputs=[contextualization_weights, length, token_index, sense{i}slider], | |
| # outputs=[contextualization_weights])""" | |
| # ) |