Spaces:
Build error
Build error
Update lxmert/src/ExplanationGenerator.py
Browse files
lxmert/src/ExplanationGenerator.py
CHANGED
|
@@ -317,7 +317,7 @@ class GeneratorOursAblationNoAggregation:
|
|
| 317 |
one_hot[0, index] = 1
|
| 318 |
one_hot_vector = one_hot
|
| 319 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 320 |
-
one_hot = torch.sum(one_hot
|
| 321 |
|
| 322 |
model.zero_grad()
|
| 323 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 317 |
one_hot[0, index] = 1
|
| 318 |
one_hot_vector = one_hot
|
| 319 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 320 |
+
one_hot = torch.sum(one_hot * output)
|
| 321 |
|
| 322 |
model.zero_grad()
|
| 323 |
one_hot.backward(retain_graph=True)
|