liyang-ict commited on
Commit
4b35303
·
verified ·
1 Parent(s): e0fa27d

Upload logic_consistency_loss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. logic_consistency_loss.py +71 -0
logic_consistency_loss.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class LogicConsistencyLoss(nn.Module):
7
+ """Logic consistency loss.
8
+
9
+ Referenece:
10
+ "Leveraging Declarative Knowledge in Text and First-Order Logic for Fine-Grained Propaganda Detection".
11
+ https://arxiv.org/abs/2004.14201
12
+
13
+ Args:
14
+ use_sigmoid (bool): Can only be true for BCE based loss now.
15
+ reduction (str, optional): Specifies the reduction to apply to the output:
16
+ ``'none'``: no reduction will be applied,
17
+ ``'mean'``: the sum of the output will be divided by the number of elements in the output,
18
+ ``'sum'``: the output will be summed.
19
+ Default: mean
20
+ """
21
+
22
+ def __init__(self, n_classes, reduce="max", reduction="mean"):
23
+ super(LogicConsistencyLoss, self).__init__()
24
+ if reduction not in ["sum", "mean", "none"]:
25
+ raise ValueError(
26
+ f"reduction should be 'sum', 'mean', or 'none', but got {reduction}"
27
+ )
28
+ if reduce not in ["max", "min", "avg"]:
29
+ raise ValueError(
30
+ f"reduce should be 'max', 'avg' or 'min', but got {reduce}"
31
+ )
32
+ self.n_classes = n_classes
33
+ self.reduce = reduce
34
+ self.reduction = reduction
35
+
36
+ def forward(self, fc, gc, gc_mask):
37
+ """Calculate the logic consistency loss.
38
+ Args:
39
+ fc (float tensor of size [batch_num, n_classes]):
40
+ Sentence level prediction with logits.
41
+ gc (float tensor of size [batch_num, token_num, n_classes]):
42
+ Token level prediction with logits.
43
+ gc_mask (float tensor of size [batch_num, token_num]):
44
+ Mask for token level prediction.
45
+ """
46
+ batch_num, token_num, class_num = gc.size()
47
+ assert (
48
+ class_num == self.n_classes
49
+ ), f"Class number mismatch: {class_num} vs {self.n_classes}"
50
+
51
+ fc = F.softmax(fc, dim=-1)[..., -1] # [batch_num,]
52
+ gc = F.softmax(gc, dim=-1)[..., -1] # [batch_num, token_num]
53
+ gc = torch.mul(gc, gc_mask)
54
+
55
+ if self.reduce == "max":
56
+ gc_maxpool, _ = torch.max(gc, dim=-1) # [batch_num, ]
57
+ elif self.reduce == "min":
58
+ gc_maxpool, _ = torch.min(gc, dim=-1)
59
+ elif self.reduce == "avg":
60
+ gc_maxpool = torch.mean(gc, dim=-1) # [batch_num, ]
61
+
62
+ pf = 1 - fc + torch.mul(fc, gc_maxpool)
63
+ pf = torch.clamp(pf, min=1e-8) # Clamp to prevent log(0)
64
+
65
+ out = torch.neg(torch.log(pf))
66
+
67
+ if self.reduction == "mean":
68
+ loss = torch.mean(out)
69
+ elif self.reduction == "sum":
70
+ loss = torch.sum(out)
71
+ return loss