Respair commited on
Commit
79f54fe
·
verified ·
1 Parent(s): 1fb7ae9

Create loss.py

Browse files
Files changed (1) hide show
  1. loss.py +42 -0
loss.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.nn import Module
6
+
7
+ class ForwardSumLoss(Module):
8
+
9
+ def __init__(self, blank_logprob=-1, loss_scale=1.0):
10
+ super().__init__()
11
+ self.log_softmax = torch.nn.LogSoftmax(dim=-1)
12
+ self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True, blank=16)
13
+ self.blank_logprob = blank_logprob
14
+ self.loss_scale = loss_scale
15
+
16
+ def forward(self, attn_logprob, in_lens, out_lens):
17
+ key_lens = in_lens
18
+ query_lens = out_lens
19
+ max_key_len = attn_logprob.size(-1)
20
+
21
+ # Reorder input to [query_len, batch_size, key_len]
22
+ attn_logprob = attn_logprob.squeeze(1)
23
+ attn_logprob = attn_logprob.permute(1, 0, 2)
24
+
25
+ # Add blank label
26
+ attn_logprob = F.pad(input=attn_logprob, pad=(1, 0, 0, 0, 0, 0), value=self.blank_logprob)
27
+
28
+ # Convert to log probabilities
29
+ # Note: Mask out probs beyond key_len
30
+ key_inds = torch.arange(max_key_len + 1, device=attn_logprob.device, dtype=torch.long)
31
+ attn_logprob.masked_fill_(key_inds.view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15) # key_inds >= key_lens+1
32
+ attn_logprob = self.log_softmax(attn_logprob)
33
+
34
+ # Target sequences
35
+ target_seqs = key_inds[1:].unsqueeze(0)
36
+ target_seqs = target_seqs.repeat(key_lens.numel(), 1)
37
+
38
+ # Evaluate CTC loss
39
+ cost = self.ctc_loss(attn_logprob, target_seqs, input_lengths=query_lens, target_lengths=key_lens)
40
+ cost *= self.loss_scale
41
+
42
+ return cost