Update app.py
Browse files
app.py
CHANGED
|
@@ -29,11 +29,12 @@ class GlobalState:
|
|
| 29 |
local_state : LocalState = LocalState()
|
| 30 |
wait_with_hidden_state : bool = False
|
| 31 |
interpretation_prompt_template : str = '{prompt}'
|
| 32 |
-
original_prompt_template : str = 'User: [X]\n\
|
| 33 |
layers_format : str = 'model.layers.{k}'
|
| 34 |
|
| 35 |
|
| 36 |
suggested_interpretation_prompts = [
|
|
|
|
| 37 |
"Sure, here's a bullet list of the key words in your message:",
|
| 38 |
"Sure, I'll summarize your message:",
|
| 39 |
"Sure, here are the words in your message:",
|
|
@@ -139,7 +140,7 @@ global_state = GlobalState()
|
|
| 139 |
|
| 140 |
model_name = 'LLAMA2-7B'
|
| 141 |
reset_model(model_name, with_extra_components=False)
|
| 142 |
-
|
| 143 |
tokens_container = []
|
| 144 |
|
| 145 |
for i in range(MAX_PROMPT_TOKENS):
|
|
@@ -185,17 +186,17 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 185 |
dataset = dataset.filter(info['filter'])
|
| 186 |
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
|
| 187 |
dataset = [[row[info['text_col']]] for row in dataset]
|
| 188 |
-
gr.Examples(dataset, [
|
| 189 |
|
| 190 |
with gr.Group():
|
| 191 |
-
|
| 192 |
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
| 193 |
|
| 194 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
| 195 |
with gr.Group('Interpretation'):
|
| 196 |
-
|
| 197 |
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
| 198 |
-
[
|
| 199 |
|
| 200 |
with gr.Accordion(open=False, label='Generation Settings'):
|
| 201 |
with gr.Row():
|
|
@@ -225,17 +226,17 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 225 |
|
| 226 |
# event listeners
|
| 227 |
for i, btn in enumerate(tokens_container):
|
| 228 |
-
btn.click(partial(run_interpretation, i=i), [
|
| 229 |
num_tokens, do_sample, temperature,
|
| 230 |
top_k, top_p, repetition_penalty, length_penalty
|
| 231 |
], [progress_dummy, *interpretation_bubbles])
|
| 232 |
|
| 233 |
original_prompt_btn.click(get_hidden_states,
|
| 234 |
-
[
|
| 235 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 236 |
-
|
| 237 |
|
| 238 |
-
extra_components = [
|
| 239 |
model_chooser.change(reset_model, [model_chooser, *extra_components],
|
| 240 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
| 241 |
|
|
|
|
| 29 |
local_state : LocalState = LocalState()
|
| 30 |
wait_with_hidden_state : bool = False
|
| 31 |
interpretation_prompt_template : str = '{prompt}'
|
| 32 |
+
original_prompt_template : str = 'User: [X]\n\nAssistant: {prompt}'
|
| 33 |
layers_format : str = 'model.layers.{k}'
|
| 34 |
|
| 35 |
|
| 36 |
suggested_interpretation_prompts = [
|
| 37 |
+
"The meaning of [X] is",
|
| 38 |
"Sure, here's a bullet list of the key words in your message:",
|
| 39 |
"Sure, I'll summarize your message:",
|
| 40 |
"Sure, here are the words in your message:",
|
|
|
|
| 140 |
|
| 141 |
model_name = 'LLAMA2-7B'
|
| 142 |
reset_model(model_name, with_extra_components=False)
|
| 143 |
+
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
| 144 |
tokens_container = []
|
| 145 |
|
| 146 |
for i in range(MAX_PROMPT_TOKENS):
|
|
|
|
| 186 |
dataset = dataset.filter(info['filter'])
|
| 187 |
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
|
| 188 |
dataset = [[row[info['text_col']]] for row in dataset]
|
| 189 |
+
gr.Examples(dataset, [raw_original_prompt], cache_examples=False)
|
| 190 |
|
| 191 |
with gr.Group():
|
| 192 |
+
raw_original_prompt.render()
|
| 193 |
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
| 194 |
|
| 195 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
| 196 |
with gr.Group('Interpretation'):
|
| 197 |
+
raw_interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 198 |
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
| 199 |
+
[raw_interpretation_prompt], cache_examples=False)
|
| 200 |
|
| 201 |
with gr.Accordion(open=False, label='Generation Settings'):
|
| 202 |
with gr.Row():
|
|
|
|
| 226 |
|
| 227 |
# event listeners
|
| 228 |
for i, btn in enumerate(tokens_container):
|
| 229 |
+
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
| 230 |
num_tokens, do_sample, temperature,
|
| 231 |
top_k, top_p, repetition_penalty, length_penalty
|
| 232 |
], [progress_dummy, *interpretation_bubbles])
|
| 233 |
|
| 234 |
original_prompt_btn.click(get_hidden_states,
|
| 235 |
+
[raw_original_prompt],
|
| 236 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 237 |
+
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 238 |
|
| 239 |
+
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
| 240 |
model_chooser.change(reset_model, [model_chooser, *extra_components],
|
| 241 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
| 242 |
|