| |
| |
| |
|
|
| import pdb |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from nets.sampler import * |
| from nets.repeatability_loss import * |
| from nets.reliability_loss import * |
|
|
|
|
| class MultiLoss(nn.Module): |
| """Combines several loss functions for convenience. |
| *args: [loss weight (float), loss creator, ... ] |
| |
| Example: |
| loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() ) |
| """ |
|
|
| def __init__(self, *args, dbg=()): |
| nn.Module.__init__(self) |
| assert len(args) % 2 == 0, "args must be a list of (float, loss)" |
| self.weights = [] |
| self.losses = nn.ModuleList() |
| for i in range(len(args) // 2): |
| weight = float(args[2 * i + 0]) |
| loss = args[2 * i + 1] |
| assert isinstance(loss, nn.Module), "%s is not a loss!" % loss |
| self.weights.append(weight) |
| self.losses.append(loss) |
|
|
| def forward(self, select=None, **variables): |
| assert not select or all(1 <= n <= len(self.losses) for n in select) |
| d = dict() |
| cum_loss = 0 |
| for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses), 1): |
| if select is not None and num not in select: |
| continue |
| l = loss_func(**{k: v for k, v in variables.items()}) |
| if isinstance(l, tuple): |
| assert len(l) == 2 and isinstance(l[1], dict) |
| else: |
| l = l, {loss_func.name: l} |
| cum_loss = cum_loss + weight * l[0] |
| for key, val in l[1].items(): |
| d["loss_" + key] = float(val) |
| d["loss"] = float(cum_loss) |
| return cum_loss, d |
|
|