Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from interpret import InterpretationPrompt
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
## info
|
| 12 |
model_info = {
|
|
@@ -52,7 +53,7 @@ def get_hidden_states(raw_original_prompt):
|
|
| 52 |
# with gr.Row() as tokens_container:
|
| 53 |
# for token in tokens:
|
| 54 |
# gr.Button(token)
|
| 55 |
-
return
|
| 56 |
|
| 57 |
|
| 58 |
def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
@@ -105,6 +106,7 @@ AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCau
|
|
| 105 |
model = AutoModelClass.from_pretrained(model_name, **model_args)
|
| 106 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
| 107 |
|
|
|
|
| 108 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column(scale=5):
|
|
@@ -124,8 +126,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
| 124 |
gr.Markdown('<span style="font-size:180px;">🤔</span>')
|
| 125 |
|
| 126 |
with gr.Group():
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
with gr.Accordion(open=False, label='Settings'):
|
| 131 |
with gr.Row():
|
|
@@ -144,11 +146,12 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
| 144 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 145 |
|
| 146 |
with gr.Group('Output'):
|
| 147 |
-
|
|
|
|
| 148 |
with gr.Column() as interpretations_container:
|
| 149 |
pass
|
| 150 |
|
| 151 |
-
|
| 152 |
# btn.click(run_model,
|
| 153 |
# [text, interpretation_prompt, num_tokens, do_sample, temperature,
|
| 154 |
# top_k, top_p, repetition_penalty, length_penalty],
|
|
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
from interpret import InterpretationPrompt
|
| 9 |
|
| 10 |
+
MAX_PROMPT_TOKENS = 30
|
| 11 |
|
| 12 |
## info
|
| 13 |
model_info = {
|
|
|
|
| 53 |
# with gr.Row() as tokens_container:
|
| 54 |
# for token in tokens:
|
| 55 |
# gr.Button(token)
|
| 56 |
+
return [gr.Button(tokens[i], visible=True) if i < len(tokens) else gr.Button('', visible=False) for i in range(MAX_PROMPT_TOKENS)]
|
| 57 |
|
| 58 |
|
| 59 |
def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
|
|
| 106 |
model = AutoModelClass.from_pretrained(model_name, **model_args)
|
| 107 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
| 108 |
|
| 109 |
+
# demo
|
| 110 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
| 111 |
with gr.Row():
|
| 112 |
with gr.Column(scale=5):
|
|
|
|
| 126 |
gr.Markdown('<span style="font-size:180px;">🤔</span>')
|
| 127 |
|
| 128 |
with gr.Group():
|
| 129 |
+
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail', container=True, label='Original Prompt')
|
| 130 |
+
original_prompt_btn = gr.Button('Compute', variant='primary')
|
| 131 |
|
| 132 |
with gr.Accordion(open=False, label='Settings'):
|
| 133 |
with gr.Row():
|
|
|
|
| 146 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 147 |
|
| 148 |
with gr.Group('Output'):
|
| 149 |
+
with gr.Row():
|
| 150 |
+
tokens_container = [gr.Button('', visible=False) for range(MAX_PROMPT_TOKENS)]
|
| 151 |
with gr.Column() as interpretations_container:
|
| 152 |
pass
|
| 153 |
|
| 154 |
+
original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
|
| 155 |
# btn.click(run_model,
|
| 156 |
# [text, interpretation_prompt, num_tokens, do_sample, temperature,
|
| 157 |
# top_k, top_p, repetition_penalty, length_penalty],
|