Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| from typing import Any, Dict | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def impaint_batch(value: th.Tensor, dst_ij: th.Tensor, src_ij: th.Tensor) -> th.Tensor: | |
| assert len(value.shape) == 4, "expecting a 4D tensor" | |
| preds = value[:] | |
| preds[:, :, dst_ij[:, 0], dst_ij[:, 1]] = value[:, :, src_ij[:, 0], src_ij[:, 1]] | |
| return preds | |
| def resample_tex(tex: th.Tensor, uvs: th.Tensor, weights: th.Tensor) -> th.Tensor: | |
| B = tex.shape[0] | |
| grid = 2.0 * (uvs[np.newaxis].expand(B, -1, -1, -1) - 0.5) | |
| tex_resampled = F.grid_sample(tex, grid, align_corners=False, padding_mode="border") | |
| return (1.0 - weights) * tex + weights * tex_resampled | |
| class SeamSampler(nn.Module): | |
| def __init__(self, seamless_data: Dict[str, Any]) -> None: | |
| super().__init__() | |
| self.register_buffer("dst_ij", seamless_data["dst_ij"]) | |
| self.register_buffer("src_ij", seamless_data["src_ij"]) | |
| self.register_buffer("uvs", seamless_data["uvs"]) | |
| self.register_buffer("weights", seamless_data["weights"]) | |
| def impaint(self, value: th.Tensor) -> th.Tensor: | |
| return impaint_batch(value, self.dst_ij, self.src_ij) | |
| def resample(self, tex: th.Tensor) -> th.Tensor: | |
| return resample_tex(tex, self.uvs, self.weights) | |
| def resample_border_only(self, tex: th.Tensor) -> th.Tensor: | |
| tex = resample_tex(tex, self.uvs, self.weights) | |
| return tex | |
| def forward(self, tex: th.Tensor) -> th.Tensor: | |
| x = self.impaint(tex) | |
| x = self.resample(x) | |
| return x | |