Update app.py
Browse files
app.py
CHANGED
|
@@ -104,11 +104,15 @@ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
|
| 104 |
|
| 105 |
@spaces.GPU
|
| 106 |
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 107 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 108 |
num_beams=1):
|
| 109 |
model = global_state.model
|
| 110 |
tokenizer = global_state.tokenizer
|
| 111 |
print(f'run {model}')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
| 113 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
| 114 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
|
@@ -251,6 +255,7 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
|
|
| 251 |
with gr.Row():
|
| 252 |
for btn in tokens_container:
|
| 253 |
btn.render()
|
|
|
|
| 254 |
|
| 255 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 256 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
@@ -259,7 +264,8 @@ with gr.Blocks(theme=gr.themes.Glass(), css='styles.css') as demo:
|
|
| 259 |
for i, btn in enumerate(tokens_container):
|
| 260 |
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
| 261 |
num_tokens, do_sample, temperature,
|
| 262 |
-
top_k, top_p, repetition_penalty, length_penalty
|
|
|
|
| 263 |
], [progress_dummy, *interpretation_bubbles])
|
| 264 |
|
| 265 |
original_prompt_btn.click(get_hidden_states,
|
|
|
|
| 104 |
|
| 105 |
@spaces.GPU
|
| 106 |
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 107 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
| 108 |
num_beams=1):
|
| 109 |
model = global_state.model
|
| 110 |
tokenizer = global_state.tokenizer
|
| 111 |
print(f'run {model}')
|
| 112 |
+
if use_gpu:
|
| 113 |
+
model = model.cuda()
|
| 114 |
+
else:
|
| 115 |
+
model = model.cpu()
|
| 116 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
| 117 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
| 118 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
|
|
|
| 255 |
with gr.Row():
|
| 256 |
for btn in tokens_container:
|
| 257 |
btn.render()
|
| 258 |
+
use_gpu = gr.Radio('Use GPU', value=True)
|
| 259 |
|
| 260 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
| 261 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
|
|
| 264 |
for i, btn in enumerate(tokens_container):
|
| 265 |
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
| 266 |
num_tokens, do_sample, temperature,
|
| 267 |
+
top_k, top_p, repetition_penalty, length_penalty,
|
| 268 |
+
use_gpu
|
| 269 |
], [progress_dummy, *interpretation_bubbles])
|
| 270 |
|
| 271 |
original_prompt_btn.click(get_hidden_states,
|