Update app.py
Browse files
app.py
CHANGED
|
@@ -45,25 +45,22 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
|
|
| 45 |
def initialize_gpu():
|
| 46 |
pass
|
| 47 |
|
| 48 |
-
def get_hidden_states(raw_original_prompt
|
| 49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
| 50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 52 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
| 54 |
-
token_btns = []
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 65 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty, interpreted_vectors, num_beams=1):
|
| 66 |
-
|
| 67 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 68 |
|
| 69 |
# generation parameters
|
|
@@ -83,7 +80,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
| 83 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 84 |
|
| 85 |
# generate the interpretations
|
| 86 |
-
generated = interpretation_prompt.generate(model, {0:
|
| 87 |
generation_texts = tokenizer.batch_decode(generated)
|
| 88 |
return generation_texts
|
| 89 |
|
|
@@ -105,6 +102,8 @@ model = AutoModelClass.from_pretrained(model_name, **model_args)
|
|
| 105 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
| 106 |
|
| 107 |
# demo
|
|
|
|
|
|
|
| 108 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column(scale=5):
|
|
@@ -144,15 +143,15 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
| 144 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 145 |
|
| 146 |
with gr.Group('Output'):
|
|
|
|
| 147 |
with gr.Row():
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
), [original_prompt_raw], [*tokens_container])
|
| 158 |
demo.launch()
|
|
|
|
| 45 |
def initialize_gpu():
|
| 46 |
pass
|
| 47 |
|
| 48 |
+
def get_hidden_states(raw_original_prompt):
|
| 49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
| 50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 52 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
| 54 |
+
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
| 55 |
+
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
| 56 |
+
return [hidden_state, *token_btns]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 60 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
| 61 |
+
num_beams=1):
|
| 62 |
+
|
| 63 |
+
interpreted_vectors = global_state[:, i]
|
|
|
|
|
|
|
|
|
|
| 64 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 65 |
|
| 66 |
# generation parameters
|
|
|
|
| 80 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 81 |
|
| 82 |
# generate the interpretations
|
| 83 |
+
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 84 |
generation_texts = tokenizer.batch_decode(generated)
|
| 85 |
return generation_texts
|
| 86 |
|
|
|
|
| 102 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
|
| 103 |
|
| 104 |
# demo
|
| 105 |
+
global_state = gr.State([])
|
| 106 |
+
json_output = gr.JSON()
|
| 107 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
| 108 |
with gr.Row():
|
| 109 |
with gr.Column(scale=5):
|
|
|
|
| 143 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 144 |
|
| 145 |
with gr.Group('Output'):
|
| 146 |
+
tokens_container = []
|
| 147 |
with gr.Row():
|
| 148 |
+
for _ in range(MAX_PROMPT_TOKENS):
|
| 149 |
+
btn = gr.Button('', visible=False)
|
| 150 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
| 151 |
+
top_k, top_p, repetition_penalty, length_penalty
|
| 152 |
+
], [json_output])
|
| 153 |
+
tokens_container.append(btn)
|
| 154 |
+
json_output.render()
|
| 155 |
+
|
| 156 |
+
original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [global_state, *tokens_container])
|
|
|
|
| 157 |
demo.launch()
|