Spaces:
Runtime error
Runtime error
| # Code modified from https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/bregman_pytorch.py | |
| import torch | |
| from torch import Tensor | |
| from torch.cuda.amp import autocast | |
| from typing import Union, Tuple, Dict | |
| M_EPS = 1e-16 | |
| # avoid numerical instability | |
| def sinkhorn( | |
| a: Tensor, | |
| b: Tensor, | |
| C: Tensor, | |
| reg: float = 1e-1, | |
| maxIter: int = 1000, | |
| stopThr: float = 1e-9, | |
| verbose: bool = False, | |
| log: bool = True, | |
| eval_freq: int = 10, | |
| print_freq: int = 200, | |
| ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: | |
| """ | |
| Solve the entropic regularization optimal transport | |
| The input should be PyTorch tensors | |
| The function solves the following optimization problem: | |
| .. math:: | |
| \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) | |
| s.t. \gamma 1 = a | |
| \gamma^T 1= b | |
| \gamma\geq 0 | |
| where : | |
| - C is the (ns,nt) metric cost matrix | |
| - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` | |
| - a and b are target and source measures (sum to 1) | |
| The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. | |
| Parameters | |
| ---------- | |
| a : torch.tensor (na,) | |
| samples measure in the target domain | |
| b : torch.tensor (nb,) | |
| samples in the source domain | |
| C : torch.tensor (na,nb) | |
| loss matrix | |
| reg : float | |
| Regularization term > 0 | |
| maxIter : int, optional | |
| Max number of iterations | |
| stopThr : float, optional | |
| Stop threshol on error ( > 0 ) | |
| verbose : bool, optional | |
| Print information along iterations | |
| log : bool, optional | |
| record log if True | |
| Returns | |
| ------- | |
| gamma : (na x nb) torch.tensor | |
| Optimal transportation matrix for the given parameters | |
| log : dict | |
| log dictionary return only if log==True in parameters | |
| References | |
| ---------- | |
| [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 | |
| See Also | |
| -------- | |
| """ | |
| device = a.device | |
| na, nb = C.shape | |
| # a = a.view(-1, 1) | |
| # b = b.view(-1, 1) | |
| assert na >= 1 and nb >= 1, f"C needs to be 2d. Found C.shape = {C.shape}" | |
| assert na == a.shape[0] and nb == b.shape[0], f"Shape of a ({a.shape}) or b ({b.shape}) does not match that of C ({C.shape})" | |
| assert reg > 0, f"reg should be greater than 0. Found reg = {reg}" | |
| assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}" | |
| if log: | |
| log = {"err": []} | |
| u = torch.ones((na), dtype=a.dtype).to(device) / na | |
| v = torch.ones((nb), dtype=b.dtype).to(device) / nb | |
| K = torch.empty(C.shape, dtype=C.dtype).to(device) | |
| torch.div(C, -reg, out=K) | |
| torch.exp(K, out=K) | |
| b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) | |
| it = 1 | |
| err = 1 | |
| # allocate memory beforehand | |
| KTu = torch.empty(v.shape, dtype=v.dtype).to(device) | |
| Kv = torch.empty(u.shape, dtype=u.dtype).to(device) | |
| while (err > stopThr and it <= maxIter): | |
| upre, vpre = u, v | |
| # torch.matmul(u, K, out=KTu) | |
| KTu = torch.matmul(u.view(1, -1), K).view(-1) | |
| v = torch.div(b, KTu + M_EPS) | |
| # torch.matmul(K, v, out=Kv) | |
| Kv = torch.matmul(K, v.view(-1, 1)).view(-1) | |
| u = torch.div(a, Kv + M_EPS) | |
| if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ | |
| torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): | |
| print("Warning: numerical errors at iteration", it) | |
| u, v = upre, vpre | |
| break | |
| if log and it % eval_freq == 0: | |
| # we can speed up the process by checking for the error only all | |
| # the eval_freq iterations | |
| # below is equivalent to: | |
| # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) | |
| # but with more memory efficient | |
| b_hat = (torch.matmul(u.view(1, -1), K) * v.view(1, -1)).view(-1) | |
| err = (b - b_hat).pow(2).sum().item() | |
| # err = (b - b_hat).abs().sum().item() | |
| log["err"].append(err) | |
| if verbose and it % print_freq == 0: | |
| print("iteration {:5d}, constraint error {:5e}".format(it, err)) | |
| it += 1 | |
| if log: | |
| log["u"] = u | |
| log["v"] = v | |
| log["alpha"] = reg * torch.log(u + M_EPS) | |
| log["beta"] = reg * torch.log(v + M_EPS) | |
| # transport plan | |
| P = u.reshape(-1, 1) * K * v.reshape(1, -1) | |
| if log: | |
| return P, log | |
| else: | |
| return P | |