|
|
from torch import nn |
|
|
|
|
|
|
|
|
__all__ = ['MultiLoss'] |
|
|
|
|
|
|
|
|
class MultiLoss(nn.Module): |
|
|
"""Wrapper to compute the weighted sum of multiple criteria |
|
|
|
|
|
:param criteria: List(callable) |
|
|
List of criteria |
|
|
:param lambdas: List(str) |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, criteria, lambdas): |
|
|
super().__init__() |
|
|
assert len(criteria) == len(lambdas) |
|
|
self.criteria = nn.ModuleList(criteria) |
|
|
self.lambdas = lambdas |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.criteria) |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
for i in range(len(self)): |
|
|
self.criteria[i] = self.criteria[i].to(*args, **kwargs) |
|
|
self.lambdas[i] = self.lambdas[i].to(*args, **kwargs) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'lambdas={self.lambdas}' |
|
|
|
|
|
def forward(self, a, b, **kwargs): |
|
|
loss = 0 |
|
|
for lamb, criterion, a_, b_ in zip(self.lambdas, self.criteria, a, b): |
|
|
loss = loss + lamb * criterion(a_, b_, **kwargs) |
|
|
return loss |
|
|
|
|
|
@property |
|
|
def weight(self): |
|
|
"""MultiLoss supports `weight` if all its criteria support it. |
|
|
""" |
|
|
return self.criteria[0].weight |
|
|
|
|
|
@weight.setter |
|
|
def weight(self, weight): |
|
|
"""MultiLoss supports `weight` if all its criteria support it. |
|
|
""" |
|
|
for i in range(len(self)): |
|
|
self.criteria[i].weight = weight |
|
|
|
|
|
def state_dict(self, *args, destination=None, prefix='', keep_vars=False): |
|
|
"""Normal `state_dict` behavior, except for the shared criterion |
|
|
weights, which are not saved under `prefix.criteria.i.weight` |
|
|
but under `prefix.weight`. |
|
|
""" |
|
|
destination = super().state_dict( |
|
|
*args, destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
|
|
|
|
|
|
for i in range(len(self)): |
|
|
destination.pop(f"{prefix}criteria.{i}.weight") |
|
|
|
|
|
|
|
|
destination[f"{prefix}weight"] = self.weight |
|
|
|
|
|
return destination |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
"""Normal `load_state_dict` behavior, except for the shared |
|
|
criterion weights, which are not saved under `criteria.i.weight` |
|
|
but under `prefix.weight`. |
|
|
""" |
|
|
|
|
|
old_format = state_dict.get('criteria.0.weight') |
|
|
new_format = state_dict.get('weight') |
|
|
weight = new_format if new_format is not None else old_format |
|
|
for k in [f"criteria.{i}.weight" for i in range(len(self))]: |
|
|
if k in state_dict.keys(): |
|
|
state_dict.pop(k) |
|
|
|
|
|
|
|
|
|
|
|
out = super().load_state_dict(state_dict, strict=strict) |
|
|
|
|
|
|
|
|
self.weight = weight |
|
|
|
|
|
return out |
|
|
|