English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch import nn
__all__ = ['MultiLoss']
class MultiLoss(nn.Module):
"""Wrapper to compute the weighted sum of multiple criteria
:param criteria: List(callable)
List of criteria
:param lambdas: List(str)
"""
def __init__(self, criteria, lambdas):
super().__init__()
assert len(criteria) == len(lambdas)
self.criteria = nn.ModuleList(criteria)
self.lambdas = lambdas
def __len__(self):
return len(self.criteria)
def to(self, *args, **kwargs):
for i in range(len(self)):
self.criteria[i] = self.criteria[i].to(*args, **kwargs)
self.lambdas[i] = self.lambdas[i].to(*args, **kwargs)
def extra_repr(self) -> str:
return f'lambdas={self.lambdas}'
def forward(self, a, b, **kwargs):
loss = 0
for lamb, criterion, a_, b_ in zip(self.lambdas, self.criteria, a, b):
loss = loss + lamb * criterion(a_, b_, **kwargs)
return loss
@property
def weight(self):
"""MultiLoss supports `weight` if all its criteria support it.
"""
return self.criteria[0].weight
@weight.setter
def weight(self, weight):
"""MultiLoss supports `weight` if all its criteria support it.
"""
for i in range(len(self)):
self.criteria[i].weight = weight
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
"""Normal `state_dict` behavior, except for the shared criterion
weights, which are not saved under `prefix.criteria.i.weight`
but under `prefix.weight`.
"""
destination = super().state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
# Remove the 'weight' from the criteria
for i in range(len(self)):
destination.pop(f"{prefix}criteria.{i}.weight")
# Only save the global shared weight
destination[f"{prefix}weight"] = self.weight
return destination
def load_state_dict(self, state_dict, strict=True):
"""Normal `load_state_dict` behavior, except for the shared
criterion weights, which are not saved under `criteria.i.weight`
but under `prefix.weight`.
"""
# Get the weight from the state_dict
old_format = state_dict.get('criteria.0.weight')
new_format = state_dict.get('weight')
weight = new_format if new_format is not None else old_format
for k in [f"criteria.{i}.weight" for i in range(len(self))]:
if k in state_dict.keys():
state_dict.pop(k)
# Normal load_state_dict, ignoring self.criteria.0.weight and
# self.weight
out = super().load_state_dict(state_dict, strict=strict)
# Set the weight
self.weight = weight
return out