|
|
""" Contains functions not directly linked to coreference resolution """ |
|
|
|
|
|
from typing import List, Set |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from stanza.models.coref.const import EPSILON |
|
|
|
|
|
|
|
|
class GraphNode: |
|
|
def __init__(self, node_id: int): |
|
|
self.id = node_id |
|
|
self.links: Set[GraphNode] = set() |
|
|
self.visited = False |
|
|
|
|
|
def link(self, another: "GraphNode"): |
|
|
self.links.add(another) |
|
|
another.links.add(self) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return str(self.id) |
|
|
|
|
|
|
|
|
def add_dummy(tensor: torch.Tensor, eps: bool = False): |
|
|
""" Prepends zeros (or a very small value if eps is True) |
|
|
to the first (not zeroth) dimension of tensor. |
|
|
""" |
|
|
kwargs = dict(device=tensor.device, dtype=tensor.dtype) |
|
|
shape: List[int] = list(tensor.shape) |
|
|
shape[1] = 1 |
|
|
if not eps: |
|
|
dummy = torch.zeros(shape, **kwargs) |
|
|
else: |
|
|
dummy = torch.full(shape, EPSILON, **kwargs) |
|
|
return torch.cat((dummy, tensor), dim=1) |
|
|
|
|
|
def sigmoid_focal_loss( |
|
|
inputs: torch.Tensor, |
|
|
targets: torch.Tensor, |
|
|
alpha: float = 0.25, |
|
|
gamma: float = 2, |
|
|
reduction: str = "none", |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
|
|
|
|
|
Args: |
|
|
inputs (Tensor): A float tensor of arbitrary shape. |
|
|
The predictions for each example. |
|
|
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary |
|
|
classification label for each element in inputs |
|
|
(0 for the negative class and 1 for the positive class). |
|
|
alpha (float): Weighting factor in range [0, 1] to balance |
|
|
positive vs negative examples or -1 for ignore. Default: ``0.25``. |
|
|
gamma (float): Exponent of the modulating factor (1 - p_t) to |
|
|
balance easy vs hard examples. Default: ``2``. |
|
|
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` |
|
|
``'none'``: No reduction will be applied to the output. |
|
|
``'mean'``: The output will be averaged. |
|
|
``'sum'``: The output will be summed. Default: ``'none'``. |
|
|
Returns: |
|
|
Loss tensor with the reduction option applied. |
|
|
""" |
|
|
|
|
|
|
|
|
if not (0 <= alpha <= 1) and alpha != -1: |
|
|
raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.") |
|
|
|
|
|
p = torch.sigmoid(inputs) |
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
|
p_t = p * targets + (1 - p) * (1 - targets) |
|
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
|
|
if alpha >= 0: |
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
|
loss = alpha_t * loss |
|
|
|
|
|
|
|
|
if reduction == "none": |
|
|
pass |
|
|
elif reduction == "mean": |
|
|
loss = loss.mean() |
|
|
elif reduction == "sum": |
|
|
loss = loss.sum() |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" |
|
|
) |
|
|
return loss |
|
|
|