Commit
·
32335bd
1
Parent(s):
7f0e78a
Fix error prob bug
Browse files- gec_model.py +6 -3
gec_model.py
CHANGED
|
@@ -78,6 +78,8 @@ class GecBERTModel(torch.nn.Module):
|
|
| 78 |
List of punctuations.
|
| 79 |
"""
|
| 80 |
super().__init__()
|
|
|
|
|
|
|
| 81 |
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
| 82 |
self.device = (
|
| 83 |
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
|
@@ -106,8 +108,6 @@ class GecBERTModel(torch.nn.Module):
|
|
| 106 |
|
| 107 |
self.indexers = []
|
| 108 |
self.models = []
|
| 109 |
-
if isinstance(model_paths, str):
|
| 110 |
-
model_paths = [model_paths]
|
| 111 |
for model_path in model_paths:
|
| 112 |
model = Seq2LabelsModel.from_pretrained(model_path)
|
| 113 |
config = model.config
|
|
@@ -337,7 +337,10 @@ class GecBERTModel(torch.nn.Module):
|
|
| 337 |
for output, weight in zip(data, self.model_weights):
|
| 338 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
| 339 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
max_vals = torch.max(all_class_probs, dim=-1)
|
| 343 |
probs = max_vals[0].tolist()
|
|
|
|
| 78 |
List of punctuations.
|
| 79 |
"""
|
| 80 |
super().__init__()
|
| 81 |
+
if isinstance(model_paths, str):
|
| 82 |
+
model_paths = [model_paths]
|
| 83 |
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
| 84 |
self.device = (
|
| 85 |
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
|
|
|
| 108 |
|
| 109 |
self.indexers = []
|
| 110 |
self.models = []
|
|
|
|
|
|
|
| 111 |
for model_path in model_paths:
|
| 112 |
model = Seq2LabelsModel.from_pretrained(model_path)
|
| 113 |
config = model.config
|
|
|
|
| 337 |
for output, weight in zip(data, self.model_weights):
|
| 338 |
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
| 339 |
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
| 340 |
+
class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
|
| 341 |
+
error_probs_d = class_probabilities_d[:, :, self.incorr_index]
|
| 342 |
+
incorr_prob = torch.max(error_probs_d, dim=-1)[0]
|
| 343 |
+
error_probs += weight * incorr_prob / sum(self.model_weights)
|
| 344 |
|
| 345 |
max_vals = torch.max(all_class_probs, dim=-1)
|
| 346 |
probs = max_vals[0].tolist()
|