keivalya commited on
Commit
ef17af7
·
verified ·
1 Parent(s): df30dec

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +73 -13
model.py CHANGED
@@ -2,22 +2,82 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- class HybridDepthModel(nn.Module):
6
- def __init__(self):
7
- super(HybridDepthModel, self).__init__()
8
- self.encoder = nn.Sequential(
9
- nn.Conv2d(3, 64, 3, padding=1),
10
- nn.ReLU(),
11
- nn.Conv2d(64, 128, 3, stride=2, padding=1),
12
  nn.ReLU(),
 
 
13
  )
14
- self.decoder = nn.Sequential(
15
- nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  nn.ReLU(),
17
- nn.Conv2d(64, 1, 3, padding=1),
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def forward(self, x):
 
21
  feat = self.encoder(x)
22
- out = self.decoder(feat)
23
- return out
 
 
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ # --- Residual Block ---
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.block = nn.Sequential(
10
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
11
+ nn.BatchNorm2d(channels),
12
  nn.ReLU(),
13
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(channels)
15
  )
16
+
17
+ def forward(self, x):
18
+ return F.relu(x + self.block(x))
19
+
20
+ # --- DepthSTAR Model ---
21
+ class DepthSTAR(nn.Module):
22
+ def __init__(
23
+ self,
24
+ use_residual_blocks=True,
25
+ use_transformer=True,
26
+ transformer_layers=8,
27
+ transformer_heads=8,
28
+ embed_dim=512,
29
+ ):
30
+ super().__init__()
31
+ self.use_residual_blocks = use_residual_blocks
32
+ self.use_transformer = use_transformer
33
+
34
+ encoder_layers = [
35
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
36
  nn.ReLU(),
37
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
38
+ nn.ReLU()
39
+ ]
40
+ if use_residual_blocks:
41
+ encoder_layers.append(ResidualBlock(128))
42
+ encoder_layers += [
43
+ nn.Conv2d(128, embed_dim, kernel_size=3, stride=2, padding=1),
44
+ nn.ReLU()
45
+ ]
46
+ if use_residual_blocks:
47
+ encoder_layers.append(ResidualBlock(embed_dim))
48
+
49
+ self.encoder = nn.Sequential(*encoder_layers)
50
+
51
+ if use_transformer:
52
+ self.bottleneck = nn.TransformerEncoder(
53
+ nn.TransformerEncoderLayer(
54
+ d_model=embed_dim,
55
+ nhead=transformer_heads,
56
+ dim_feedforward=embed_dim * 4,
57
+ batch_first=True
58
+ ),
59
+ num_layers=transformer_layers
60
+ )
61
+
62
+ decoder_layers = [
63
+ nn.ConvTranspose2d(embed_dim, 128, kernel_size=4, stride=2, padding=1),
64
+ nn.ReLU()
65
+ ]
66
+ if use_residual_blocks:
67
+ decoder_layers.append(ResidualBlock(128))
68
+ decoder_layers += [
69
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
70
+ nn.ReLU(),
71
+ nn.Conv2d(64, 1, kernel_size=3, padding=1),
72
+ nn.Sigmoid()
73
+ ]
74
+ self.decoder = nn.Sequential(*decoder_layers)
75
 
76
  def forward(self, x):
77
+ B = x.size(0)
78
  feat = self.encoder(x)
79
+ if self.use_transformer:
80
+ tokens = feat.flatten(2).transpose(1, 2)
81
+ tokens = self.bottleneck(tokens)
82
+ feat = tokens.transpose(1, 2).reshape(B, feat.shape[1], feat.shape[2], feat.shape[3])
83
+ return self.decoder(feat)