File size: 343 Bytes
4c1e73e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
import torch.nn.functional as F

class KLDivLoss(nn.Module):
    """Kullback-Leibler Divergence Loss"""
    def __init__(self, reduction='batchmean'):
        super().__init__()
        self.reduction = reduction

    def forward(self, inputs, targets):
        return F.kl_div(inputs, targets, reduction=self.reduction)