Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +4 -4
src/streamlit_app.py
CHANGED
|
@@ -204,7 +204,7 @@ def render_attribution_html(result, log_color: bool = False, cmap_name: str = "B
|
|
| 204 |
grad_magnitude[target_idx_in_grad] = 1e-8
|
| 205 |
bar_idx = target_idx_in_grad
|
| 206 |
|
| 207 |
-
grad_np = grad_magnitude.cpu().numpy()
|
| 208 |
log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max())
|
| 209 |
norm = Norm(vmin=grad_np.min(), vmax=grad_np.max())
|
| 210 |
|
|
@@ -289,9 +289,9 @@ def render_attribution_barplot(result, log_color: bool = False, cmap_name: str =
|
|
| 289 |
front_pad = 2 # assumed
|
| 290 |
|
| 291 |
if len(grads.shape) == 2:
|
| 292 |
-
grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().cpu().numpy()
|
| 293 |
else:
|
| 294 |
-
grad_magnitude = grads.detach().clone().cpu().numpy()
|
| 295 |
|
| 296 |
hardset_target_grad = True
|
| 297 |
target_bar_index = None
|
|
@@ -506,7 +506,7 @@ def main():
|
|
| 506 |
probs = torch.softmax(logit_vector, dim=-1)
|
| 507 |
top_probs, top_indices = torch.topk(probs, k)
|
| 508 |
top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
|
| 509 |
-
for i, (tok, prob) in enumerate(zip(top_tokens, top_probs.cpu().numpy()), 1):
|
| 510 |
st.write(f"{i}. P(**{repr(tok)}**)={prob:.3f}")
|
| 511 |
|
| 512 |
|
|
|
|
| 204 |
grad_magnitude[target_idx_in_grad] = 1e-8
|
| 205 |
bar_idx = target_idx_in_grad
|
| 206 |
|
| 207 |
+
grad_np = grad_magnitude.float().cpu().numpy()
|
| 208 |
log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max())
|
| 209 |
norm = Norm(vmin=grad_np.min(), vmax=grad_np.max())
|
| 210 |
|
|
|
|
| 289 |
front_pad = 2 # assumed
|
| 290 |
|
| 291 |
if len(grads.shape) == 2:
|
| 292 |
+
grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().float().cpu().numpy()
|
| 293 |
else:
|
| 294 |
+
grad_magnitude = grads.detach().clone().float().cpu().numpy()
|
| 295 |
|
| 296 |
hardset_target_grad = True
|
| 297 |
target_bar_index = None
|
|
|
|
| 506 |
probs = torch.softmax(logit_vector, dim=-1)
|
| 507 |
top_probs, top_indices = torch.topk(probs, k)
|
| 508 |
top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
|
| 509 |
+
for i, (tok, prob) in enumerate(zip(top_tokens, top_probs.float().cpu().numpy()), 1):
|
| 510 |
st.write(f"{i}. P(**{repr(tok)}**)={prob:.3f}")
|
| 511 |
|
| 512 |
|