Update app.py
Browse files
app.py
CHANGED
|
@@ -57,7 +57,7 @@ suggested_interpretation_prompts = [
|
|
| 57 |
def initialize_gpu():
|
| 58 |
pass
|
| 59 |
|
| 60 |
-
def reset_model(
|
| 61 |
# extract model info
|
| 62 |
model_args = deepcopy(model_info[model_name])
|
| 63 |
model_path = model_args.pop('model_path')
|
|
@@ -84,15 +84,15 @@ def reset_model(global_state, model_name, load_on_gpu, *extra_components, reset_
|
|
| 84 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 85 |
gc.collect()
|
| 86 |
if with_extra_components:
|
| 87 |
-
return ([
|
| 88 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
| 89 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
| 90 |
+ [*extra_components])
|
| 91 |
else:
|
| 92 |
-
return
|
| 93 |
|
| 94 |
|
| 95 |
-
def get_hidden_states(
|
| 96 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 97 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 98 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
@@ -120,7 +120,7 @@ def get_hidden_states(global_state, raw_original_prompt, force_hidden_states=Fal
|
|
| 120 |
|
| 121 |
|
| 122 |
@spaces.GPU
|
| 123 |
-
def run_interpretation(
|
| 124 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
| 125 |
num_beams=1):
|
| 126 |
model = global_state.model
|
|
@@ -197,8 +197,8 @@ for i in range(MAX_PROMPT_TOKENS):
|
|
| 197 |
tokens_container.append(btn)
|
| 198 |
|
| 199 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
| 200 |
-
global_state =
|
| 201 |
-
|
| 202 |
with gr.Row():
|
| 203 |
with gr.Column(scale=5):
|
| 204 |
gr.Markdown('# 😎 Self-Interpreting Models')
|
|
@@ -278,19 +278,19 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 278 |
|
| 279 |
# event listeners
|
| 280 |
for i, btn in enumerate(tokens_container):
|
| 281 |
-
btn.click(partial(run_interpretation, i=i), [
|
| 282 |
num_tokens, do_sample, temperature,
|
| 283 |
top_k, top_p, repetition_penalty, length_penalty,
|
| 284 |
use_gpu
|
| 285 |
], [progress_dummy, *interpretation_bubbles])
|
| 286 |
|
| 287 |
original_prompt_btn.click(get_hidden_states,
|
| 288 |
-
[
|
| 289 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 290 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 291 |
|
| 292 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
| 293 |
-
model_chooser.change(reset_model, [
|
| 294 |
-
[
|
| 295 |
|
| 296 |
demo.launch()
|
|
|
|
| 57 |
def initialize_gpu():
|
| 58 |
pass
|
| 59 |
|
| 60 |
+
def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
|
| 61 |
# extract model info
|
| 62 |
model_args = deepcopy(model_info[model_name])
|
| 63 |
model_path = model_args.pop('model_path')
|
|
|
|
| 84 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 85 |
gc.collect()
|
| 86 |
if with_extra_components:
|
| 87 |
+
return ([welcome_message.format(model_name=model_name)]
|
| 88 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
| 89 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
| 90 |
+ [*extra_components])
|
| 91 |
else:
|
| 92 |
+
return None
|
| 93 |
|
| 94 |
|
| 95 |
+
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
| 96 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 97 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 98 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
@spaces.GPU
|
| 123 |
+
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 124 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
| 125 |
num_beams=1):
|
| 126 |
model = global_state.model
|
|
|
|
| 197 |
tokens_container.append(btn)
|
| 198 |
|
| 199 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
| 200 |
+
global_state = GlobalState()
|
| 201 |
+
reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
|
| 202 |
with gr.Row():
|
| 203 |
with gr.Column(scale=5):
|
| 204 |
gr.Markdown('# 😎 Self-Interpreting Models')
|
|
|
|
| 278 |
|
| 279 |
# event listeners
|
| 280 |
for i, btn in enumerate(tokens_container):
|
| 281 |
+
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
| 282 |
num_tokens, do_sample, temperature,
|
| 283 |
top_k, top_p, repetition_penalty, length_penalty,
|
| 284 |
use_gpu
|
| 285 |
], [progress_dummy, *interpretation_bubbles])
|
| 286 |
|
| 287 |
original_prompt_btn.click(get_hidden_states,
|
| 288 |
+
[raw_original_prompt],
|
| 289 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 290 |
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 291 |
|
| 292 |
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
| 293 |
+
model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
|
| 294 |
+
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
| 295 |
|
| 296 |
demo.launch()
|