File size: 2,713 Bytes
7c870b8
 
 
 
ef17af7
 
 
 
 
 
 
7c870b8
ef17af7
 
7c870b8
ef17af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c870b8
ef17af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c870b8
 
ef17af7
7c870b8
ef17af7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Residual Block ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

# --- DepthSTAR Model ---
class DepthSTAR(nn.Module):
    def __init__(
        self,
        use_residual_blocks=True,
        use_transformer=True,
        transformer_layers=8,
        transformer_heads=8,
        embed_dim=512,
    ):
        super().__init__()
        self.use_residual_blocks = use_residual_blocks
        self.use_transformer = use_transformer

        encoder_layers = [
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        ]
        if use_residual_blocks:
            encoder_layers.append(ResidualBlock(128))
        encoder_layers += [
            nn.Conv2d(128, embed_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        ]
        if use_residual_blocks:
            encoder_layers.append(ResidualBlock(embed_dim))

        self.encoder = nn.Sequential(*encoder_layers)

        if use_transformer:
            self.bottleneck = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=embed_dim,
                    nhead=transformer_heads,
                    dim_feedforward=embed_dim * 4,
                    batch_first=True
                ),
                num_layers=transformer_layers
            )

        decoder_layers = [
            nn.ConvTranspose2d(embed_dim, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        ]
        if use_residual_blocks:
            decoder_layers.append(ResidualBlock(128))
        decoder_layers += [
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        ]
        self.decoder = nn.Sequential(*decoder_layers)

    def forward(self, x):
        B = x.size(0)
        feat = self.encoder(x)
        if self.use_transformer:
            tokens = feat.flatten(2).transpose(1, 2)
            tokens = self.bottleneck(tokens)
            feat = tokens.transpose(1, 2).reshape(B, feat.shape[1], feat.shape[2], feat.shape[3])
        return self.decoder(feat)