Spaces:
Runtime error
Runtime error
Update BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py
Browse files
BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py
CHANGED
|
@@ -37,7 +37,7 @@ class Generator:
|
|
| 37 |
one_hot[0, index] = 1
|
| 38 |
one_hot_vector = one_hot
|
| 39 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 40 |
-
one_hot = torch.sum(one_hot
|
| 41 |
|
| 42 |
self.model.zero_grad()
|
| 43 |
one_hot.backward(retain_graph=True)
|
|
@@ -70,7 +70,7 @@ class Generator:
|
|
| 70 |
one_hot[0, index] = 1
|
| 71 |
one_hot_vector = one_hot
|
| 72 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 73 |
-
one_hot = torch.sum(one_hot
|
| 74 |
|
| 75 |
self.model.zero_grad()
|
| 76 |
one_hot.backward(retain_graph=True)
|
|
@@ -94,7 +94,7 @@ class Generator:
|
|
| 94 |
one_hot[0, index] = 1
|
| 95 |
one_hot_vector = one_hot
|
| 96 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 97 |
-
one_hot = torch.sum(one_hot
|
| 98 |
|
| 99 |
self.model.zero_grad()
|
| 100 |
one_hot.backward(retain_graph=True)
|
|
@@ -136,7 +136,7 @@ class Generator:
|
|
| 136 |
one_hot[0, index] = 1
|
| 137 |
one_hot_vector = one_hot
|
| 138 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 139 |
-
one_hot = torch.sum(one_hot
|
| 140 |
|
| 141 |
self.model.zero_grad()
|
| 142 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 37 |
one_hot[0, index] = 1
|
| 38 |
one_hot_vector = one_hot
|
| 39 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 40 |
+
one_hot = torch.sum(one_hot * output)
|
| 41 |
|
| 42 |
self.model.zero_grad()
|
| 43 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 70 |
one_hot[0, index] = 1
|
| 71 |
one_hot_vector = one_hot
|
| 72 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 73 |
+
one_hot = torch.sum(one_hot * output)
|
| 74 |
|
| 75 |
self.model.zero_grad()
|
| 76 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 94 |
one_hot[0, index] = 1
|
| 95 |
one_hot_vector = one_hot
|
| 96 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 97 |
+
one_hot = torch.sum(one_hot * output)
|
| 98 |
|
| 99 |
self.model.zero_grad()
|
| 100 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 136 |
one_hot[0, index] = 1
|
| 137 |
one_hot_vector = one_hot
|
| 138 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 139 |
+
one_hot = torch.sum(one_hot * output)
|
| 140 |
|
| 141 |
self.model.zero_grad()
|
| 142 |
one_hot.backward(retain_graph=True)
|