anoushka2000 commited on
Commit
6c36c0b
·
verified ·
1 Parent(s): 9f072eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -42,26 +42,16 @@ def get_channels(model):
42
  return None
43
 
44
 
45
- def forward_fn(input_ids, attention_mask, model, target_idx=None):
46
  output = model(input_ids=input_ids, attention_mask=attention_mask)
47
 
48
  if hasattr(output, "logits"):
49
- logits = output.logits
50
- if target_idx is not None and len(logits.shape) > 1 and logits.shape[-1] > 1:
51
- return logits[:, target_idx]
52
- return logits.squeeze(-1)
53
-
54
- if hasattr(model, "encoder"):
55
- encoder_output = model.encoder(
56
- input_ids=input_ids, attention_mask=attention_mask
57
- )
58
- if hasattr(encoder_output, "last_hidden_state"):
59
- return encoder_output.last_hidden_state[:, 0, :].mean(dim=-1)
60
- return encoder_output[0][:, 0, :].mean(dim=-1)
61
 
62
- if hasattr(output, "last_hidden_state"):
63
- return output.last_hidden_state[:, 0, :].mean(dim=-1)
64
- return output[0][:, 0, :].mean(dim=-1)
65
 
66
 
67
  @torch.no_grad()
@@ -93,17 +83,21 @@ def compute_attributions(
93
 
94
  baseline_ids = torch.full_like(input_ids, pad_id)
95
  lig = LayerIntegratedGradients(
96
- lambda ids, am: forward_fn(ids, am, model, target_idx),
97
  get_embedding_layer(model),
98
  )
99
 
100
- attributions, delta = lig.attribute(
101
- inputs=input_ids,
102
- baselines=baseline_ids,
103
- additional_forward_args=(attention_mask,),
104
- return_convergence_delta=True,
105
- n_steps=n_steps,
106
- )
 
 
 
 
107
 
108
  token_scores = attributions.sum(dim=-1) * attention_mask
109
  return token_scores, delta
 
42
  return None
43
 
44
 
45
+ def forward_fn(input_ids, attention_mask, model):
46
  output = model(input_ids=input_ids, attention_mask=attention_mask)
47
 
48
  if hasattr(output, "logits"):
49
+ return output.logits
50
+
51
+ if isinstance(output, tuple):
52
+ return output[0]
 
 
 
 
 
 
 
 
53
 
54
+ return output
 
 
55
 
56
 
57
  @torch.no_grad()
 
83
 
84
  baseline_ids = torch.full_like(input_ids, pad_id)
85
  lig = LayerIntegratedGradients(
86
+ lambda ids, am: forward_fn(ids, am, model),
87
  get_embedding_layer(model),
88
  )
89
 
90
+ attr_kwargs = {
91
+ "inputs": input_ids,
92
+ "baselines": baseline_ids,
93
+ "additional_forward_args": (attention_mask,),
94
+ "return_convergence_delta": True,
95
+ "n_steps": n_steps,
96
+ }
97
+ if target_idx is not None:
98
+ attr_kwargs["target"] = target_idx
99
+
100
+ attributions, delta = lig.attribute(**attr_kwargs)
101
 
102
  token_scores = attributions.sum(dim=-1) * attention_mask
103
  return token_scores, delta