Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
4d8fdac
1
Parent(s):
798488e
Wrap CrossEntropyLoss in callable that makes it appplicable to sequences
Browse files- scripts.py +12 -3
scripts.py
CHANGED
|
@@ -22,14 +22,23 @@ import scipy.spatial
|
|
| 22 |
import seaborn
|
| 23 |
import tqdm
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
class CombinedLoss(torch.nn.Module):
|
| 28 |
def __init__(self):
|
| 29 |
super(CombinedLoss,self).__init__()
|
| 30 |
-
self.component_losses = (
|
| 31 |
torch.nn.MSELoss(),
|
| 32 |
-
|
| 33 |
torch.nn.MSELoss())
|
| 34 |
|
| 35 |
def forward(self,y_pred,y_true):
|
|
|
|
| 22 |
import seaborn
|
| 23 |
import tqdm
|
| 24 |
|
| 25 |
+
class SequenceCrossEntropyLoss(torch.nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super(SequenceCrossEntropyLoss,self).__init__()
|
| 28 |
+
self.crossentropy = torch.nn.CrossEntropyLoss()
|
| 29 |
+
|
| 30 |
+
def forward(self,y_pred,y_true):
|
| 31 |
+
(batch_size,sequence_length,n_classes) = y_pred.shape
|
| 32 |
+
predictions = y_pred.view(-1,n_classes)
|
| 33 |
+
labels = y_true.view(-1)
|
| 34 |
+
return self.crossentropy(predictions,labels)
|
| 35 |
|
| 36 |
class CombinedLoss(torch.nn.Module):
|
| 37 |
def __init__(self):
|
| 38 |
super(CombinedLoss,self).__init__()
|
| 39 |
+
self.component_losses = (SequenceCrossEntropyLoss(),
|
| 40 |
torch.nn.MSELoss(),
|
| 41 |
+
SequenceCrossEntropyLoss(),
|
| 42 |
torch.nn.MSELoss())
|
| 43 |
|
| 44 |
def forward(self,y_pred,y_true):
|