Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ from interpret import InterpretationPrompt
|
|
| 11 |
|
| 12 |
MAX_PROMPT_TOKENS = 60
|
| 13 |
|
|
|
|
| 14 |
## info
|
| 15 |
dataset_info = [
|
| 16 |
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
|
@@ -56,7 +57,7 @@ suggested_interpretation_prompts = ["Sure, I'll summarize your message:", "Sure,
|
|
| 56 |
def initialize_gpu():
|
| 57 |
pass
|
| 58 |
|
| 59 |
-
def get_hidden_states(raw_original_prompt
|
| 60 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
| 61 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 62 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
|
@@ -65,17 +66,13 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
| 65 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 66 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 67 |
progress_dummy_output = ''
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@spaces.GPU
|
| 72 |
-
def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
|
| 73 |
-
return interpret_prompt.generate(*args, **kwargs)
|
| 74 |
|
| 75 |
|
| 76 |
@spaces.GPU
|
| 77 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 78 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty,
|
| 79 |
num_beams=1):
|
| 80 |
|
| 81 |
interpreted_vectors = global_state[:, i]
|
|
@@ -98,8 +95,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
| 98 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 99 |
|
| 100 |
# generate the interpretations
|
| 101 |
-
generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
| 102 |
-
generated = generate(
|
| 103 |
generation_texts = tokenizer.batch_decode(generated)
|
| 104 |
progress_dummy_output = ''
|
| 105 |
return ([progress_dummy_output] +
|
|
@@ -187,14 +184,14 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 187 |
|
| 188 |
with gr.Group():
|
| 189 |
original_prompt_raw.render()
|
| 190 |
-
original_prompt_btn = gr.Button('
|
| 191 |
|
| 192 |
tokens_container = []
|
| 193 |
with gr.Row():
|
|
|
|
| 194 |
for i in range(MAX_PROMPT_TOKENS):
|
| 195 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
| 196 |
tokens_container.append(btn)
|
| 197 |
-
use_gpu = False # gr.Checkbox(value=False, label='Use GPU')
|
| 198 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 199 |
|
| 200 |
|
|
@@ -226,12 +223,12 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 226 |
|
| 227 |
# event listeners
|
| 228 |
for i, btn in enumerate(tokens_container):
|
| 229 |
-
btn.click(partial(run_interpretation, i=i
|
| 230 |
num_tokens, do_sample, temperature,
|
| 231 |
top_k, top_p, repetition_penalty, length_penalty,
|
| 232 |
], [progress_dummy, *interpretation_bubbles])
|
| 233 |
|
| 234 |
original_prompt_btn.click(get_hidden_states,
|
| 235 |
[original_prompt_raw],
|
| 236 |
-
[progress_dummy, global_state, *tokens_container])
|
| 237 |
demo.launch()
|
|
|
|
| 11 |
|
| 12 |
MAX_PROMPT_TOKENS = 60
|
| 13 |
|
| 14 |
+
|
| 15 |
## info
|
| 16 |
dataset_info = [
|
| 17 |
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
|
|
|
| 57 |
def initialize_gpu():
|
| 58 |
pass
|
| 59 |
|
| 60 |
+
def get_hidden_states(raw_original_prompt):
|
| 61 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
| 62 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 63 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
|
|
|
| 66 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 67 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 68 |
progress_dummy_output = ''
|
| 69 |
+
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_PROMPT_TOKENS)]
|
| 70 |
+
return [progress_dummy_output, hidden_states, *token_btns, *invisible_bubbles]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
@spaces.GPU
|
| 74 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 75 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 76 |
num_beams=1):
|
| 77 |
|
| 78 |
interpreted_vectors = global_state[:, i]
|
|
|
|
| 95 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 96 |
|
| 97 |
# generate the interpretations
|
| 98 |
+
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
| 99 |
+
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 100 |
generation_texts = tokenizer.batch_decode(generated)
|
| 101 |
progress_dummy_output = ''
|
| 102 |
return ([progress_dummy_output] +
|
|
|
|
| 184 |
|
| 185 |
with gr.Group():
|
| 186 |
original_prompt_raw.render()
|
| 187 |
+
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
| 188 |
|
| 189 |
tokens_container = []
|
| 190 |
with gr.Row():
|
| 191 |
+
gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
|
| 192 |
for i in range(MAX_PROMPT_TOKENS):
|
| 193 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
| 194 |
tokens_container.append(btn)
|
|
|
|
| 195 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 196 |
|
| 197 |
|
|
|
|
| 223 |
|
| 224 |
# event listeners
|
| 225 |
for i, btn in enumerate(tokens_container):
|
| 226 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
| 227 |
num_tokens, do_sample, temperature,
|
| 228 |
top_k, top_p, repetition_penalty, length_penalty,
|
| 229 |
], [progress_dummy, *interpretation_bubbles])
|
| 230 |
|
| 231 |
original_prompt_btn.click(get_hidden_states,
|
| 232 |
[original_prompt_raw],
|
| 233 |
+
[progress_dummy, global_state, *tokens_container, *interpretation_bubbles])
|
| 234 |
demo.launch()
|