Update app.py
Browse files
app.py
CHANGED
|
@@ -107,7 +107,6 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
| 107 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
| 108 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 109 |
hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype)
|
| 110 |
-
hidden_norms = hidden_means.norm(keepdim=True, dim=-1)
|
| 111 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 112 |
|
| 113 |
# generation parameters
|
|
@@ -133,7 +132,7 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
| 133 |
generation_texts = tokenizer.batch_decode(generated)
|
| 134 |
|
| 135 |
# create GUI output
|
| 136 |
-
important_idxs = 1 + ((interpreted_vectors - hidden_means)
|
| 137 |
print(f'{important_idxs=}')
|
| 138 |
progress_dummy_output = ''
|
| 139 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|
|
|
|
| 107 |
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
| 108 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 109 |
hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype)
|
|
|
|
| 110 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 111 |
|
| 112 |
# generation parameters
|
|
|
|
| 132 |
generation_texts = tokenizer.batch_decode(generated)
|
| 133 |
|
| 134 |
# create GUI output
|
| 135 |
+
important_idxs = 1 + ((interpreted_vectors - hidden_means)).diff(dim=0).norm(dim=-1).topk(k=int(np.ceil(0.2 * len(generation_texts)))).indices.cpu().numpy()
|
| 136 |
print(f'{important_idxs=}')
|
| 137 |
progress_dummy_output = ''
|
| 138 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|