Update app.py
Browse files
app.py
CHANGED
|
@@ -74,7 +74,7 @@ def initialize_gpu():
|
|
| 74 |
pass
|
| 75 |
|
| 76 |
|
| 77 |
-
def reset_model(model_name,
|
| 78 |
# extract model info
|
| 79 |
model_args = deepcopy(model_info[model_name])
|
| 80 |
model_path = model_args.pop('model_path')
|
|
@@ -90,10 +90,11 @@ def reset_model(model_name, global_state):
|
|
| 90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 92 |
gc.collect()
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
-
def get_hidden_states(
|
| 97 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 98 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 99 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
@@ -105,11 +106,11 @@ def get_hidden_states(global_state, raw_original_prompt):
|
|
| 105 |
progress_dummy_output = ''
|
| 106 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
|
| 107 |
global_state.hidden_states = hidden_states
|
| 108 |
-
return [progress_dummy_output,
|
| 109 |
|
| 110 |
|
| 111 |
@spaces.GPU
|
| 112 |
-
def run_interpretation(
|
| 113 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 114 |
num_beams=1):
|
| 115 |
|
|
@@ -143,7 +144,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
| 143 |
|
| 144 |
|
| 145 |
## main
|
| 146 |
-
|
| 147 |
torch.set_grad_enabled(False)
|
| 148 |
model_name = 'LLAMA2-7B'
|
| 149 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
|
@@ -153,7 +154,6 @@ for i in range(MAX_PROMPT_TOKENS):
|
|
| 153 |
tokens_container.append(btn)
|
| 154 |
|
| 155 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
| 156 |
-
global_state = gr.State(reset_model(model_name, GlobalState()))
|
| 157 |
|
| 158 |
with gr.Row():
|
| 159 |
with gr.Column(scale=5):
|
|
@@ -236,8 +236,9 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 236 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 237 |
) for i in range(model.config.num_hidden_layers)]
|
| 238 |
|
|
|
|
| 239 |
# event listeners
|
| 240 |
-
model_chooser.change(reset_new_model, [model_chooser
|
| 241 |
|
| 242 |
for i, btn in enumerate(tokens_container):
|
| 243 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
|
@@ -247,6 +248,6 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 247 |
|
| 248 |
original_prompt_btn.click(get_hidden_states,
|
| 249 |
[original_prompt_raw],
|
| 250 |
-
[progress_dummy,
|
| 251 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 252 |
demo.launch()
|
|
|
|
| 74 |
pass
|
| 75 |
|
| 76 |
|
| 77 |
+
def reset_model(model_name, return_state=False):
|
| 78 |
# extract model info
|
| 79 |
model_args = deepcopy(model_info[model_name])
|
| 80 |
model_path = model_args.pop('model_path')
|
|
|
|
| 90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 92 |
gc.collect()
|
| 93 |
+
if return_state:
|
| 94 |
+
return global_state
|
| 95 |
|
| 96 |
|
| 97 |
+
def get_hidden_states(raw_original_prompt):
|
| 98 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 99 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 100 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
|
| 106 |
progress_dummy_output = ''
|
| 107 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
|
| 108 |
global_state.hidden_states = hidden_states
|
| 109 |
+
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 110 |
|
| 111 |
|
| 112 |
@spaces.GPU
|
| 113 |
+
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 114 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 115 |
num_beams=1):
|
| 116 |
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
## main
|
| 147 |
+
global_state = reset_model(model_name, return_state=True)
|
| 148 |
torch.set_grad_enabled(False)
|
| 149 |
model_name = 'LLAMA2-7B'
|
| 150 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
|
|
|
| 154 |
tokens_container.append(btn)
|
| 155 |
|
| 156 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
|
|
| 157 |
|
| 158 |
with gr.Row():
|
| 159 |
with gr.Column(scale=5):
|
|
|
|
| 236 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
| 237 |
) for i in range(model.config.num_hidden_layers)]
|
| 238 |
|
| 239 |
+
|
| 240 |
# event listeners
|
| 241 |
+
model_chooser.change(reset_new_model, [model_chooser], [])
|
| 242 |
|
| 243 |
for i, btn in enumerate(tokens_container):
|
| 244 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
|
|
|
| 248 |
|
| 249 |
original_prompt_btn.click(get_hidden_states,
|
| 250 |
[original_prompt_raw],
|
| 251 |
+
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
| 252 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
| 253 |
demo.launch()
|