Update modeling_wavlm_spkreg.py
Browse files- modeling_wavlm_spkreg.py +18 -4
modeling_wavlm_spkreg.py
CHANGED
|
@@ -456,7 +456,7 @@ class AAMSoftmaxLoss(nn.Module):
|
|
| 456 |
def __init__(
|
| 457 |
self,
|
| 458 |
scale: float = 30.0,
|
| 459 |
-
margin: float = 0.
|
| 460 |
easy_margin: bool = False,
|
| 461 |
label_smoothing: float = 0.0,
|
| 462 |
reduction: str = "mean"
|
|
@@ -489,9 +489,23 @@ class AAMSoftmaxLoss(nn.Module):
|
|
| 489 |
"""
|
| 490 |
_, num_labels = inputs.shape
|
| 491 |
# `inputs` are the outputs from AngularLinear()
|
| 492 |
-
|
| 493 |
-
theta = torch.acos(cos_theta)
|
| 494 |
-
psi = torch.cos(theta + self.margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
| 496 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
| 497 |
loss = F.cross_entropy(
|
|
|
|
| 456 |
def __init__(
|
| 457 |
self,
|
| 458 |
scale: float = 30.0,
|
| 459 |
+
margin: float = 0.2,
|
| 460 |
easy_margin: bool = False,
|
| 461 |
label_smoothing: float = 0.0,
|
| 462 |
reduction: str = "mean"
|
|
|
|
| 489 |
"""
|
| 490 |
_, num_labels = inputs.shape
|
| 491 |
# `inputs` are the outputs from AngularLinear()
|
| 492 |
+
epsilon = 1e-6
|
| 493 |
+
# theta = torch.acos(cos_theta)
|
| 494 |
+
# psi = torch.cos(theta + self.margin)
|
| 495 |
+
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
|
| 496 |
+
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
| 497 |
+
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
|
| 498 |
+
|
| 499 |
+
cos_m = math.cos(self.margin)
|
| 500 |
+
sin_m = math.sin(self.margin)
|
| 501 |
+
psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
|
| 502 |
+
|
| 503 |
+
if self.easy_margin:
|
| 504 |
+
psi = torch.where(cos_theta > 0, psi, cos_theta)
|
| 505 |
+
else:
|
| 506 |
+
# Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
|
| 507 |
+
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
|
| 508 |
+
|
| 509 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
| 510 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
| 511 |
loss = F.cross_entropy(
|