Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausa
|
|
| 13 |
from interpret import InterpretationPrompt
|
| 14 |
|
| 15 |
MAX_PROMPT_TOKENS = 60
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
## info
|
|
@@ -102,7 +103,7 @@ def get_hidden_states(raw_original_prompt):
|
|
| 102 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 103 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 104 |
progress_dummy_output = ''
|
| 105 |
-
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(
|
| 106 |
global_state.hidden_states = hidden_states
|
| 107 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 108 |
|
|
@@ -136,9 +137,9 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
| 136 |
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 137 |
generation_texts = tokenizer.batch_decode(generated)
|
| 138 |
progress_dummy_output = ''
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
|
| 143 |
|
| 144 |
## main
|
|
@@ -235,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 235 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 236 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
| 237 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 238 |
-
) for i in range(
|
| 239 |
|
| 240 |
|
| 241 |
# event listeners
|
|
|
|
| 13 |
from interpret import InterpretationPrompt
|
| 14 |
|
| 15 |
MAX_PROMPT_TOKENS = 60
|
| 16 |
+
MAX_NUM_LAYERS = 50
|
| 17 |
|
| 18 |
|
| 19 |
## info
|
|
|
|
| 103 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 104 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 105 |
progress_dummy_output = ''
|
| 106 |
+
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
| 107 |
global_state.hidden_states = hidden_states
|
| 108 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 109 |
|
|
|
|
| 137 |
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 138 |
generation_texts = tokenizer.batch_decode(generated)
|
| 139 |
progress_dummy_output = ''
|
| 140 |
+
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
|
| 141 |
+
bubble_outputs += [gr.Textbox(visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
|
| 142 |
+
return [progress_dummy_output, *bubble_outputs]
|
| 143 |
|
| 144 |
|
| 145 |
## main
|
|
|
|
| 236 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 237 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
| 238 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 239 |
+
) for i in range(MAX_NUM_LAYERS)]
|
| 240 |
|
| 241 |
|
| 242 |
# event listeners
|