File size: 3,059 Bytes
65b0806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualRenderBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.GroupNorm(8, dim),
            nn.SiLU(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1),
            nn.GroupNorm(8, dim)
        )

    def forward(self, x):
        return x + self.block(x)

class RenderEncoder(nn.Module):
    def __init__(self, encoder_type="1d", in_channels=768, out_channels=3):
        super().__init__()
        self.encoder_type = encoder_type

        if encoder_type == "1d":
            self.model = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.Sigmoid()
            )

        elif encoder_type == "residual":
            self.model = ResidualBlockRender(in_channels, out_channels)

        elif encoder_type == "expressive":
            mid_channels = 256
            self.model = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
                nn.GroupNorm(8, mid_channels),
                nn.SiLU(),
                ResidualRenderBlock(mid_channels),
                ResidualRenderBlock(mid_channels),
                ResidualRenderBlock(mid_channels),
                nn.Conv2d(mid_channels, out_channels, kernel_size=1),
                nn.Sigmoid()
            )

        else:
            raise ValueError(f"Unknown encoder_type '{encoder_type}'. Use '1d', 'residual', or 'expressive'.")

    def forward(self, x):
        return self.model(x)

class ResidualBlockRender(nn.Module):
    def __init__(self, in_channels=768, out_channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(256, out_channels, kernel_size=1)
        self.out = nn.Sigmoid()

        if in_channels != out_channels:
            self.residual_proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.residual_proj = nn.Identity()

    def forward(self, x):
        residual = self.residual_proj(x)
        h = self.relu1(self.conv1(x))
        h = self.relu2(self.conv2(h))
        h = self.conv3(h)
        h = h + residual
        return self.out(h)

def load_render_encoder(checkpoint_path, device='cpu'):
    """Load standalone RenderEncoder from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    config = checkpoint['model_config']
    model = RenderEncoder(
        encoder_type=config['encoder_type'],
        in_channels=config['in_channels'],
        out_channels=config['out_channels']
    )
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Loaded RenderEncoder: {config}")
    return model