""" Full ITDF model: ViT encoder → SIREN implicit decoder. Input: makeup image (B, 3, H, W) Output: restored bare face (B, 3, H, W) """ import torch import torch.nn as nn from models.encoder import ViTEncoder from models.decoder import SIRENDecoder, make_coord_grid class ITDF(nn.Module): def __init__(self, cfg: dict): super().__init__() mc = cfg["model"] data_cfg = cfg["data"] self.encoder = ViTEncoder( model_name=mc["encoder"], out_dim=mc["encoder_out_dim"], img_size=mc.get("encoder_img_size", data_cfg.get("img_size", 512)), pretrained=True, ) self.decoder = SIRENDecoder( feat_dim=mc["encoder_out_dim"], hidden_dim=mc["siren_hidden"], n_layers=mc["siren_layers"], omega_0=mc["siren_omega0"], fourier_bands=mc["fourier_bands"], out_channels=mc["decoder_out_channels"], ) self.patch_size = 16 def _pad_to_patch_multiple(self, makeup: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: """Pad on the bottom/right so H and W are divisible by the ViT patch size.""" h, w = makeup.shape[-2:] pad_h = (self.patch_size - h % self.patch_size) % self.patch_size pad_w = (self.patch_size - w % self.patch_size) % self.patch_size if pad_h == 0 and pad_w == 0: return makeup, (h, w) padded = torch.nn.functional.pad(makeup, (0, pad_w, 0, pad_h)) return padded, (h, w) def forward(self, makeup: torch.Tensor) -> torch.Tensor: """ Args: makeup: (B, 3, H, W) in [-1, 1] Returns: restored: (B, 3, H, W) in [-1, 1] """ B = makeup.shape[0] makeup_padded, original_hw = self._pad_to_patch_multiple(makeup) H, W = makeup_padded.shape[-2:] # encode Z = self.encoder(makeup_padded) # (B, C, h, w) # build coordinate grid coords = make_coord_grid(H, W, makeup.device) # (1, H*W, 2) coords = coords.expand(B, -1, -1) # (B, H*W, 2) # decode rgb = self.decoder(coords, Z) # (B, H*W, 3) restored = rgb.permute(0, 2, 1).reshape(B, 3, H, W) restored = restored[..., : original_hw[0], : original_hw[1]] return restored