itdf-space / models /itdf.py
priyadip's picture
Upload models/itdf.py with huggingface_hub
488ce43 verified
"""
Full ITDF model: ViT encoder → SIREN implicit decoder.
Input: makeup image (B, 3, H, W)
Output: restored bare face (B, 3, H, W)
"""
import torch
import torch.nn as nn
from models.encoder import ViTEncoder
from models.decoder import SIRENDecoder, make_coord_grid
class ITDF(nn.Module):
def __init__(self, cfg: dict):
super().__init__()
mc = cfg["model"]
data_cfg = cfg["data"]
self.encoder = ViTEncoder(
model_name=mc["encoder"],
out_dim=mc["encoder_out_dim"],
img_size=mc.get("encoder_img_size", data_cfg.get("img_size", 512)),
pretrained=True,
)
self.decoder = SIRENDecoder(
feat_dim=mc["encoder_out_dim"],
hidden_dim=mc["siren_hidden"],
n_layers=mc["siren_layers"],
omega_0=mc["siren_omega0"],
fourier_bands=mc["fourier_bands"],
out_channels=mc["decoder_out_channels"],
)
self.patch_size = 16
def _pad_to_patch_multiple(self, makeup: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
"""Pad on the bottom/right so H and W are divisible by the ViT patch size."""
h, w = makeup.shape[-2:]
pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
if pad_h == 0 and pad_w == 0:
return makeup, (h, w)
padded = torch.nn.functional.pad(makeup, (0, pad_w, 0, pad_h))
return padded, (h, w)
def forward(self, makeup: torch.Tensor) -> torch.Tensor:
"""
Args:
makeup: (B, 3, H, W) in [-1, 1]
Returns:
restored: (B, 3, H, W) in [-1, 1]
"""
B = makeup.shape[0]
makeup_padded, original_hw = self._pad_to_patch_multiple(makeup)
H, W = makeup_padded.shape[-2:]
# encode
Z = self.encoder(makeup_padded) # (B, C, h, w)
# build coordinate grid
coords = make_coord_grid(H, W, makeup.device) # (1, H*W, 2)
coords = coords.expand(B, -1, -1) # (B, H*W, 2)
# decode
rgb = self.decoder(coords, Z) # (B, H*W, 3)
restored = rgb.permute(0, 2, 1).reshape(B, 3, H, W)
restored = restored[..., : original_hw[0], : original_hw[1]]
return restored