Update app.py
Browse files
app.py
CHANGED
|
@@ -27,6 +27,7 @@ class GlobalState:
|
|
| 27 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
| 28 |
model : Optional[PreTrainedModel] = None
|
| 29 |
local_state : LocalState = LocalState()
|
|
|
|
| 30 |
interpretation_prompt_template : str = '{prompt}'
|
| 31 |
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
| 32 |
layers_format : str = 'model.layers.{k}'
|
|
@@ -48,7 +49,6 @@ def initialize_gpu():
|
|
| 48 |
|
| 49 |
def reset_model(model_name, *extra_components, with_extra_components=True):
|
| 50 |
# extract model info
|
| 51 |
-
|
| 52 |
model_args = deepcopy(model_info[model_name])
|
| 53 |
model_path = model_args.pop('model_path')
|
| 54 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
|
@@ -57,6 +57,7 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
| 57 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
| 58 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 59 |
dont_cuda = model_args.pop('dont_cuda', False)
|
|
|
|
| 60 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
| 61 |
|
| 62 |
# get model
|
|
@@ -80,22 +81,28 @@ def get_hidden_states(raw_original_prompt):
|
|
| 80 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 81 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 82 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 85 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 86 |
progress_dummy_output = ''
|
| 87 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
| 88 |
-
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
| 89 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 90 |
|
| 91 |
|
| 92 |
@spaces.GPU
|
| 93 |
-
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 94 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 95 |
num_beams=1):
|
| 96 |
model = global_state.model
|
| 97 |
tokenizer = global_state.tokenizer
|
| 98 |
print(f'run {model}')
|
|
|
|
|
|
|
| 99 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 100 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 101 |
|
|
@@ -218,7 +225,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 218 |
|
| 219 |
# event listeners
|
| 220 |
for i, btn in enumerate(tokens_container):
|
| 221 |
-
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
| 222 |
num_tokens, do_sample, temperature,
|
| 223 |
top_k, top_p, repetition_penalty, length_penalty
|
| 224 |
], [progress_dummy, *interpretation_bubbles])
|
|
|
|
| 27 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
| 28 |
model : Optional[PreTrainedModel] = None
|
| 29 |
local_state : LocalState = LocalState()
|
| 30 |
+
wait_with_hidden_state : bool = False
|
| 31 |
interpretation_prompt_template : str = '{prompt}'
|
| 32 |
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
| 33 |
layers_format : str = 'model.layers.{k}'
|
|
|
|
| 49 |
|
| 50 |
def reset_model(model_name, *extra_components, with_extra_components=True):
|
| 51 |
# extract model info
|
|
|
|
| 52 |
model_args = deepcopy(model_info[model_name])
|
| 53 |
model_path = model_args.pop('model_path')
|
| 54 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
|
|
|
| 57 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
| 58 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 59 |
dont_cuda = model_args.pop('dont_cuda', False)
|
| 60 |
+
global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
|
| 61 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
| 62 |
|
| 63 |
# get model
|
|
|
|
| 81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 83 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 84 |
+
if global_state.wait_with_hidden_states:
|
| 85 |
+
global_state.local_state.hidden_states = None
|
| 86 |
+
else:
|
| 87 |
+
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
| 88 |
+
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
| 89 |
+
|
| 90 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 91 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 92 |
progress_dummy_output = ''
|
| 93 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
|
|
|
| 94 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
| 95 |
|
| 96 |
|
| 97 |
@spaces.GPU
|
| 98 |
+
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 99 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 100 |
num_beams=1):
|
| 101 |
model = global_state.model
|
| 102 |
tokenizer = global_state.tokenizer
|
| 103 |
print(f'run {model}')
|
| 104 |
+
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
| 105 |
+
get_hidden_states(raw_original_prompt)
|
| 106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 108 |
|
|
|
|
| 225 |
|
| 226 |
# event listeners
|
| 227 |
for i, btn in enumerate(tokens_container):
|
| 228 |
+
btn.click(partial(run_interpretation, i=i), [original_prompt_raw, interpretation_prompt,
|
| 229 |
num_tokens, do_sample, temperature,
|
| 230 |
top_k, top_p, repetition_penalty, length_penalty
|
| 231 |
], [progress_dummy, *interpretation_bubbles])
|