Typony commited on
Commit
0d820e3
·
verified ·
1 Parent(s): e863bb2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +4 -4
src/streamlit_app.py CHANGED
@@ -204,7 +204,7 @@ def render_attribution_html(result, log_color: bool = False, cmap_name: str = "B
204
  grad_magnitude[target_idx_in_grad] = 1e-8
205
  bar_idx = target_idx_in_grad
206
 
207
- grad_np = grad_magnitude.cpu().numpy()
208
  log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max())
209
  norm = Norm(vmin=grad_np.min(), vmax=grad_np.max())
210
 
@@ -289,9 +289,9 @@ def render_attribution_barplot(result, log_color: bool = False, cmap_name: str =
289
  front_pad = 2 # assumed
290
 
291
  if len(grads.shape) == 2:
292
- grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().cpu().numpy()
293
  else:
294
- grad_magnitude = grads.detach().clone().cpu().numpy()
295
 
296
  hardset_target_grad = True
297
  target_bar_index = None
@@ -506,7 +506,7 @@ def main():
506
  probs = torch.softmax(logit_vector, dim=-1)
507
  top_probs, top_indices = torch.topk(probs, k)
508
  top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
509
- for i, (tok, prob) in enumerate(zip(top_tokens, top_probs.cpu().numpy()), 1):
510
  st.write(f"{i}. P(**{repr(tok)}**)={prob:.3f}")
511
 
512
 
 
204
  grad_magnitude[target_idx_in_grad] = 1e-8
205
  bar_idx = target_idx_in_grad
206
 
207
+ grad_np = grad_magnitude.float().cpu().numpy()
208
  log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max())
209
  norm = Norm(vmin=grad_np.min(), vmax=grad_np.max())
210
 
 
289
  front_pad = 2 # assumed
290
 
291
  if len(grads.shape) == 2:
292
+ grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().float().cpu().numpy()
293
  else:
294
+ grad_magnitude = grads.detach().clone().float().cpu().numpy()
295
 
296
  hardset_target_grad = True
297
  target_bar_index = None
 
506
  probs = torch.softmax(logit_vector, dim=-1)
507
  top_probs, top_indices = torch.topk(probs, k)
508
  top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
509
+ for i, (tok, prob) in enumerate(zip(top_tokens, top_probs.float().cpu().numpy()), 1):
510
  st.write(f"{i}. P(**{repr(tok)}**)={prob:.3f}")
511
 
512