PeteBleackley commited on
Commit
4d8fdac
·
1 Parent(s): 798488e

Wrap CrossEntropyLoss in callable that makes it appplicable to sequences

Browse files
Files changed (1) hide show
  1. scripts.py +12 -3
scripts.py CHANGED
@@ -22,14 +22,23 @@ import scipy.spatial
22
  import seaborn
23
  import tqdm
24
 
25
- EPSILON = torch.tensor(1.0e-12)
 
 
 
 
 
 
 
 
 
26
 
27
  class CombinedLoss(torch.nn.Module):
28
  def __init__(self):
29
  super(CombinedLoss,self).__init__()
30
- self.component_losses = (torch.nn.CrossEntropyLoss(),
31
  torch.nn.MSELoss(),
32
- torch.nn.CrossEntropyLoss(),
33
  torch.nn.MSELoss())
34
 
35
  def forward(self,y_pred,y_true):
 
22
  import seaborn
23
  import tqdm
24
 
25
+ class SequenceCrossEntropyLoss(torch.nn.Module):
26
+ def __init__(self):
27
+ super(SequenceCrossEntropyLoss,self).__init__()
28
+ self.crossentropy = torch.nn.CrossEntropyLoss()
29
+
30
+ def forward(self,y_pred,y_true):
31
+ (batch_size,sequence_length,n_classes) = y_pred.shape
32
+ predictions = y_pred.view(-1,n_classes)
33
+ labels = y_true.view(-1)
34
+ return self.crossentropy(predictions,labels)
35
 
36
  class CombinedLoss(torch.nn.Module):
37
  def __init__(self):
38
  super(CombinedLoss,self).__init__()
39
+ self.component_losses = (SequenceCrossEntropyLoss(),
40
  torch.nn.MSELoss(),
41
+ SequenceCrossEntropyLoss(),
42
  torch.nn.MSELoss())
43
 
44
  def forward(self,y_pred,y_true):