Spaces:
Sleeping
Sleeping
| """ | |
| Full ITDF model: ViT encoder → SIREN implicit decoder. | |
| Input: makeup image (B, 3, H, W) | |
| Output: restored bare face (B, 3, H, W) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from models.encoder import ViTEncoder | |
| from models.decoder import SIRENDecoder, make_coord_grid | |
| class ITDF(nn.Module): | |
| def __init__(self, cfg: dict): | |
| super().__init__() | |
| mc = cfg["model"] | |
| data_cfg = cfg["data"] | |
| self.encoder = ViTEncoder( | |
| model_name=mc["encoder"], | |
| out_dim=mc["encoder_out_dim"], | |
| img_size=mc.get("encoder_img_size", data_cfg.get("img_size", 512)), | |
| pretrained=True, | |
| ) | |
| self.decoder = SIRENDecoder( | |
| feat_dim=mc["encoder_out_dim"], | |
| hidden_dim=mc["siren_hidden"], | |
| n_layers=mc["siren_layers"], | |
| omega_0=mc["siren_omega0"], | |
| fourier_bands=mc["fourier_bands"], | |
| out_channels=mc["decoder_out_channels"], | |
| ) | |
| self.patch_size = 16 | |
| def _pad_to_patch_multiple(self, makeup: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: | |
| """Pad on the bottom/right so H and W are divisible by the ViT patch size.""" | |
| h, w = makeup.shape[-2:] | |
| pad_h = (self.patch_size - h % self.patch_size) % self.patch_size | |
| pad_w = (self.patch_size - w % self.patch_size) % self.patch_size | |
| if pad_h == 0 and pad_w == 0: | |
| return makeup, (h, w) | |
| padded = torch.nn.functional.pad(makeup, (0, pad_w, 0, pad_h)) | |
| return padded, (h, w) | |
| def forward(self, makeup: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| makeup: (B, 3, H, W) in [-1, 1] | |
| Returns: | |
| restored: (B, 3, H, W) in [-1, 1] | |
| """ | |
| B = makeup.shape[0] | |
| makeup_padded, original_hw = self._pad_to_patch_multiple(makeup) | |
| H, W = makeup_padded.shape[-2:] | |
| # encode | |
| Z = self.encoder(makeup_padded) # (B, C, h, w) | |
| # build coordinate grid | |
| coords = make_coord_grid(H, W, makeup.device) # (1, H*W, 2) | |
| coords = coords.expand(B, -1, -1) # (B, H*W, 2) | |
| # decode | |
| rgb = self.decoder(coords, Z) # (B, H*W, 3) | |
| restored = rgb.permute(0, 2, 1).reshape(B, 3, H, W) | |
| restored = restored[..., : original_hw[0], : original_hw[1]] | |
| return restored | |