3ZadeSSG commited on
Commit
3b9f59f
·
1 Parent(s): cabef4a

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Models and Engines
7
+ *.onnx
8
+ *.onnx.data
9
+ *.pth
10
+ *.engine
11
+
12
+ # Images
13
+ *.png
14
+ *.jpeg
15
+ *.JPG
16
+
17
+ # Videos
18
+ *.mp4
19
+
20
+ # Logs
21
+ logs/
.huggingface.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: gradio
2
+ python_version: '3.12'
3
+ requirements_file: requirements.txt
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import tempfile
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ import matplotlib.pyplot as plt
9
+ from models.pvsdnet_model import PVSDNet
10
+ from models.pvsdnet_lite_model import PVSDNet_Lite
11
+ import helperFunctions as helper
12
+ import parameters_pvsdnet as params
13
+ from huggingface_hub import hf_hub_download
14
+ import joblib
15
+
16
+ REPO_ID = "3ZadeSSG/PVSDNet"
17
+ print("Downloading/Loading checkpoints from Hugging Face Hub...")
18
+ MODEL_PVSDNET_LITE_LOCATION = hf_hub_download(
19
+ repo_id=REPO_ID,
20
+ filename="pvsdnet_lite_model.pth"
21
+ )
22
+
23
+ MODEL_PVSDNET_LOCATION = hf_hub_download(
24
+ repo_id=REPO_ID,
25
+ filename="pvsdnet_model.pth"
26
+ )
27
+
28
+ print(f"Large Model loaded at: {MODEL_PVSDNET_LITE_LOCATION}")
29
+ print(f"Lite Model loaded at: {MODEL_PVSDNET_LOCATION}")
30
+
31
+ DEVICE = params.DEVICE
32
+
33
+ def getPositionVector(x, y, z):
34
+ vector = torch.zeros((1, 3), dtype=torch.float)
35
+ normalized_x = (float(format(x, '.7f')) - (-0.1)) / (0.1 - (-0.1))
36
+ normalized_y = (float(format(y, '.7f')) - (-0.1)) / (0.1 - (-0.1))
37
+ normalized_z = (float(format(z, '.7f')) - (-0.1)) / (0.1 - (-0.1))
38
+ vector[0][0] = normalized_x
39
+ vector[0][1] = normalized_y
40
+ vector[0][2] = normalized_z
41
+ return vector
42
+
43
+ def generateCircularTrajectory(radius, num_frames):
44
+ angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
45
+ return [[radius * np.cos(angle), radius * np.sin(angle), 0] for angle in angles]
46
+
47
+ def generateSwingTrajectory(radius, num_frames):
48
+ angles = np.linspace(0, 2 * np.pi, num_frames, endpoint=False)
49
+ return [[radius * np.cos(angle), 0, radius * np.sin(angle)] for angle in angles]
50
+
51
+ def create_video_from_memory(frames, fps=30):
52
+ if not frames:
53
+ return None
54
+ height, width, _ = frames[0].shape
55
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
56
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
57
+ out = cv2.VideoWriter(temp_video.name, fourcc, fps, (width, height))
58
+ for frame in frames:
59
+ out.write(frame)
60
+ out.release()
61
+ return temp_video.name
62
+
63
+ def process_image(img, video_type, radius, num_frames, num_loops, model_type):
64
+ if img is None:
65
+ return None, None
66
+
67
+ height, width = 256, 256
68
+
69
+ min_dim = min(img.width, img.height)
70
+ left = (img.width - min_dim) / 2
71
+ top = (img.height - min_dim) / 2
72
+ right = (img.width + min_dim) / 2
73
+ bottom = (img.height + min_dim) / 2
74
+ img = img.crop((left, top, right, bottom))
75
+
76
+ if model_type == "PVSDNet Lite":
77
+ model = PVSDNet_Lite(total_image_input=params.params_number_input)
78
+ checkpoint = MODEL_PVSDNET_LITE_LOCATION
79
+ else:
80
+ model = PVSDNet(total_image_input=params.params_number_input)
81
+ checkpoint = MODEL_PVSDNET_LOCATION
82
+
83
+ try:
84
+ model = helper.load_Checkpoint(checkpoint, model, load_cpu=True)
85
+ except Exception as e:
86
+ print(f"Error loading checkpoint {checkpoint}: {e}")
87
+
88
+ model.to(DEVICE)
89
+ model.eval()
90
+
91
+ transform = transforms.Compose([
92
+ transforms.Resize((height, width)),
93
+ transforms.ToTensor()
94
+ ])
95
+
96
+ img_input = img.convert('RGB')
97
+ img_input = transform(img_input).unsqueeze(0).to(DEVICE)
98
+
99
+ if video_type == "Circle":
100
+ raw_traj = generateCircularTrajectory(radius, num_frames)
101
+ trajectory = [(p[0], p[1], 0) for p in raw_traj]
102
+
103
+ elif video_type == "Swing":
104
+ raw_traj = generateSwingTrajectory(radius, num_frames)
105
+ trajectory = raw_traj
106
+ else:
107
+ raw_traj = generateCircularTrajectory(radius, num_frames)
108
+ trajectory = [(p[0], p[1], 0) for p in raw_traj]
109
+
110
+ view_frames = []
111
+ depth_frames = []
112
+
113
+ # Run inference for a single loop (trajectory) to save computation
114
+ for x, y, z in trajectory:
115
+ pos = getPositionVector(x, y, z).unsqueeze(0).to(DEVICE)
116
+
117
+ with torch.no_grad():
118
+ predicted_img, predicted_depth = model(img_input, pos)
119
+
120
+ p_img = predicted_img[0].detach().cpu().permute(1, 2, 0).numpy()
121
+ p_img = np.clip(p_img, 0, 1)
122
+ p_img = (p_img * 255).astype(np.uint8)
123
+ p_img_bgr = cv2.cvtColor(p_img, cv2.COLOR_RGB2BGR)
124
+ view_frames.append(p_img_bgr)
125
+
126
+ d_img = predicted_depth.squeeze().detach().cpu().numpy()
127
+ d_min, d_max = d_img.min(), d_img.max()
128
+ if d_max - d_min > 1e-6:
129
+ d_img = (d_img - d_min) / (d_max - d_min)
130
+ else:
131
+ d_img = np.zeros_like(d_img)
132
+
133
+ d_img_colored = plt.get_cmap('inferno')(d_img)[:, :, :3]
134
+ d_img_colored = (d_img_colored * 255).astype(np.uint8)
135
+ d_img_bgr = cv2.cvtColor(d_img_colored, cv2.COLOR_RGB2BGR)
136
+ depth_frames.append(d_img_bgr)
137
+
138
+ # Repeat the frames for the requested number of loops
139
+ view_frames = view_frames * int(num_loops)
140
+ depth_frames = depth_frames * int(num_loops)
141
+
142
+ fps = 60
143
+ view_video_path = create_video_from_memory(view_frames, fps=fps)
144
+ depth_video_path = create_video_from_memory(depth_frames, fps=fps)
145
+
146
+ return view_video_path, depth_video_path
147
+
148
+ with gr.Blocks(title="PVSDNet: View & Depth Synthesis", theme="default") as demo:
149
+ gr.Markdown(
150
+ """
151
+ ## PVSDNet: Joint Depth Prediction and View Synthesis via Shared Latent Spaces in Real-Time
152
+ * Upload an image and get a mini video showing capability of novel view and depth synthesis.
153
+
154
+ **Note:** Huggingface demo is running on CPU so inference speeds will be slow. Inference might take around 2 mins.
155
+ ### Head to our [Project Page](https://realistic3d-miun.github.io/PVSDNet/) for more details about the models.
156
+ """)
157
+
158
+ with gr.Row():
159
+ with gr.Column():
160
+ img_input = gr.Image(type="pil", label="Input Image", height=256)
161
+
162
+ with gr.Group():
163
+ video_type = gr.Dropdown(["Circle", "Swing"], label="Trajectory Type", value="Swing")
164
+ model_type = gr.Dropdown(["PVSDNet", "PVSDNet Lite"], label="Model Type", value="PVSDNet")
165
+
166
+ with gr.Accordion("Advanced Settings", open=False):
167
+ radius = gr.Slider(0.01, 0.1, value=0.06, label="Motion Radius")
168
+ num_frames = gr.Slider(10, 120, value=60, step=1, label="Frames per Loop")
169
+ num_loops = gr.Slider(1, 6, value=3, step=1, label="Number of Loops")
170
+
171
+ submit_btn = gr.Button("Generate", variant="primary")
172
+
173
+ with gr.Column():
174
+ video_output = gr.Video(label="Generated View Video", height=256)
175
+ depth_video_output = gr.Video(label="Generated Depth Video", height=256)
176
+
177
+ submit_btn.click(
178
+ fn=process_image,
179
+ inputs=[img_input, video_type, radius, num_frames, num_loops, model_type],
180
+ outputs=[video_output, depth_video_output]
181
+ )
182
+
183
+ gr.Markdown("### Example Images: Click to Load")
184
+ with gr.Column():
185
+ with gr.Row():
186
+ sample_1 = gr.Image("./samples/PVSDNet_Samples/COCO_59_source_image.png", label="COCO Example 59", height=150, interactive=False, show_label=True)
187
+ sample_2 = gr.Image("./samples/PVSDNet_Samples/COCO_16_source_image.png", label="COCO Example 16", height=150, interactive=False, show_label=True)
188
+ sample_3 = gr.Image("./samples/PVSDNet_Samples/COCO_755_source_image.png", label="COCO Example 755", height=150, interactive=False, show_label=True)
189
+
190
+ with gr.Row():
191
+ sample_4 = gr.Image("./samples/PVSDNet_Samples/COCO_223_source_image.png", label="COCO Example 223", height=150, interactive=False, show_label=True)
192
+ sample_5 = gr.Image("./samples/PVSDNet_Samples/COCO_23_source_image.png", label="COCO Example 23", height=150, interactive=False, show_label=True)
193
+ sample_6 = gr.Image("./samples/PVSDNet_Samples/person.jpeg", label="Person", height=150, interactive=False, show_label=True)
194
+
195
+
196
+ with gr.Row():
197
+ sample_7 = gr.Image("./samples/PVSDNet_Samples/flower.png", label="Flower", height=150, interactive=False, show_label=True)
198
+ sample_8 = gr.Image("./samples/PVSDNet_Samples/person_2.jpeg", label="Person", height=150, interactive=False, show_label=True)
199
+ sample_9 = gr.Image("./samples/PVSDNet_Samples/bakery.jpeg", label="Bakery", height=150, interactive=False, show_label=True)
200
+
201
+ sample_1.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/COCO_59_source_image.png"), outputs=img_input)
202
+ sample_2.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/COCO_16_source_image.png"), outputs=img_input)
203
+ sample_3.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/COCO_755_source_image.png"), outputs=img_input)
204
+
205
+ sample_4.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/COCO_223_source_image.png"), outputs=img_input)
206
+ sample_5.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/COCO_23_source_image.png"), outputs=img_input)
207
+ sample_6.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/person.jpeg"), outputs=img_input)
208
+
209
+ sample_7.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/flower.png"), outputs=img_input)
210
+ sample_8.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/person_2.jpeg"), outputs=img_input)
211
+ sample_9.select(fn=lambda: Image.open("./samples/PVSDNet_Samples/bakery.jpeg"), outputs=img_input)
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch()
helperFunctions.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import torch.nn.functional as F
4
+
5
+ def save_checkpoint(model, filelocation, save_parallel = True):
6
+ if save_parallel:
7
+ torch.save(model.module.state_dict(), filelocation)
8
+ else:
9
+ torch.save(model.state_dict(), filelocation)
10
+
11
+ def load_Checkpoint(fileLocation,model, load_cpu=False):
12
+ if load_cpu:
13
+ model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
14
+ else:
15
+ model.load_state_dict(torch.load(fileLocation))
16
+ return model
17
+
18
+ def writeLog(logList, filename):
19
+ with open(filename, 'w') as outfile:
20
+ outfile.write("\n".join(logList))
21
+
22
+
23
+ def kl_loss(mu, logvar):
24
+ return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
25
+
26
+
models/__init__.py ADDED
File without changes
models/pvsdnet_lite_model.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import torchvision
10
+ import rff.layers as rff
11
+ import parameters_pvsdnet as params
12
+ import helperFunctions as helper
13
+
14
+ def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
15
+ return nn.Sequential(
16
+ nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
17
+ activation
18
+ )
19
+
20
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
21
+ return nn.Sequential(nn.Conv2d(in_channel,
22
+ out_channel,
23
+ kernel_size=3,
24
+ stride=stride,
25
+ padding=padding,
26
+ padding_mode='reflect'),
27
+ activation)
28
+
29
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
30
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
31
+ out_channel,
32
+ kernel_size = kernel,
33
+ stride=stride,
34
+ padding=padding),
35
+ activation)
36
+
37
+ class Flatten(nn.Module):
38
+ def forward(self, input):
39
+ return input.view(input.size(0), -1)
40
+
41
+ class UnFlatten(nn.Module):
42
+ def forward(self, input, size=1):
43
+ return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
44
+
45
+ class ResidualBlock(nn.Module):
46
+ def __init__(self, in_channels, out_channels, stride=1):
47
+ super(ResidualBlock, self).__init__()
48
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
49
+ self.relu = nn.ReLU()
50
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
51
+ self.stride = stride
52
+
53
+ self.shortcut = nn.Sequential()
54
+ if stride != 1 or in_channels != out_channels:
55
+ self.shortcut = nn.Sequential(
56
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
57
+ nn.BatchNorm2d(out_channels)
58
+ )
59
+
60
+ def forward(self, x):
61
+ residual = x
62
+
63
+ out = self.conv1(x)
64
+ out = self.relu(out)
65
+
66
+ out = self.conv2(out)
67
+
68
+ out = out + self.shortcut(residual)
69
+ out = self.relu(out)
70
+ return out
71
+
72
+ class MLPEncoder(nn.Module):
73
+ def __init__(self):
74
+ super().__init__()
75
+ self.m = params.params_m
76
+ self.positional_encoding = rff.PositionalEncoding(sigma=1,m=self.m)
77
+ self.layer1 = getLinearLayer(2*3*self.m, 1024) # 2*3*m = 12, here m=32
78
+ self.dropout1 = nn.Dropout(0.2)
79
+ self.layer2 = getLinearLayer(1024, 2048)
80
+ self.dropout2 = nn.Dropout(0.2)
81
+ self.layer3 = getLinearLayer(2048, (params.params_height//8)*(params.params_width//8))
82
+ self.unflat = UnFlatten()
83
+
84
+ self.up_layer1 = nn.Upsample(scale_factor=2, mode='nearest')
85
+ self.up_layer2 = nn.Upsample(scale_factor=2, mode='nearest')
86
+ self.up_layer3 = nn.Upsample(scale_factor=2, mode='nearest')
87
+
88
+
89
+ def forward(self, x):
90
+ x = self.positional_encoding(x)
91
+
92
+ x = self.layer1(x)
93
+ x = self.dropout1(x)
94
+ x = self.layer2(x)
95
+ x = self.dropout2(x)
96
+ x = self.layer3(x)
97
+
98
+ x = self.unflat(x)
99
+
100
+ x = self.up_layer1(x)
101
+ x = self.up_layer2(x)
102
+ x = self.up_layer3(x)
103
+ return x
104
+
105
+ class UpperEncoder(nn.Module):
106
+ def __init__(self):
107
+ super().__init__()
108
+ model = torchvision.models.resnet152(pretrained=False)
109
+ layers = list(model.children())
110
+ self.ResNetEncoder = torch.nn.Sequential(*layers[:5].copy())
111
+ del model
112
+
113
+ def forward(self, x):
114
+ x1 = x[:, 0:3, :, :]
115
+ x1 = self.ResNetEncoder(x1)
116
+ return x1
117
+
118
+ def apply_resnet_encoder(self, x):
119
+ x1 = x[:, 0:3, :, :]
120
+ x1 = self.ResNetEncoder(x1)
121
+ return x1
122
+
123
+
124
+ class LowerEncoder(nn.Module):
125
+ def __init__(self,total_image_input=1):
126
+ super().__init__()
127
+ self.encoder_pre = ResidualBlock((total_image_input*3)+1, 20)
128
+ self.encoder_layer1 = ResidualBlock(20, 30)
129
+ self.encoder_layer2 = ResidualBlock(30, 50)
130
+
131
+ self.encoder_layer3 = nn.Sequential(
132
+ ResidualBlock(50, 100),
133
+ nn.MaxPool2d(kernel_size=2, stride=2)
134
+ )
135
+
136
+ self.encoder_layer4 = ResidualBlock(100, 200)
137
+ self.encoder_layer5 = nn.Sequential(
138
+ ResidualBlock(200, 200),
139
+ nn.MaxPool2d(kernel_size=2, stride=2)
140
+ )
141
+
142
+ self.encoder_layer6 = ResidualBlock(200, 200)
143
+ self.encoder_layer7 = nn.Sequential(
144
+ ResidualBlock(200, 200),
145
+ nn.MaxPool2d(kernel_size=2, stride=2)
146
+ )
147
+
148
+ self.encoder_layer8 = ResidualBlock(200, 500)
149
+ self.encoder_layer9 = nn.Sequential(
150
+ ResidualBlock(500, 500),
151
+ nn.MaxPool2d(kernel_size=2, stride=2)
152
+ )
153
+
154
+ self.encoder_layer10 = ResidualBlock(500, 500)
155
+ self.encoder_layer11 = ResidualBlock(500, 500)
156
+
157
+ def forward(self, x):
158
+ x = self.encoder_pre(x)
159
+ x = self.encoder_layer1(x)
160
+ x = self.encoder_layer2(x)
161
+ skip1 = self.encoder_layer3(x)
162
+
163
+ x = self.encoder_layer4(skip1)
164
+ skip2 = self.encoder_layer5(x)
165
+
166
+ x = self.encoder_layer6(skip2)
167
+ skip3 = self.encoder_layer7(x)
168
+
169
+ x = self.encoder_layer8(skip3)
170
+ skip4 = self.encoder_layer9(x)
171
+
172
+ x = self.encoder_layer10(skip4)
173
+ x = self.encoder_layer11(x)
174
+
175
+ return x, [skip1, skip2, skip3, skip4]
176
+
177
+ class MergeDecoder(nn.Module):
178
+ def __init__(self):
179
+ super().__init__()
180
+
181
+ self.decoder_layer1 = ResidualBlock(500, 500)
182
+ self.decoder_layer2 = ResidualBlock(500, 500)
183
+ self.decoder_layer3 = ResidualBlock(500, 500)
184
+
185
+ self.decoder_layer4 = nn.Sequential(
186
+ nn.ConvTranspose2d(500, 200, 2, stride=2, padding=0),
187
+ nn.ReLU(True)
188
+ )
189
+ self.decoder_layer5 = ResidualBlock(200, 200)
190
+
191
+ self.decoder_layer6 = nn.Sequential(
192
+ nn.ConvTranspose2d(200, 200, 2, stride=2, padding=0),
193
+ nn.ReLU(True)
194
+ )
195
+ self.decoder_layer7 = ResidualBlock(200, 200)
196
+
197
+ self.decoder_layer8 = nn.Sequential(
198
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
199
+ nn.ReLU(True)
200
+ )
201
+ self.decoder_layer9 = ResidualBlock(100, 100)
202
+
203
+ self.decoder_layer10 = nn.Sequential(
204
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
205
+ nn.ReLU(True)
206
+ )
207
+ self.decoder_layer11 = ResidualBlock(100, 100)
208
+ self.decoder_layer12 = ResidualBlock(100, 50)
209
+ self.decoder_layer13 = ResidualBlock(50, 40)
210
+ self.decoder_layer14 = ResidualBlock(40, 20)
211
+ self.decoder_layer15 = nn.Sequential(
212
+ nn.Conv2d(20, 8, 3, stride=1, padding=1),
213
+ nn.Sigmoid()
214
+ )
215
+ self.decoder_layer16 = nn.Sequential(
216
+ nn.Conv2d(8, 3, 3, stride=1, padding=1),
217
+ nn.Sigmoid()
218
+ )
219
+
220
+ def forward(self, x, lower_skip_list, upper_skip_list):
221
+ x = self.decoder_layer1(x)
222
+ x = self.decoder_layer2(x)
223
+ x = x + lower_skip_list[3] + upper_skip_list[1]
224
+
225
+ x = self.decoder_layer3(x)
226
+ x = self.decoder_layer4(x)
227
+ x = x + lower_skip_list[2] + upper_skip_list[0]
228
+
229
+ x = self.decoder_layer5(x)
230
+ x = self.decoder_layer6(x)
231
+ x = x + lower_skip_list[1]
232
+
233
+ x = self.decoder_layer7(x)
234
+ x = self.decoder_layer8(x)
235
+ x = x + lower_skip_list[0]
236
+
237
+ x = self.decoder_layer9(x)
238
+ x = self.decoder_layer10(x)
239
+ x = self.decoder_layer11(x)
240
+ x = self.decoder_layer12(x)
241
+ x = self.decoder_layer13(x)
242
+ x = self.decoder_layer14(x)
243
+ x = self.decoder_layer15(x)
244
+ x = self.decoder_layer16(x)
245
+ return x
246
+
247
+
248
+ class DepthDecoder(nn.Module):
249
+ def __init__(self):
250
+ super().__init__()
251
+
252
+ self.decoder_layer1 = ResidualBlock(500, 1400)
253
+ self.decoder_layer2 = ResidualBlock(1400, 1200)
254
+ self.decoder_layer3 = ResidualBlock(1200, 1000)
255
+
256
+ self.decoder_layer4 = nn.Sequential(
257
+ nn.ConvTranspose2d(1000, 800, 2, stride=2, padding=0),
258
+ nn.ReLU(True)
259
+ )
260
+ self.decoder_layer5 = ResidualBlock(800, 600)
261
+
262
+ self.decoder_layer6 = nn.Sequential(
263
+ nn.ConvTranspose2d(600, 400, 2, stride=2, padding=0),
264
+ nn.ReLU(True)
265
+ )
266
+ self.decoder_layer7 = ResidualBlock(400, 200)
267
+
268
+ self.decoder_layer8 = nn.Sequential(
269
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
270
+ nn.ReLU(True)
271
+ )
272
+ self.decoder_layer9 = ResidualBlock(100, 100)
273
+
274
+ self.decoder_layer10 = nn.Sequential(
275
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
276
+ nn.ReLU(True)
277
+ )
278
+ self.decoder_layer11 = ResidualBlock(100, 100)
279
+ self.decoder_layer12 = ResidualBlock(100, 50)
280
+ self.decoder_layer13 = ResidualBlock(50, 40)
281
+ self.decoder_layer14 = ResidualBlock(40, 20)
282
+ self.decoder_layer15 = nn.Sequential(
283
+ nn.Conv2d(20, 8, 3, stride=1, padding=1),
284
+ nn.ReLU(True)
285
+ )
286
+ self.decoder_layer16 = nn.Sequential(
287
+ nn.Conv2d(8, 1, 3, stride=1, padding=1),
288
+ nn.ReLU(True)
289
+ )
290
+
291
+ self.up_refinement_0 = ResidualBlock(200, 800)
292
+ self.up_refinement_1 = ResidualBlock(500, 1200)
293
+
294
+ self.low_refinement_1 = ResidualBlock(200, 400)
295
+ self.low_refinement_2 = ResidualBlock(200, 800)
296
+ self.low_refinement_3 = ResidualBlock(500, 1200)
297
+
298
+
299
+
300
+ def forward(self, x, lower_skip_list, upper_skip_list):
301
+ x = self.decoder_layer1(x)
302
+ x = self.decoder_layer2(x)
303
+
304
+ low_skip_3 = self.low_refinement_3(lower_skip_list[3])
305
+ up_skip_1 = self.up_refinement_1(upper_skip_list[1])
306
+ x = x + low_skip_3 + up_skip_1
307
+
308
+ x = self.decoder_layer3(x)
309
+ x = self.decoder_layer4(x)
310
+
311
+ low_skip_2 = self.low_refinement_2(lower_skip_list[2])
312
+ up_skip_0 = self.up_refinement_0(upper_skip_list[0])
313
+ x = x + low_skip_2 + up_skip_0
314
+
315
+ x = self.decoder_layer5(x)
316
+ x = self.decoder_layer6(x)
317
+
318
+ low_skip_1 = self.low_refinement_1(lower_skip_list[1])
319
+ x = x + low_skip_1
320
+
321
+ x = self.decoder_layer7(x)
322
+ x = self.decoder_layer8(x)
323
+ x = x + lower_skip_list[0]
324
+
325
+ x = self.decoder_layer9(x)
326
+ x = self.decoder_layer10(x)
327
+ x = self.decoder_layer11(x)
328
+ x = self.decoder_layer12(x)
329
+ x = self.decoder_layer13(x)
330
+ x = self.decoder_layer14(x)
331
+ x = self.decoder_layer15(x)
332
+ x = self.decoder_layer16(x)
333
+ return x
334
+
335
+
336
+ class PVSNet_Lite(nn.Module):
337
+ def __init__(self,total_image_input=1):
338
+ super().__init__()
339
+ self.target_positional_embedding = MLPEncoder()
340
+ self.upper_encoder = UpperEncoder()
341
+ self.lower_encoder = LowerEncoder(total_image_input)
342
+ self.merge_decoder = MergeDecoder()
343
+
344
+ self.upper_encoder_extra_1 = nn.Sequential(
345
+ ResidualBlock(256, 200),
346
+ nn.MaxPool2d(kernel_size=2, stride=2)
347
+ )
348
+ self.upper_encoder_extra_2 = nn.Sequential(
349
+ ResidualBlock(200, 500),
350
+ nn.MaxPool2d(kernel_size=2, stride=2)
351
+ )
352
+
353
+ def forward(self, x, pos):
354
+ target_position_feature = self.target_positional_embedding(pos)
355
+
356
+ # First Encoder Branch
357
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
358
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
359
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
360
+
361
+ # Second Encoder Branch
362
+ stacked_tensor = torch.cat((x,target_position_feature),dim=1)
363
+ lower_feature, skip_list = self.lower_encoder(stacked_tensor)
364
+
365
+ # Decoder
366
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
367
+ return merged_feature
368
+
369
+
370
+
371
+
372
+ class PVSDNet_Lite(nn.Module):
373
+ def __init__(self,total_image_input=1):
374
+ super().__init__()
375
+ self.target_positional_embedding = MLPEncoder()
376
+ self.upper_encoder = UpperEncoder()
377
+ self.lower_encoder = LowerEncoder(total_image_input)
378
+ self.merge_decoder = MergeDecoder()
379
+ self.depth_decoder = DepthDecoder()
380
+
381
+ self.upper_encoder_extra_1 = nn.Sequential(
382
+ ResidualBlock(256, 200),
383
+ nn.MaxPool2d(kernel_size=2, stride=2)
384
+ )
385
+ self.upper_encoder_extra_2 = nn.Sequential(
386
+ ResidualBlock(200, 200),
387
+ nn.MaxPool2d(kernel_size=2, stride=2)
388
+ )
389
+
390
+ print("Loading pre-trained nvs net")
391
+ base_net = PVSNet_Lite(total_image_input)
392
+ #base_net = helper.load_Checkpoint("./checkpoint/checkpoint_init_pvsnet.pth", base_net, load_cpu=True)
393
+
394
+ self.target_positional_embedding = base_net.target_positional_embedding
395
+ self.upper_encoder = base_net.upper_encoder
396
+ self.lower_encoder = base_net.lower_encoder
397
+ self.merge_decoder = base_net.merge_decoder
398
+ self.upper_encoder_extra_1 = base_net.upper_encoder_extra_1
399
+ self.upper_encoder_extra_2 = base_net.upper_encoder_extra_2
400
+ del base_net
401
+ print("Loading pre-trained nvs net: Done")
402
+
403
+
404
+ def forward(self, x, pos):
405
+ target_position_feature = self.target_positional_embedding(pos)
406
+
407
+ # First Encoder Branch
408
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
409
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
410
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
411
+
412
+ # Second Encoder Branch
413
+ stacked_tensor = torch.cat((x,target_position_feature),dim=1)
414
+ lower_feature, skip_list = self.lower_encoder(stacked_tensor)
415
+
416
+ # Decoder
417
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
418
+
419
+ # Depth Decoder
420
+ depth_feature = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
421
+ return merged_feature, depth_feature
models/pvsdnet_model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import torchvision
10
+ import rff.layers as rff
11
+ import parameters_pvsdnet as params
12
+ import helperFunctions as helper
13
+
14
+ def getLinearLayer(in_feat, out_feat, activation=nn.ReLU(True)):
15
+ return nn.Sequential(
16
+ nn.Linear(in_features=in_feat, out_features=out_feat, bias=True),
17
+ activation
18
+ )
19
+
20
+ def getConvLayer(in_channel,out_channel,stride=1,padding=1,activation=nn.ReLU()):
21
+ return nn.Sequential(nn.Conv2d(in_channel,
22
+ out_channel,
23
+ kernel_size=3,
24
+ stride=stride,
25
+ padding=padding,
26
+ padding_mode='reflect'),
27
+ activation)
28
+
29
+ def getConvTransposeLayer(in_channel, out_channel,kernel=3,stride=1,padding=1,activation=nn.ReLU()):
30
+ return nn.Sequential(nn.ConvTranspose2d(in_channel,
31
+ out_channel,
32
+ kernel_size = kernel,
33
+ stride=stride,
34
+ padding=padding),
35
+ activation)
36
+
37
+
38
+ class Flatten(nn.Module):
39
+ def forward(self, input):
40
+ return input.view(input.size(0), -1)
41
+
42
+ class UnFlatten(nn.Module):
43
+ def forward(self, input, size=1):
44
+ return input.view(input.size(0), 1, params.params_height//8, params.params_width//8)
45
+
46
+ class ResidualBlock(nn.Module):
47
+ def __init__(self, in_channels, out_channels, stride=1):
48
+ super(ResidualBlock, self).__init__()
49
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
50
+ self.relu = nn.ReLU()
51
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.stride = stride
53
+
54
+ self.shortcut = nn.Sequential()
55
+ if stride != 1 or in_channels != out_channels:
56
+ self.shortcut = nn.Sequential(
57
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
58
+ nn.BatchNorm2d(out_channels)
59
+ )
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.relu(out)
66
+
67
+ out = self.conv2(out)
68
+
69
+ out = out + self.shortcut(residual)
70
+ out = self.relu(out)
71
+ return out
72
+
73
+ class MLPEncoder(nn.Module):
74
+ def __init__(self):
75
+ super().__init__()
76
+ self.m = params.params_m
77
+ self.positional_encoding = rff.PositionalEncoding(sigma=1,m=self.m)
78
+ self.layer1 = getLinearLayer(2*3*self.m, 1024) # 2*3*m = 12, here m=32, when training with Rotation data, set it to 2*6*m
79
+ self.dropout1 = nn.Dropout(0.2)
80
+ self.layer2 = getLinearLayer(1024, 2048)
81
+ self.dropout2 = nn.Dropout(0.2)
82
+ self.layer3 = getLinearLayer(2048, (params.params_height//8)*(params.params_width//8))
83
+ self.unflat = UnFlatten()
84
+
85
+ self.up_layer1 = nn.Upsample(scale_factor=2, mode='nearest')
86
+ self.up_layer2 = nn.Upsample(scale_factor=2, mode='nearest')
87
+ self.up_layer3 = nn.Upsample(scale_factor=2, mode='nearest')
88
+
89
+
90
+ def forward(self, x):
91
+ x = self.positional_encoding(x)
92
+
93
+ x = self.layer1(x)
94
+ x = self.dropout1(x)
95
+ x = self.layer2(x)
96
+ x = self.dropout2(x)
97
+ x = self.layer3(x)
98
+
99
+ x = self.unflat(x)
100
+
101
+ x = self.up_layer1(x)
102
+ x = self.up_layer2(x)
103
+ x = self.up_layer3(x)
104
+ return x
105
+
106
+ class UpperEncoder(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ model = torchvision.models.resnet152(pretrained=False)
110
+ layers = list(model.children())
111
+ self.ResNetEncoder = torch.nn.Sequential(*layers[:5].copy())
112
+ del model
113
+
114
+ def forward(self, x):
115
+ x1 = x[:, 0:3, :, :]
116
+ x1 = self.ResNetEncoder(x1)
117
+ return x1
118
+
119
+ def apply_resnet_encoder(self, x):
120
+ x1 = x[:, 0:3, :, :]
121
+ x1 = self.ResNetEncoder(x1)
122
+ return x1
123
+
124
+ class LowerEncoder(nn.Module):
125
+ def __init__(self,total_image_input=1):
126
+ super().__init__()
127
+ self.encoder_pre = ResidualBlock((total_image_input*3)+1, 20)
128
+ self.encoder_layer1 = ResidualBlock(20, 30)
129
+ self.encoder_layer2 = ResidualBlock(30, 50)
130
+
131
+ self.encoder_layer3 = nn.Sequential(
132
+ ResidualBlock(50, 100),
133
+ nn.MaxPool2d(kernel_size=2, stride=2)
134
+ )
135
+
136
+ self.encoder_layer4 = ResidualBlock(100, 200)
137
+ self.encoder_layer5 = nn.Sequential(
138
+ ResidualBlock(200, 400),
139
+ nn.MaxPool2d(kernel_size=2, stride=2)
140
+ )
141
+
142
+ self.encoder_layer6 = ResidualBlock(400, 600)
143
+ self.encoder_layer7 = nn.Sequential(
144
+ ResidualBlock(600, 800),
145
+ nn.MaxPool2d(kernel_size=2, stride=2)
146
+ )
147
+
148
+ self.encoder_layer8 = ResidualBlock(800, 1000)
149
+ self.encoder_layer9 = nn.Sequential(
150
+ ResidualBlock(1000, 1200),
151
+ nn.MaxPool2d(kernel_size=2, stride=2)
152
+ )
153
+
154
+ self.encoder_layer10 = ResidualBlock(1200, 1400)
155
+ self.encoder_layer11 = ResidualBlock(1400, 1600)
156
+
157
+ def forward(self, x):
158
+ x = self.encoder_pre(x)
159
+ x = self.encoder_layer1(x)
160
+ x = self.encoder_layer2(x)
161
+ skip1 = self.encoder_layer3(x)
162
+
163
+ x = self.encoder_layer4(skip1)
164
+ skip2 = self.encoder_layer5(x)
165
+
166
+ x = self.encoder_layer6(skip2)
167
+ skip3 = self.encoder_layer7(x)
168
+
169
+ x = self.encoder_layer8(skip3)
170
+ skip4 = self.encoder_layer9(x)
171
+
172
+ x = self.encoder_layer10(skip4)
173
+ x = self.encoder_layer11(x)
174
+
175
+ return x, [skip1, skip2, skip3, skip4]
176
+
177
+ class MergeDecoder(nn.Module):
178
+ def __init__(self):
179
+ super().__init__()
180
+
181
+ self.decoder_layer1 = ResidualBlock(1600, 1400)
182
+ self.decoder_layer2 = ResidualBlock(1400, 1200)
183
+ self.decoder_layer3 = ResidualBlock(1200, 1000)
184
+
185
+ self.decoder_layer4 = nn.Sequential(
186
+ nn.ConvTranspose2d(1000, 800, 2, stride=2, padding=0),
187
+ nn.ReLU(True)
188
+ )
189
+ self.decoder_layer5 = ResidualBlock(800, 600)
190
+
191
+ self.decoder_layer6 = nn.Sequential(
192
+ nn.ConvTranspose2d(600, 400, 2, stride=2, padding=0),
193
+ nn.ReLU(True)
194
+ )
195
+ self.decoder_layer7 = ResidualBlock(400, 200)
196
+
197
+ self.decoder_layer8 = nn.Sequential(
198
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
199
+ nn.ReLU(True)
200
+ )
201
+ self.decoder_layer9 = ResidualBlock(100, 100)
202
+
203
+ self.decoder_layer10 = nn.Sequential(
204
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
205
+ nn.ReLU(True)
206
+ )
207
+ self.decoder_layer11 = ResidualBlock(100, 100)
208
+ self.decoder_layer12 = ResidualBlock(100, 50)
209
+ self.decoder_layer13 = ResidualBlock(50, 40)
210
+ self.decoder_layer14 = ResidualBlock(40, 20)
211
+ self.decoder_layer15 = nn.Sequential(
212
+ nn.Conv2d(20, 8, 3, stride=1, padding=1),
213
+ nn.Sigmoid()
214
+ )
215
+ self.decoder_layer16 = nn.Sequential(
216
+ nn.Conv2d(8, 3, 3, stride=1, padding=1),
217
+ nn.Sigmoid()
218
+ )
219
+
220
+ def forward(self, x, lower_skip_list, upper_skip_list):
221
+ x = self.decoder_layer1(x)
222
+ x = self.decoder_layer2(x)
223
+ x = x + lower_skip_list[3] + upper_skip_list[1]
224
+
225
+ x = self.decoder_layer3(x)
226
+ x = self.decoder_layer4(x)
227
+ x = x + lower_skip_list[2] + upper_skip_list[0]
228
+
229
+ x = self.decoder_layer5(x)
230
+ x = self.decoder_layer6(x)
231
+ x = x + lower_skip_list[1]
232
+
233
+ x = self.decoder_layer7(x)
234
+ x = self.decoder_layer8(x)
235
+ x = x + lower_skip_list[0]
236
+
237
+ x = self.decoder_layer9(x)
238
+ x = self.decoder_layer10(x)
239
+ x = self.decoder_layer11(x)
240
+ x = self.decoder_layer12(x)
241
+ x = self.decoder_layer13(x)
242
+ x = self.decoder_layer14(x)
243
+ x = self.decoder_layer15(x)
244
+ x = self.decoder_layer16(x)
245
+ return x
246
+
247
+ class DepthDecoder(nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+
251
+ self.decoder_layer1 = ResidualBlock(1600, 1400)
252
+ self.decoder_layer2 = ResidualBlock(1400, 1200)
253
+ self.decoder_layer3 = ResidualBlock(1200, 1000)
254
+
255
+ self.decoder_layer4 = nn.Sequential(
256
+ nn.ConvTranspose2d(1000, 800, 2, stride=2, padding=0),
257
+ nn.ReLU(True)
258
+ )
259
+ self.decoder_layer5 = ResidualBlock(800, 600)
260
+
261
+ self.decoder_layer6 = nn.Sequential(
262
+ nn.ConvTranspose2d(600, 400, 2, stride=2, padding=0),
263
+ nn.ReLU(True)
264
+ )
265
+ self.decoder_layer7 = ResidualBlock(400, 200)
266
+
267
+ self.decoder_layer8 = nn.Sequential(
268
+ nn.ConvTranspose2d(200, 100, 2, stride=2, padding=0),
269
+ nn.ReLU(True)
270
+ )
271
+ self.decoder_layer9 = ResidualBlock(100, 100)
272
+
273
+ self.decoder_layer10 = nn.Sequential(
274
+ nn.ConvTranspose2d(100, 100, 2, stride=2, padding=0),
275
+ nn.ReLU(True)
276
+ )
277
+ self.decoder_layer11 = ResidualBlock(100, 100)
278
+ self.decoder_layer12 = ResidualBlock(100, 50)
279
+ self.decoder_layer13 = ResidualBlock(50, 40)
280
+ self.decoder_layer14 = ResidualBlock(40, 20)
281
+ self.decoder_layer15 = nn.Sequential(
282
+ nn.Conv2d(20, 8, 3, stride=1, padding=1),
283
+ nn.ReLU(True)
284
+ )
285
+ self.decoder_layer16 = nn.Sequential(
286
+ nn.Conv2d(8, 1, 3, stride=1, padding=1),
287
+ nn.ReLU(True)
288
+ )
289
+
290
+ def forward(self, x, lower_skip_list, upper_skip_list):
291
+ x = self.decoder_layer1(x)
292
+ x = self.decoder_layer2(x)
293
+ x = x + lower_skip_list[3] + upper_skip_list[1]
294
+
295
+ x = self.decoder_layer3(x)
296
+ x = self.decoder_layer4(x)
297
+ x = x + lower_skip_list[2] + upper_skip_list[0]
298
+
299
+ x = self.decoder_layer5(x)
300
+ x = self.decoder_layer6(x)
301
+ x = x + lower_skip_list[1]
302
+
303
+ x = self.decoder_layer7(x)
304
+ x = self.decoder_layer8(x)
305
+ x = x + lower_skip_list[0]
306
+
307
+ x = self.decoder_layer9(x)
308
+ x = self.decoder_layer10(x)
309
+ x = self.decoder_layer11(x)
310
+ x = self.decoder_layer12(x)
311
+ x = self.decoder_layer13(x)
312
+ x = self.decoder_layer14(x)
313
+ x = self.decoder_layer15(x)
314
+ x = self.decoder_layer16(x)
315
+ return x
316
+
317
+
318
+ class PVSNet(nn.Module):
319
+ def __init__(self,total_image_input=1):
320
+ super().__init__()
321
+ self.target_positional_embedding = MLPEncoder()
322
+ self.upper_encoder = UpperEncoder()
323
+ self.lower_encoder = LowerEncoder(total_image_input)
324
+ self.merge_decoder = MergeDecoder()
325
+
326
+ self.upper_encoder_extra_1 = nn.Sequential(
327
+ ResidualBlock(256, 800),
328
+ nn.MaxPool2d(kernel_size=2, stride=2)
329
+ )
330
+ self.upper_encoder_extra_2 = nn.Sequential(
331
+ ResidualBlock(800, 1200),
332
+ nn.MaxPool2d(kernel_size=2, stride=2)
333
+ )
334
+
335
+ def forward(self, x, pos):
336
+ target_position_feature = self.target_positional_embedding(pos)
337
+
338
+ # First Encoder Branch
339
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
340
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
341
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
342
+
343
+ # Second Encoder Branch
344
+ stacked_tensor = torch.cat((x,target_position_feature),dim=1)
345
+ lower_feature, skip_list = self.lower_encoder(stacked_tensor)
346
+
347
+ # Decoder
348
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
349
+
350
+ return merged_feature
351
+
352
+ class PVSDNet(nn.Module):
353
+ def __init__(self,total_image_input=1):
354
+ super().__init__()
355
+ self.target_positional_embedding = MLPEncoder()
356
+ self.upper_encoder = UpperEncoder()
357
+ self.lower_encoder = LowerEncoder(total_image_input)
358
+ self.merge_decoder = MergeDecoder()
359
+ self.depth_decoder = DepthDecoder()
360
+
361
+ self.upper_encoder_extra_1 = nn.Sequential(
362
+ ResidualBlock(256, 800),
363
+ nn.MaxPool2d(kernel_size=2, stride=2)
364
+ )
365
+ self.upper_encoder_extra_2 = nn.Sequential(
366
+ ResidualBlock(800, 1200),
367
+ nn.MaxPool2d(kernel_size=2, stride=2)
368
+ )
369
+
370
+ print("Loading pre-trained nvs net")
371
+ base_net = PVSNet(total_image_input)
372
+ #base_net = helper.load_Checkpoint("./checkpoint/checkpoint_init_pvsnet.pth", base_net, load_cpu=True) #uncomment if you want to use the pre-trained pvsnet to train with lora
373
+
374
+ self.target_positional_embedding = base_net.target_positional_embedding
375
+ self.upper_encoder = base_net.upper_encoder
376
+ self.lower_encoder = base_net.lower_encoder
377
+ self.merge_decoder = base_net.merge_decoder
378
+ self.upper_encoder_extra_1 = base_net.upper_encoder_extra_1
379
+ self.upper_encoder_extra_2 = base_net.upper_encoder_extra_2
380
+ del base_net
381
+ print("Loading pre-trained nvs net: Done")
382
+
383
+
384
+
385
+ def forward(self, x, pos):
386
+ target_position_feature = self.target_positional_embedding(pos)
387
+
388
+ # First Encoder Branch
389
+ upper_features_1 = self.upper_encoder.apply_resnet_encoder(x)
390
+ upper_features_1 = self.upper_encoder_extra_1(upper_features_1)
391
+ upper_features_2 = self.upper_encoder_extra_2(upper_features_1)
392
+
393
+ # Second Encoder Branch
394
+ stacked_tensor = torch.cat((x,target_position_feature),dim=1)
395
+ lower_feature, skip_list = self.lower_encoder(stacked_tensor)
396
+
397
+ # Decoder
398
+ merged_feature = self.merge_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
399
+
400
+ # Depth Decoder
401
+ depth_feature = self.depth_decoder(lower_feature, skip_list, [upper_features_1, upper_features_2])
402
+ return merged_feature, depth_feature
parameters_pvsdnet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ params_height = 256
4
+ params_width = 256
5
+ params_m = 32
6
+ params_number_input = 1
7
+ params_step_size = 5
8
+ params_gamma = 0.2
9
+
10
+ TRAIN_LOCATION = "./lf_train.txt"
11
+ VALIDATION_LOCATION = "./lf_validate.txt"
12
+ TEST_LOCATION = "./lf_test.txt"
13
+ LOG_FILE_LOCATION = "./logs/training_log_0.txt"
14
+ CHECKPOINT_LOCATION = "./checkpoint/"
15
+ RESUME_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_best.pth"
16
+ START_CHECKPOINT_LOCATION = "./checkpoint/checkpoint_init.pth"
17
+ DEVICE = "cpu"
18
+
19
+ BATCH_SIZE = 24
20
+ LEARNING_RATE = 0.00001
21
+ NUM_EPOCHS = 50
22
+ START_EPOCH = 0
23
+ T_max = 50
24
+ PRINT_INTERVAL = 20
25
+
26
+ os.makedirs("./logs",exist_ok=True)
27
+ os.makedirs("./checkpoint",exist_ok=True)
28
+ os.makedirs("./output",exist_ok=True)
29
+
30
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ tempfile
3
+ torch==2.9.1
4
+ torchvision==0.24.1
5
+ pytorch-msssim==1.0.0
6
+ pytorchvideo==0.1.5
7
+ gradio==6.2.0
8
+ gradio_client==2.0.2
9
+ opencv-python==4.6.0.66
10
+ pillow==10.4.0
11
+ pillow_heif==0.15.0
12
+ matplotlib==3.10.8
13
+ matplotlib-inline==0.1.6
14
+ tqdm==4.65.0
15
+ moviepy==1.0.3
16
+ scikit-image==0.26.0
17
+ scikit-learn==1.8.0
18
+ scipy==1.11.4