File size: 2,384 Bytes
488ce43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Full ITDF model: ViT encoder → SIREN implicit decoder.

Input:  makeup image  (B, 3, H, W)
Output: restored bare face  (B, 3, H, W)
"""

import torch
import torch.nn as nn

from models.encoder import ViTEncoder
from models.decoder import SIRENDecoder, make_coord_grid


class ITDF(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        mc = cfg["model"]
        data_cfg = cfg["data"]

        self.encoder = ViTEncoder(
            model_name=mc["encoder"],
            out_dim=mc["encoder_out_dim"],
            img_size=mc.get("encoder_img_size", data_cfg.get("img_size", 512)),
            pretrained=True,
        )
        self.decoder = SIRENDecoder(
            feat_dim=mc["encoder_out_dim"],
            hidden_dim=mc["siren_hidden"],
            n_layers=mc["siren_layers"],
            omega_0=mc["siren_omega0"],
            fourier_bands=mc["fourier_bands"],
            out_channels=mc["decoder_out_channels"],
        )
        self.patch_size = 16

    def _pad_to_patch_multiple(self, makeup: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
        """Pad on the bottom/right so H and W are divisible by the ViT patch size."""
        h, w = makeup.shape[-2:]
        pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
        pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
        if pad_h == 0 and pad_w == 0:
            return makeup, (h, w)
        padded = torch.nn.functional.pad(makeup, (0, pad_w, 0, pad_h))
        return padded, (h, w)

    def forward(self, makeup: torch.Tensor) -> torch.Tensor:
        """
        Args:
            makeup: (B, 3, H, W) in [-1, 1]
        Returns:
            restored: (B, 3, H, W) in [-1, 1]
        """
        B = makeup.shape[0]
        makeup_padded, original_hw = self._pad_to_patch_multiple(makeup)
        H, W = makeup_padded.shape[-2:]

        # encode
        Z = self.encoder(makeup_padded)               # (B, C, h, w)

        # build coordinate grid
        coords = make_coord_grid(H, W, makeup.device)  # (1, H*W, 2)
        coords = coords.expand(B, -1, -1)               # (B, H*W, 2)

        # decode
        rgb = self.decoder(coords, Z)                # (B, H*W, 3)
        restored = rgb.permute(0, 2, 1).reshape(B, 3, H, W)
        restored = restored[..., : original_hw[0], : original_hw[1]]
        return restored