Update app.py
Browse files
app.py
CHANGED
|
@@ -74,8 +74,7 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
| 74 |
+ [*extra_components])
|
| 75 |
|
| 76 |
|
| 77 |
-
|
| 78 |
-
def get_hidden_states(local_state, raw_original_prompt):
|
| 79 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 80 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
@@ -86,8 +85,8 @@ def get_hidden_states(local_state, raw_original_prompt):
|
|
| 86 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 87 |
progress_dummy_output = ''
|
| 88 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
| 89 |
-
local_state.hidden_states = hidden_states.cpu().detach()
|
| 90 |
-
return [progress_dummy_output,
|
| 91 |
|
| 92 |
|
| 93 |
@spaces.GPU
|
|
@@ -216,10 +215,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 216 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
| 217 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 218 |
) for i in range(MAX_NUM_LAYERS)]
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
local_state = gr.State(global_state.local_state)
|
| 222 |
-
|
| 223 |
# event listeners
|
| 224 |
for i, btn in enumerate(tokens_container):
|
| 225 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
|
@@ -228,8 +224,8 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 228 |
], [progress_dummy, *interpretation_bubbles])
|
| 229 |
|
| 230 |
original_prompt_btn.click(get_hidden_states,
|
| 231 |
-
[
|
| 232 |
-
[progress_dummy,
|
| 233 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 234 |
|
| 235 |
extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
|
|
|
|
| 74 |
+ [*extra_components])
|
| 75 |
|
| 76 |
|
| 77 |
+
def get_hidden_states(raw_original_prompt):
|
|
|
|
| 78 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 79 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 80 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
|
| 85 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 86 |
progress_dummy_output = ''
|
| 87 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
| 88 |
+
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
| 89 |
+
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 90 |
|
| 91 |
|
| 92 |
@spaces.GPU
|
|
|
|
| 215 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
| 216 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 217 |
) for i in range(MAX_NUM_LAYERS)]
|
| 218 |
+
|
|
|
|
|
|
|
|
|
|
| 219 |
# event listeners
|
| 220 |
for i, btn in enumerate(tokens_container):
|
| 221 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
|
|
|
| 224 |
], [progress_dummy, *interpretation_bubbles])
|
| 225 |
|
| 226 |
original_prompt_btn.click(get_hidden_states,
|
| 227 |
+
[original_prompt_raw],
|
| 228 |
+
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 229 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 230 |
|
| 231 |
extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
|