| |
| import torch |
| from torch.amp import autocast |
| from torch import Tensor |
| from typing import Union, Tuple, Dict |
|
|
| M_EPS = 1e-16 |
|
|
|
|
| @torch.no_grad() |
| @autocast(device_type="cuda", enabled=True, dtype=torch.float32) |
| def sinkhorn( |
| a: Tensor, |
| b: Tensor, |
| C: Tensor, |
| reg: float = 1e-1, |
| maxIter: int = 1000, |
| stopThr: float = 1e-9, |
| verbose: bool = False, |
| log: bool = True, |
| eval_freq: int = 10, |
| print_freq: int = 200, |
| ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: |
| device = a.device |
| na, nb = C.shape |
| assert na == a.shape[0] and nb == b.shape[0], f"Shapes of a ({a.shape}) or b ({b.shape}) do not match that of C ({C.shape})" |
| assert reg > 0, f"reg should be greater than 0. Found reg = {reg}" |
| assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}" |
|
|
| if log: |
| log = {"err": []} |
|
|
| u = torch.ones(na, dtype=a.dtype, device=device) / na |
| v = torch.ones(nb, dtype=b.dtype, device=device) / nb |
| K = torch.exp(-C / reg) |
|
|
| it, err = 1, 1 |
| while (err > stopThr and it <= maxIter): |
| u_pre, v_pre = u.clone(), v.clone() |
| KTu = torch.matmul(K.T, u) |
| v = b / (KTu + M_EPS) |
| Kv = torch.matmul(K, v) |
| u = a / (Kv + M_EPS) |
|
|
| if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): |
| print("Warning: numerical errors at iteration", it) |
| u, v = u_pre, v_pre |
| break |
|
|
| if log and it % eval_freq == 0: |
| b_hat = torch.matmul(u, K) * v |
| err = (b - b_hat).pow(2).sum().item() |
| log["err"].append(err) |
|
|
| if verbose and it % print_freq == 0: |
| print(f"Iteration {it}, constraint error {err}") |
|
|
| it += 1 |
|
|
| if log: |
| log["u"] = u |
| log["v"] = v |
| log["alpha"] = reg * torch.log(u + M_EPS) |
| log["beta"] = reg * torch.log(v + M_EPS) |
|
|
| P = u.view(-1, 1) * K * v.view(1, -1) |
| if log: |
| return P, log |
| else: |
| return P |
|
|