ReaganWZY commited on
Commit
5acc7ae
·
verified ·
1 Parent(s): b7b2f5c

Upload DepthPolyp model artifacts

Browse files
.gitattributes CHANGED
@@ -1,35 +1,7 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
3
+ *.gif filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ *.jpg filter=lfs diff=lfs merge=lfs -text
6
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
7
+
 
 
 
 
 
 
DepthPolyp_Kvasir.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:883ff8a825a5f51f59d46a9a2c2e9f0a505519140495dfa6800e6b48297c9f5b
3
+ size 14588196
DepthPolyp_Kvasir.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bfad787ccc259b25bb28ee77ec39c4ae4a579aba971facc3e579aa8debd6257
3
+ size 14410152
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ZHUOYU WU
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ pipeline_tag: image-segmentation
5
+ tags:
6
+ - medical-image-segmentation
7
+ - image-segmentation
8
+ - semantic-segmentation
9
+ - polyp-segmentation
10
+ - colonoscopy
11
+ - depth-estimation
12
+ - pseudo-depth
13
+ - real-time
14
+ - onnx
15
+ - pytorch
16
+ - arxiv:2605.16519
17
+ metrics:
18
+ - dice
19
+ - iou
20
+ - recall
21
+ ---
22
+
23
+ # DepthPolyp: Pseudo-Depth Guided Lightweight Segmentation for Real-Time Colonoscopy
24
+
25
+ DepthPolyp is a lightweight pseudo-depth guided model for real-time colonoscopic polyp segmentation. Given an RGB colonoscopy frame, it jointly predicts:
26
+
27
+ 1. a binary polyp segmentation probability map
28
+ 2. a pseudo-depth probability map for depth-aware structural guidance
29
+
30
+ The model uses a MiT-B0 encoder and lightweight fusion/gating modules to keep deployment cost low while improving robustness under blur, illumination changes, reflections, and other real-world colonoscopy degradations.
31
+
32
+ - Paper: [arXiv:2605.16519](https://arxiv.org/abs/2605.16519)
33
+ - Code: [github.com/ReaganWu/DepthPolyp](https://github.com/ReaganWu/DepthPolyp)
34
+ - License: MIT
35
+
36
+ ## Model Details
37
+
38
+ | Item | Value |
39
+ | --- | --- |
40
+ | Model | DepthPolyp |
41
+ | Encoder | MiT-B0 |
42
+ | Input | RGB image, 224 x 224 |
43
+ | Outputs | segmentation, pseudo-depth |
44
+ | Parameters | 3.57M |
45
+ | Complexity | 0.86 GMACs |
46
+ | Training data | Kvasir-SEG with degradation-aware training |
47
+ | PyTorch checkpoint | `DepthPolyp_Kvasir.pth` |
48
+ | ONNX checkpoint | `DepthPolyp_Kvasir.onnx` |
49
+
50
+ ONNX I/O names:
51
+
52
+ ```text
53
+ input: image
54
+ outputs: segmentation, depth
55
+ ```
56
+
57
+ ## Intended Use
58
+
59
+ DepthPolyp is intended for research on colonoscopic polyp segmentation, lightweight medical image segmentation, robustness under endoscopic video degradation, and deployment-oriented model comparison.
60
+
61
+ This model is not a standalone medical device and is not intended for clinical diagnosis without appropriate validation, regulatory review, and expert oversight.
62
+
63
+ ## Quick Start: ONNX Runtime
64
+
65
+ ```bash
66
+ pip install onnxruntime pillow numpy
67
+
68
+ python scripts/infer_onnx.py \
69
+ --onnx DepthPolyp_Kvasir.onnx \
70
+ --input samples/kvasir/images \
71
+ --output outputs
72
+ ```
73
+
74
+ The script writes binary masks, pseudo-depth visualizations, and mask overlays.
75
+
76
+ ## Quick Start: PyTorch
77
+
78
+ ```bash
79
+ pip install torch torchvision pillow numpy
80
+ ```
81
+
82
+ ```python
83
+ import torch
84
+ from PIL import Image
85
+ from torchvision import transforms
86
+
87
+ from model.depthpolyp import build_depthpolyp
88
+
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+
91
+ model = build_depthpolyp(
92
+ encoder_name="b0",
93
+ in_channels=3,
94
+ num_classes=2,
95
+ decoder_channels=256,
96
+ activation=None,
97
+ )
98
+ state_dict = torch.load("DepthPolyp_Kvasir.pth", map_location="cpu", weights_only=True)
99
+ model.load_state_dict(state_dict, strict=True)
100
+ model.to(device).eval()
101
+
102
+ image = Image.open("samples/kvasir/images/sample_01.jpg").convert("RGB")
103
+ transform = transforms.Compose([
104
+ transforms.Resize((224, 224)),
105
+ transforms.ToTensor(),
106
+ ])
107
+ x = transform(image).unsqueeze(0).to(device)
108
+
109
+ with torch.no_grad():
110
+ seg_prob, depth_prob = model(x)
111
+
112
+ print(seg_prob.shape) # [1, 1, 224, 224]
113
+ print(depth_prob.shape) # [1, 1, 224, 224]
114
+ ```
115
+
116
+ ## Loading Files with `huggingface_hub`
117
+
118
+ ```python
119
+ from huggingface_hub import hf_hub_download
120
+
121
+ repo_id = "ReaganWZY/DepthPolyp"
122
+ pth_path = hf_hub_download(repo_id=repo_id, filename="DepthPolyp_Kvasir.pth")
123
+ onnx_path = hf_hub_download(repo_id=repo_id, filename="DepthPolyp_Kvasir.onnx")
124
+ ```
125
+
126
+ If you publish under a different Hugging Face repo id, replace `ReaganWZY/DepthPolyp` with that id.
127
+
128
+ ## Evaluation
129
+
130
+ Paper-reported reference results:
131
+
132
+ | Protocol | Kvasir Dice/IoU/Recall | ClinicDB Dice/IoU/Recall | ColonDB Dice/IoU/Recall |
133
+ | --- | --- | --- | --- |
134
+ | `N->C` | 0.891 / 0.805 / 0.885 | 0.854 / 0.748 / 0.845 | 0.801 / 0.669 / 0.759 |
135
+ | `N->N` | 0.853 / 0.745 / 0.854 | 0.751 / 0.608 / 0.759 | 0.734 / 0.582 / 0.697 |
136
+
137
+ Real-world robustness and deployment results from the paper:
138
+
139
+ | Params | GMACs | Avg. Dice | PolypGen Dice | iPhone FPS | Raspberry Pi 4 FPS |
140
+ | ---: | ---: | ---: | ---: | ---: | ---: |
141
+ | 3.57M | 0.86 | 0.779 | 0.679 | 181.54 | 4.05 |
142
+
143
+ ## Training Data and Protocol
144
+
145
+ The released checkpoint is trained on Kvasir-SEG with degradation-aware training. Pseudo-depth targets are generated with Depth-Anything v2 Small and are used only during training; depth targets are not required at inference time.
146
+
147
+ Reference training settings from the paper:
148
+
149
+ - Input resolution: 224 x 224
150
+ - Optimizer: AdamW
151
+ - Learning rate: 1e-4
152
+ - Weight decay: 1e-4
153
+ - Batch size: 16
154
+ - Epochs: 200
155
+ - Schedule: 10% warm-up followed by cosine annealing
156
+
157
+ ## Citation
158
+
159
+ ```bibtex
160
+ @misc{wu2026depthpolyp,
161
+ title={DepthPolyp: Pseudo-Depth Guided Lightweight Segmentation for Real-Time Colonoscopy},
162
+ author={Wu, Zhuoyu and Ou, Wenhui and Zhang, Lexi and Tan, Pei-Sze and Wu, Dongjun and Zhao, Junhe and Fang, Wenqi and Phan, Raphaël C.-W.},
163
+ year={2026},
164
+ eprint={2605.16519},
165
+ archivePrefix={arXiv},
166
+ primaryClass={cs.CV}
167
+ }
168
+ ```
assets/depthpolyp_architecture.png ADDED

Git LFS Details

  • SHA256: 164e1f204f551b849e0d30f9633840d899df5382a5686050ece67228763d10d6
  • Pointer size: 131 Bytes
  • Size of remote file: 683 kB
assets/seq19.gif ADDED

Git LFS Details

  • SHA256: b5bf1f06e43007f48de9c17f455f370be4b11a8e41aacda55f9f12e60146087a
  • Pointer size: 132 Bytes
  • Size of remote file: 4.04 MB
assets/seq22.gif ADDED

Git LFS Details

  • SHA256: af9cb8212c7a8a33e782d4c75678c85831309eb78743b7afa479bf991f7eebc3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.4 MB
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DepthPolyp"
4
+ ],
5
+ "model_type": "depthpolyp",
6
+ "encoder_name": "b0",
7
+ "in_channels": 3,
8
+ "num_classes": 2,
9
+ "decoder_channels": 256,
10
+ "activation": null,
11
+ "image_size": 224,
12
+ "outputs": [
13
+ "segmentation",
14
+ "depth"
15
+ ],
16
+ "training_dataset": "Kvasir-SEG",
17
+ "paper": "https://arxiv.org/abs/2605.16519",
18
+ "github": "https://github.com/ReaganWu/DepthPolyp"
19
+ }
20
+
model/depthpolyp.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .modules.HF_Decoder import HiF_Decoder
5
+ from .modules.MiT_Encoder import MixVisionTransformer
6
+ from .modules.Seg_Head import SegmentationHead
7
+
8
+ class DepthPolyp(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels: int = 3,
12
+ num_classes: int = 2, # 1 for seg, 1 for depth
13
+ encoder_name: str = 'b0',
14
+ decoder_channels: int = 256,
15
+ activation: str = None,
16
+ upsampling: int = 4,
17
+ ):
18
+ super().__init__()
19
+
20
+ # Encoder configurations
21
+ encoder_configs = {
22
+ 'b0': {
23
+ 'embed_dims': [32, 64, 160, 256],
24
+ 'num_heads': [1, 2, 5, 8],
25
+ 'mlp_ratios': [4, 4, 4, 4],
26
+ 'depths': [2, 2, 2, 2],
27
+ 'sr_ratios': [8, 4, 2, 1],
28
+ },
29
+ 'b1': {
30
+ 'embed_dims': [64, 128, 320, 512],
31
+ 'num_heads': [1, 2, 5, 8],
32
+ 'mlp_ratios': [4, 4, 4, 4],
33
+ 'depths': [2, 2, 2, 2],
34
+ 'sr_ratios': [8, 4, 2, 1],
35
+ },
36
+ 'b2': {
37
+ 'embed_dims': [64, 128, 320, 512],
38
+ 'num_heads': [1, 2, 5, 8],
39
+ 'mlp_ratios': [4, 4, 4, 4],
40
+ 'depths': [3, 4, 6, 3],
41
+ 'sr_ratios': [8, 4, 2, 1],
42
+ },
43
+ 'b3': {
44
+ 'embed_dims': [64, 128, 320, 512],
45
+ 'num_heads': [1, 2, 5, 8],
46
+ 'mlp_ratios': [4, 4, 4, 4],
47
+ 'depths': [3, 4, 18, 3],
48
+ 'sr_ratios': [8, 4, 2, 1],
49
+ },
50
+ 'b4': {
51
+ 'embed_dims': [64, 128, 320, 512],
52
+ 'num_heads': [1, 2, 5, 8],
53
+ 'mlp_ratios': [4, 4, 4, 4],
54
+ 'depths': [3, 8, 27, 3],
55
+ 'sr_ratios': [8, 4, 2, 1],
56
+ },
57
+ 'b5': {
58
+ 'embed_dims': [64, 128, 320, 512],
59
+ 'num_heads': [1, 2, 5, 8],
60
+ 'mlp_ratios': [4, 4, 4, 4],
61
+ 'depths': [3, 6, 40, 3],
62
+ 'sr_ratios': [8, 4, 2, 1],
63
+ },
64
+ }
65
+
66
+ if encoder_name not in encoder_configs:
67
+ raise ValueError(f"encoder_name should be one of {list(encoder_configs.keys())}, got {encoder_name}")
68
+
69
+ config = encoder_configs[encoder_name]
70
+
71
+ # Build encoder
72
+ self.encoder = MixVisionTransformer(
73
+ in_chans=in_channels,
74
+ embed_dims=config['embed_dims'],
75
+ num_heads=config['num_heads'],
76
+ mlp_ratios=config['mlp_ratios'],
77
+ qkv_bias=True,
78
+ depths=config['depths'],
79
+ sr_ratios=config['sr_ratios'],
80
+ drop_rate=0.0,
81
+ drop_path_rate=0.1,
82
+ )
83
+
84
+
85
+ self.decoder = HiF_Decoder(
86
+ encoder_channels=config['embed_dims'],
87
+ decoder_channels=decoder_channels,
88
+ )
89
+
90
+ # Build segmentation head (nn.Sequential style)
91
+ self.segmentation_head = SegmentationHead(
92
+ in_channels=decoder_channels//4,
93
+ out_channels=num_classes,
94
+ activation=activation,
95
+ kernel_size=1,
96
+ upsampling=upsampling,
97
+ )
98
+
99
+ self.name = f"DepthPolyp-{encoder_name}"
100
+
101
+ def forward(self, x):
102
+ """Forward pass
103
+
104
+ Args:
105
+ x: Input tensor of shape (B, C, H, W)
106
+
107
+ Returns:
108
+ Output tensor of shape (B, num_classes, H, W)
109
+ """
110
+ # Encoder - returns features at [H/4, H/8, H/16, H/32]
111
+ encoder_features = self.encoder(x)
112
+
113
+ # Decoder - returns features at H/4
114
+
115
+ fpn_features = self.decoder(encoder_features)
116
+ decoder_output = fpn_features
117
+ # print(f"Decoder output shape: {decoder_output.shape}")
118
+
119
+ # Segmentation head - upsample to original size
120
+ masks = self.segmentation_head(decoder_output)
121
+ pred_seg = torch.sigmoid(masks[:, 0:1, :, :]) # segmentation 通道
122
+ pred_depth = torch.sigmoid(masks[:, 1:2, :, :]) # depth 通道,通常是回归,不做激活
123
+
124
+ return pred_seg, pred_depth
125
+
126
+ @torch.no_grad()
127
+ def predict(self, x):
128
+ """Inference method"""
129
+ if self.training:
130
+ self.eval()
131
+ return self(x)
132
+
133
+ def load_pretrained(self, checkpoint_path, strict=True):
134
+ """Load pretrained weights
135
+
136
+ Args:
137
+ checkpoint_path: Path to checkpoint file
138
+ strict: Whether to strictly enforce key matching
139
+ """
140
+ state_dict = torch.load(checkpoint_path, map_location='cpu')
141
+
142
+ # Handle different checkpoint formats
143
+ if 'state_dict' in state_dict:
144
+ state_dict = state_dict['state_dict']
145
+ elif 'model' in state_dict:
146
+ state_dict = state_dict['model']
147
+
148
+ # Remove module. prefix if present (from DataParallel)
149
+ new_state_dict = {}
150
+ for k, v in state_dict.items():
151
+ if k.startswith('module.'):
152
+ new_state_dict[k[7:]] = v
153
+ else:
154
+ new_state_dict[k] = v
155
+
156
+ self.load_state_dict(new_state_dict, strict=strict)
157
+ print(f"✓ Loaded pretrained weights from {checkpoint_path}")
158
+
159
+
160
+ def build_depthpolyp(
161
+ encoder_name='b0',
162
+ in_channels=3,
163
+ num_classes=2,
164
+ decoder_channels=256,
165
+ activation=None,
166
+ ):
167
+ """
168
+ Create a DepthPolyp model
169
+
170
+ Args:
171
+ encoder_name: Encoder variant ('b0', 'b1', 'b2', 'b3', 'b4', 'b5')
172
+ in_channels: Number of input channels
173
+ num_classes: Number of output classes
174
+ decoder_channels: Number of channels in decoder
175
+ activation: Output activation ('sigmoid', 'softmax', or None)
176
+
177
+ Returns:
178
+ DepthPolyp model
179
+
180
+ Example:
181
+ >>> model = build_depthpolyp('b2', num_classes=21, activation='softmax')
182
+ >>> print(model)
183
+ """
184
+
185
+ model = DepthPolyp(
186
+ in_channels=in_channels,
187
+ num_classes=num_classes,
188
+ encoder_name=encoder_name,
189
+ decoder_channels=decoder_channels,
190
+ activation=activation,
191
+ )
192
+ return model
193
+
194
+ if __name__ == '__main__':
195
+ print("="*60)
196
+ print("Loading Model .....")
197
+ model = build_depthpolyp(
198
+ encoder_name='b0',
199
+ in_channels=3, # Input channels
200
+ num_classes=2, # Total 2. 1 for seg, 1 for depth
201
+ decoder_channels=256,
202
+ activation='sigmoid',
203
+ )
204
+ print("="*60)
205
+ print("Validating Model .....")
206
+ print("Check the Param and Complexity(GMACs)")
207
+ import ptflops
208
+ macs, params = ptflops.get_model_complexity_info(
209
+ model, (3, 224, 224), as_strings=True,
210
+ print_per_layer_stat=False, verbose=False
211
+ )
212
+ print(f" MACs: {macs}, Params: {params}")
213
+ # output is MACs: 862.17 MMac, Params: 3.57 M
214
+ print("="*60)
215
+ print("Check the output .....")
216
+ dummy_input = torch.randn(1, 3, 224, 224) # B, C, H, W, single RGB image
217
+ output_seg, output_depth = model(dummy_input)
218
+ print("input_shape is:", dummy_input.shape)
219
+ print("output_seg shape is:", output_seg.shape)
220
+ print("output_depth shape is:", output_depth.shape)
model/modules/DGG.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DGG_Module(nn.Module):
6
+ def __init__(self, channels, groups):
7
+ super().__init__()
8
+ self.groups = groups
9
+ self.fc = nn.Linear(groups, groups)
10
+
11
+ def forward(self, x):
12
+ B, C, H, W = x.shape
13
+ gc = C // self.groups
14
+
15
+ xg = x.view(B, self.groups, gc, H, W).mean(dim=(2,3,4)) # (B, groups)
16
+ gates = torch.sigmoid(self.fc(xg))[:, :, None, None, None] # (B, groups, 1, 1, 1)
17
+
18
+ xg = x.view(B, self.groups, gc, H, W)
19
+ out = (xg * gates).reshape(B, C, H, W)
20
+ return out
model/modules/GFM.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class GFM_Module(nn.Module):
5
+ def __init__(self, in_channels, out_channels, ratio=2):
6
+ super().__init__()
7
+ init_channels = out_channels // ratio
8
+ new_channels = out_channels - init_channels
9
+
10
+ self.primary_conv = nn.Sequential(
11
+ nn.Conv2d(in_channels, init_channels, 1, bias=False),
12
+ nn.BatchNorm2d(init_channels),
13
+ nn.ReLU(inplace=True)
14
+ )
15
+
16
+ self.cheap_operation = nn.Sequential(
17
+ nn.Conv2d(init_channels, new_channels, 3, 1, 1, groups=init_channels, bias=False),
18
+ nn.BatchNorm2d(new_channels),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+
22
+ def forward(self, x):
23
+ # print("input:", x.shape)
24
+ x1 = self.primary_conv(x)
25
+ # print("primary conv output:", x1.shape)
26
+ x2 = self.cheap_operation(x1)
27
+ # print("cheap operation output:", x2.shape)
28
+ return x1, x2
model/modules/HF_Decoder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .GFM import GFM_Module
6
+ from .DGG import DGG_Module
7
+ from .ISF import ISF_Module
8
+
9
+
10
+ class MLP(nn.Module):
11
+ """Simple MLP for decoder"""
12
+ def __init__(self, input_dim, embed_dim):
13
+ super().__init__()
14
+ self.proj = nn.Linear(input_dim, embed_dim)
15
+
16
+ def forward(self, x):
17
+ x = x.flatten(2).transpose(1, 2)
18
+ x = self.proj(x)
19
+ return x
20
+
21
+
22
+ class HiF_Decoder(nn.Module):
23
+ """Hierarchical Factorized Decoder"""
24
+ def __init__(
25
+ self,
26
+ encoder_channels=[64, 128, 320, 512],
27
+ decoder_channels=256,
28
+ ):
29
+ super().__init__()
30
+
31
+ # MLP layers to unify channel dimensions
32
+ self.linear_c4 = MLP(input_dim=encoder_channels[3], embed_dim=decoder_channels)
33
+ self.linear_c3 = MLP(input_dim=encoder_channels[2], embed_dim=decoder_channels)
34
+ self.linear_c2 = MLP(input_dim=encoder_channels[1], embed_dim=decoder_channels)
35
+ self.linear_c1 = MLP(input_dim=encoder_channels[0], embed_dim=decoder_channels)
36
+
37
+ self.dropout = nn.Dropout2d(0.1)
38
+
39
+ self.gfm_c4_1 = GFM_Module(decoder_channels, decoder_channels//2)
40
+ self.gfm_c3_1 = GFM_Module(decoder_channels, decoder_channels//2)
41
+ self.gfm_c2_1 = GFM_Module(decoder_channels, decoder_channels//2)
42
+ self.gfm_c1_1 = GFM_Module(decoder_channels, decoder_channels//2)
43
+
44
+ self.gfm_c_o_1 = GFM_Module(decoder_channels, decoder_channels//2)
45
+ self.gfm_c_e_1 = GFM_Module(decoder_channels, decoder_channels//2)
46
+
47
+ self.gfm_c_o_2 = GFM_Module(decoder_channels//2, decoder_channels//4)
48
+ self.gfm_c_e_2 = GFM_Module(decoder_channels//2, decoder_channels//4)
49
+
50
+ self.gfm_c_o_3 = GFM_Module(decoder_channels//4, decoder_channels//8)
51
+ self.gfm_c_e_3 = GFM_Module(decoder_channels//4, decoder_channels//8)
52
+
53
+ self.cyclic_shuffle_enhancer_o = ISF_Module(channels=decoder_channels, groups=4, kernel_size=3, cyclic_percent=0.0)
54
+ self.cyclic_shuffle_enhancer_e = ISF_Module(channels=decoder_channels, groups=4, kernel_size=3, cyclic_percent=0.0)
55
+
56
+ self.gatefuser = DGG_Module(channels=decoder_channels//4, groups=4)
57
+
58
+ def forward(self, encoder_features):
59
+ # Encoder features: [c1, c2, c3, c4] with shapes [H/4, H/8, H/16, H/32]
60
+ c1, c2, c3, c4 = encoder_features
61
+
62
+ # Get target size (H/4, W/4) - same as c1
63
+ n, _, h, w = c1.shape
64
+
65
+ # Transform each feature and upsample to H/4
66
+ _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
67
+ _c4 = F.interpolate(_c4, size=(h, w), mode='bilinear', align_corners=False)
68
+
69
+ _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
70
+ _c3 = F.interpolate(_c3, size=(h, w), mode='bilinear', align_corners=False)
71
+
72
+ _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
73
+ _c2 = F.interpolate(_c2, size=(h, w), mode='bilinear', align_corners=False)
74
+
75
+ _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
76
+ # c1 is already at the target size, no need to interpolate
77
+
78
+ # Concatenate and fuse
79
+ # print(_c4.shape, _c3.shape, _c2.shape, _c1.shape)
80
+
81
+ # First Stage Ghost
82
+ # 4*256=1024 -> 8*64=512
83
+ _c4_g1_o, _c4_g2_e = self.gfm_c4_1(_c4)
84
+ _c3_g1_o, _c3_g2_e = self.gfm_c3_1(_c3)
85
+ _c2_g1_o, _c2_g2_e = self.gfm_c2_1(_c2)
86
+ _c1_g1_o, _c1_g2_e = self.gfm_c1_1(_c1)
87
+ # 2*4*64 -> 2*256=512 -> 4*64=256
88
+ _c_o_1 = torch.cat([_c4_g1_o, _c3_g1_o, _c2_g1_o, _c1_g1_o], dim=1) # B, 256, H, W
89
+ _c_e_1 = torch.cat([_c4_g2_e, _c3_g2_e, _c2_g2_e, _c1_g2_e], dim=1) # B, 256, H, W
90
+ _c_o_1_f = self.cyclic_shuffle_enhancer_o(_c_o_1) # fused _c_o_1 feature
91
+ _c_e_1_f = self.cyclic_shuffle_enhancer_e(_c_e_1) # fused _c_e_1 feature
92
+
93
+ _c_o_1_o, _c_o_1_e = self.gfm_c_o_1(_c_o_1_f)
94
+ _c_e_1_o, _c_e_1_e = self.gfm_c_e_1(_c_e_1_f)
95
+
96
+ # Second Stage Ghost
97
+ # 2*2*64=256 -> 2*128 -> 4*32=128
98
+ _c_o_2 = torch.cat([_c_o_1_o, _c_e_1_o], dim=1) # (B, 128, H, W)
99
+ _c_e_2 = torch.cat([_c_o_1_e, _c_e_1_e], dim=1) # (B, 128, H, W)
100
+ _c_o_2_o, _c_o_2_e = self.gfm_c_o_2(_c_o_2) # (B, 32 H, W), (B, 32, H, W)
101
+ _c_e_2_o, _c_e_2_e = self.gfm_c_e_2(_c_e_2) # (B, 32 H, W), (B, 32, H, W)
102
+
103
+ # Third Stage Ghost
104
+ # 2*2*32=128 -> 2*64 -> 4*16=64
105
+ _c_o_3 = torch.cat([_c_o_2_o, _c_e_2_o], dim=1) # (B, 64, H, W)
106
+ _c_e_3 = torch.cat([_c_o_2_e, _c_e_2_e], dim=1) # (B, 64, H, W)
107
+ _c_o_3_o, _c_o_3_e = self.gfm_c_o_3(_c_o_3) # (B, 16 H, W), (B, 16, H, W)
108
+ _c_e_3_o, _c_e_3_e = self.gfm_c_e_3(_c_e_3) # (B, 16 H, W), (B, 16, H, W)
109
+
110
+ x = torch.cat([_c_o_3_o, _c_e_3_o, _c_o_3_e, _c_e_3_e], dim=1) # (B, 64, H, W)
111
+ x_f = self.gatefuser(x)
112
+ x = x + x_f
113
+ x = self.dropout(x)
114
+ return x
model/modules/ISF.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class GroupChannelShuffle(nn.Module):
5
+ """
6
+ group-based channel shuffle / interleave.
7
+ groups: number of source groups you want to interleave (e.g. 4 for c1..c4)
8
+ optional cyclic shift (percent) to add deterministic rotation after shuffle.
9
+ """
10
+ def __init__(self, groups: int = 4, cyclic_percent: float = 0.0):
11
+ super().__init__()
12
+ assert groups >= 1
13
+ self.groups = groups
14
+ self.cyclic_percent = cyclic_percent
15
+
16
+ def forward(self, x):
17
+ # x: (B, C, H, W)
18
+ B, C, H, W = x.shape
19
+ g = self.groups
20
+ assert C % g == 0, f"channels {C} not divisible by groups {g}"
21
+ gc = C // g
22
+ # reshape to (B, groups, group_channels, H, W)
23
+ x = x.view(B, g, gc, H, W)
24
+ # transpose to interleave: (B, group_channels, groups, H, W)
25
+ x = x.transpose(1, 2).contiguous()
26
+ x = x.view(B, C, H, W)
27
+ # optional cyclic rotate by percent of channels (deterministic)
28
+ if self.cyclic_percent and 0 < self.cyclic_percent < 1.0:
29
+ shift = int(C * self.cyclic_percent)
30
+ x = torch.roll(x, shifts=shift, dims=1)
31
+ return x
32
+
33
+ class ISF_Module(nn.Module):
34
+ """
35
+ A lightweight module that wraps shuffle + depthwise conv + group-wise scaling + residual.
36
+ - channels: total channels of x
37
+ - groups: number of logical groups (must divide channels)
38
+ """
39
+ def __init__(self, channels: int, groups: int = 4, kernel_size: int = 3, cyclic_percent: float = 0.0):
40
+ super().__init__()
41
+ assert channels % groups == 0
42
+ self.groups = groups
43
+ self.channels = channels
44
+ self.shuffle = GroupChannelShuffle(groups=groups, cyclic_percent=cyclic_percent)
45
+
46
+ # depthwise conv (per-channel local spatial enhancement)
47
+ self.dw = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, groups=channels, bias=False)
48
+ self.bn = nn.BatchNorm2d(channels)
49
+ self.act = nn.ReLU(inplace=True)
50
+
51
+ # group-wise scaling: one scalar per group to reweight groups after fusion
52
+ self.group_scale = nn.Parameter(torch.ones(groups), requires_grad=True) # tiny param overhead
53
+
54
+ # optional small pointwise to re-calibrate channels (commented out to keep ultra-light)
55
+ # self.pw = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
56
+
57
+ def forward(self, x):
58
+ # x: (B, C, H, W)
59
+ B, C, H, W = x.shape
60
+ # 1) deterministic interleave
61
+ y = self.shuffle(x) # (B, C, H, W)
62
+
63
+ # 2) per-channel spatial refine
64
+ y = self.dw(y)
65
+ y = self.bn(y)
66
+ y = self.act(y)
67
+
68
+ # 3) group-wise scaling
69
+ gc = C // self.groups
70
+ # scale = self.group_scale.repeat_interleave(gc).view(1, C, 1, 1) # (1, C, 1, 1)
71
+ scale = self.group_scale.to(x.device)
72
+ scale = scale.repeat_interleave(gc).view(1, C, 1, 1)
73
+ y = y * scale
74
+
75
+ # 4) residual add to preserve original information
76
+ out = x + y
77
+ return out
78
+
model/modules/MiT_Encoder.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on NVIDIA's SegFormer code, cleaned and made independent
3
+ """
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from functools import partial
10
+ from typing import Dict, Sequence, List, Optional, Union, Callable, Any
11
+ import warnings
12
+
13
+
14
+ # ============================================================================
15
+ # Utility Functions
16
+ # ============================================================================
17
+
18
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
19
+ """Truncated normal initialization (from timm)"""
20
+ def norm_cdf(x):
21
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
22
+
23
+ with torch.no_grad():
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
27
+ tensor.erfinv_()
28
+ tensor.mul_(std * math.sqrt(2.))
29
+ tensor.add_(mean)
30
+ tensor.clamp_(min=a, max=b)
31
+ return tensor
32
+
33
+
34
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
35
+ """Truncated normal initialization"""
36
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
37
+
38
+
39
+ def to_2tuple(x):
40
+ """Convert input to 2-tuple"""
41
+ if isinstance(x, (list, tuple)):
42
+ return tuple(x)
43
+ return (x, x)
44
+
45
+
46
+ class DropPath(nn.Module):
47
+ """Drop paths (Stochastic Depth) per sample"""
48
+ def __init__(self, drop_prob=None):
49
+ super(DropPath, self).__init__()
50
+ self.drop_prob = drop_prob
51
+
52
+ def forward(self, x):
53
+ if self.drop_prob == 0. or not self.training:
54
+ return x
55
+ keep_prob = 1 - self.drop_prob
56
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
57
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
58
+ random_tensor.floor_()
59
+ output = x.div(keep_prob) * random_tensor
60
+ return output
61
+
62
+
63
+
64
+
65
+ # ============================================================================
66
+ # Core Modules
67
+ # ============================================================================
68
+
69
+ class LayerNorm(nn.LayerNorm):
70
+ """LayerNorm that supports both 3D (B, N, C) and 4D (B, C, H, W) inputs"""
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ if x.ndim == 4:
73
+ batch_size, channels, height, width = x.shape
74
+ x = x.view(batch_size, channels, -1).transpose(1, 2)
75
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
76
+ x = x.transpose(1, 2).view(batch_size, channels, height, width)
77
+ else:
78
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
79
+ return x
80
+
81
+
82
+ class DWConv(nn.Module):
83
+ """Depthwise Convolution"""
84
+ def __init__(self, dim=768):
85
+ super(DWConv, self).__init__()
86
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
87
+
88
+ def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
89
+ batch_size, _, channels = x.shape
90
+ x = x.transpose(1, 2).view(batch_size, channels, height, width)
91
+ x = self.dwconv(x)
92
+ x = x.flatten(2).transpose(1, 2)
93
+ return x
94
+
95
+
96
+ class Mlp(nn.Module):
97
+ """MLP with depthwise convolution"""
98
+ def __init__(
99
+ self,
100
+ in_features,
101
+ hidden_features=None,
102
+ out_features=None,
103
+ act_layer=nn.GELU,
104
+ drop=0.0,
105
+ ):
106
+ super().__init__()
107
+ out_features = out_features or in_features
108
+ hidden_features = hidden_features or in_features
109
+ self.fc1 = nn.Linear(in_features, hidden_features)
110
+ self.dwconv = DWConv(hidden_features)
111
+ self.act = act_layer()
112
+ self.fc2 = nn.Linear(hidden_features, out_features)
113
+ self.drop = nn.Dropout(drop)
114
+
115
+ self.apply(self._init_weights)
116
+
117
+ def _init_weights(self, m):
118
+ if isinstance(m, nn.Linear):
119
+ trunc_normal_(m.weight, std=0.02)
120
+ if isinstance(m, nn.Linear) and m.bias is not None:
121
+ nn.init.constant_(m.bias, 0)
122
+ elif isinstance(m, nn.Conv2d):
123
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
124
+ fan_out //= m.groups
125
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
126
+ if m.bias is not None:
127
+ m.bias.data.zero_()
128
+
129
+ def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
130
+ x = self.fc1(x)
131
+ x = self.dwconv(x, height, width)
132
+ x = self.act(x)
133
+ x = self.drop(x)
134
+ x = self.fc2(x)
135
+ x = self.drop(x)
136
+ return x
137
+
138
+
139
+ class Attention(nn.Module):
140
+ """Efficient Multi-head Self-Attention with Spatial Reduction"""
141
+ def __init__(
142
+ self,
143
+ dim,
144
+ num_heads=8,
145
+ qkv_bias=False,
146
+ qk_scale=None,
147
+ attn_drop=0.0,
148
+ proj_drop=0.0,
149
+ sr_ratio=1,
150
+ ):
151
+ super().__init__()
152
+ assert dim % num_heads == 0, (
153
+ f"dim {dim} should be divided by num_heads {num_heads}."
154
+ )
155
+
156
+ self.dim = dim
157
+ self.num_heads = num_heads
158
+ head_dim = dim // num_heads
159
+ self.scale = qk_scale or head_dim**-0.5
160
+
161
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
162
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.proj = nn.Linear(dim, dim)
165
+ self.proj_drop = nn.Dropout(proj_drop)
166
+
167
+ self.sr_ratio = sr_ratio
168
+ if sr_ratio > 1:
169
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
170
+ self.norm = LayerNorm(dim)
171
+ else:
172
+ self.sr = nn.Identity()
173
+ self.norm = nn.Identity()
174
+
175
+ self.apply(self._init_weights)
176
+
177
+ def _init_weights(self, m):
178
+ if isinstance(m, nn.Linear):
179
+ trunc_normal_(m.weight, std=0.02)
180
+ if isinstance(m, nn.Linear) and m.bias is not None:
181
+ nn.init.constant_(m.bias, 0)
182
+ elif isinstance(m, LayerNorm):
183
+ nn.init.constant_(m.bias, 0)
184
+ nn.init.constant_(m.weight, 1.0)
185
+ elif isinstance(m, nn.Conv2d):
186
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
187
+ fan_out //= m.groups
188
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
189
+ if m.bias is not None:
190
+ m.bias.data.zero_()
191
+
192
+ def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
193
+ batch_size, N, C = x.shape
194
+ q = (
195
+ self.q(x)
196
+ .reshape(batch_size, N, self.num_heads, C // self.num_heads)
197
+ .permute(0, 2, 1, 3)
198
+ )
199
+
200
+ if self.sr_ratio > 1:
201
+ x_ = x.permute(0, 2, 1).reshape(batch_size, C, height, width)
202
+ x_ = self.sr(x_).reshape(batch_size, C, -1).permute(0, 2, 1)
203
+ x_ = self.norm(x_)
204
+ kv = (
205
+ self.kv(x_)
206
+ .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads)
207
+ .permute(2, 0, 3, 1, 4)
208
+ )
209
+ else:
210
+ kv = (
211
+ self.kv(x)
212
+ .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads)
213
+ .permute(2, 0, 3, 1, 4)
214
+ )
215
+ k, v = kv[0], kv[1]
216
+
217
+ attn = (q @ k.transpose(-2, -1)) * self.scale
218
+ attn = attn.softmax(dim=-1)
219
+ attn = self.attn_drop(attn)
220
+
221
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, N, C)
222
+ x = self.proj(x)
223
+ x = self.proj_drop(x)
224
+
225
+ return x
226
+
227
+
228
+ class Block(nn.Module):
229
+ """Transformer Block"""
230
+ def __init__(
231
+ self,
232
+ dim,
233
+ num_heads,
234
+ mlp_ratio=4.0,
235
+ qkv_bias=False,
236
+ qk_scale=None,
237
+ drop=0.0,
238
+ attn_drop=0.0,
239
+ drop_path=0.0,
240
+ act_layer=nn.GELU,
241
+ norm_layer=LayerNorm,
242
+ sr_ratio=1,
243
+ ):
244
+ super().__init__()
245
+ self.norm1 = norm_layer(dim)
246
+ self.attn = Attention(
247
+ dim,
248
+ num_heads=num_heads,
249
+ qkv_bias=qkv_bias,
250
+ qk_scale=qk_scale,
251
+ attn_drop=attn_drop,
252
+ proj_drop=drop,
253
+ sr_ratio=sr_ratio,
254
+ )
255
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
256
+ self.norm2 = norm_layer(dim)
257
+ mlp_hidden_dim = int(dim * mlp_ratio)
258
+ self.mlp = Mlp(
259
+ in_features=dim,
260
+ hidden_features=mlp_hidden_dim,
261
+ act_layer=act_layer,
262
+ drop=drop,
263
+ )
264
+
265
+ self.apply(self._init_weights)
266
+
267
+ def _init_weights(self, m):
268
+ if isinstance(m, nn.Linear):
269
+ trunc_normal_(m.weight, std=0.02)
270
+ if isinstance(m, nn.Linear) and m.bias is not None:
271
+ nn.init.constant_(m.bias, 0)
272
+ elif isinstance(m, LayerNorm):
273
+ nn.init.constant_(m.bias, 0)
274
+ nn.init.constant_(m.weight, 1.0)
275
+ elif isinstance(m, nn.Conv2d):
276
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
277
+ fan_out //= m.groups
278
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
279
+ if m.bias is not None:
280
+ m.bias.data.zero_()
281
+
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ batch_size, _, height, width = x.shape
284
+ x = x.flatten(2).transpose(1, 2)
285
+ x = x + self.drop_path(self.attn(self.norm1(x), height, width))
286
+ x = x + self.drop_path(self.mlp(self.norm2(x), height, width))
287
+ x = x.transpose(1, 2).view(batch_size, -1, height, width)
288
+ return x
289
+
290
+
291
+ class OverlapPatchEmbed(nn.Module):
292
+ """Image to Patch Embedding with Overlapping Patches"""
293
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
294
+ super().__init__()
295
+ img_size = to_2tuple(img_size)
296
+ patch_size = to_2tuple(patch_size)
297
+
298
+ self.img_size = img_size
299
+ self.patch_size = patch_size
300
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
301
+ self.num_patches = self.H * self.W
302
+ self.proj = nn.Conv2d(
303
+ in_chans,
304
+ embed_dim,
305
+ kernel_size=patch_size,
306
+ stride=stride,
307
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
308
+ )
309
+ self.norm = LayerNorm(embed_dim)
310
+
311
+ self.apply(self._init_weights)
312
+
313
+ def _init_weights(self, m):
314
+ if isinstance(m, nn.Linear):
315
+ trunc_normal_(m.weight, std=0.02)
316
+ if isinstance(m, nn.Linear) and m.bias is not None:
317
+ nn.init.constant_(m.bias, 0)
318
+ elif isinstance(m, LayerNorm):
319
+ nn.init.constant_(m.bias, 0)
320
+ nn.init.constant_(m.weight, 1.0)
321
+ elif isinstance(m, nn.Conv2d):
322
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
323
+ fan_out //= m.groups
324
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
325
+ if m.bias is not None:
326
+ m.bias.data.zero_()
327
+
328
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
+ x = self.proj(x)
330
+ x = self.norm(x)
331
+ return x
332
+
333
+
334
+ # ============================================================================
335
+ # Mix Vision Transformer (Encoder)
336
+ # ============================================================================
337
+
338
+ class MixVisionTransformer(nn.Module):
339
+ """Mix Vision Transformer - Hierarchical Transformer Encoder"""
340
+ def __init__(
341
+ self,
342
+ img_size=224,
343
+ in_chans=3,
344
+ embed_dims=[64, 128, 256, 512],
345
+ num_heads=[1, 2, 4, 8],
346
+ mlp_ratios=[4, 4, 4, 4],
347
+ qkv_bias=False,
348
+ qk_scale=None,
349
+ drop_rate=0.0,
350
+ attn_drop_rate=0.0,
351
+ drop_path_rate=0.0,
352
+ norm_layer=LayerNorm,
353
+ depths=[3, 4, 6, 3],
354
+ sr_ratios=[8, 4, 2, 1],
355
+ ):
356
+ super().__init__()
357
+ self.depths = depths
358
+
359
+ # Patch embeddings for each stage
360
+ self.patch_embed1 = OverlapPatchEmbed(
361
+ img_size=img_size,
362
+ patch_size=7,
363
+ stride=4,
364
+ in_chans=in_chans,
365
+ embed_dim=embed_dims[0],
366
+ )
367
+ self.patch_embed2 = OverlapPatchEmbed(
368
+ img_size=img_size // 4,
369
+ patch_size=3,
370
+ stride=2,
371
+ in_chans=embed_dims[0],
372
+ embed_dim=embed_dims[1],
373
+ )
374
+ self.patch_embed3 = OverlapPatchEmbed(
375
+ img_size=img_size // 8,
376
+ patch_size=3,
377
+ stride=2,
378
+ in_chans=embed_dims[1],
379
+ embed_dim=embed_dims[2],
380
+ )
381
+ self.patch_embed4 = OverlapPatchEmbed(
382
+ img_size=img_size // 16,
383
+ patch_size=3,
384
+ stride=2,
385
+ in_chans=embed_dims[2],
386
+ embed_dim=embed_dims[3],
387
+ )
388
+
389
+ # Stochastic depth decay rule
390
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
391
+
392
+ # Transformer blocks for each stage
393
+ cur = 0
394
+ self.block1 = nn.Sequential(
395
+ *[
396
+ Block(
397
+ dim=embed_dims[0],
398
+ num_heads=num_heads[0],
399
+ mlp_ratio=mlp_ratios[0],
400
+ qkv_bias=qkv_bias,
401
+ qk_scale=qk_scale,
402
+ drop=drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[cur + i],
405
+ norm_layer=norm_layer,
406
+ sr_ratio=sr_ratios[0],
407
+ )
408
+ for i in range(depths[0])
409
+ ]
410
+ )
411
+ self.norm1 = norm_layer(embed_dims[0])
412
+
413
+ cur += depths[0]
414
+ self.block2 = nn.Sequential(
415
+ *[
416
+ Block(
417
+ dim=embed_dims[1],
418
+ num_heads=num_heads[1],
419
+ mlp_ratio=mlp_ratios[1],
420
+ qkv_bias=qkv_bias,
421
+ qk_scale=qk_scale,
422
+ drop=drop_rate,
423
+ attn_drop=attn_drop_rate,
424
+ drop_path=dpr[cur + i],
425
+ norm_layer=norm_layer,
426
+ sr_ratio=sr_ratios[1],
427
+ )
428
+ for i in range(depths[1])
429
+ ]
430
+ )
431
+ self.norm2 = norm_layer(embed_dims[1])
432
+
433
+ cur += depths[1]
434
+ self.block3 = nn.Sequential(
435
+ *[
436
+ Block(
437
+ dim=embed_dims[2],
438
+ num_heads=num_heads[2],
439
+ mlp_ratio=mlp_ratios[2],
440
+ qkv_bias=qkv_bias,
441
+ qk_scale=qk_scale,
442
+ drop=drop_rate,
443
+ attn_drop=attn_drop_rate,
444
+ drop_path=dpr[cur + i],
445
+ norm_layer=norm_layer,
446
+ sr_ratio=sr_ratios[2],
447
+ )
448
+ for i in range(depths[2])
449
+ ]
450
+ )
451
+ self.norm3 = norm_layer(embed_dims[2])
452
+
453
+ cur += depths[2]
454
+ self.block4 = nn.Sequential(
455
+ *[
456
+ Block(
457
+ dim=embed_dims[3],
458
+ num_heads=num_heads[3],
459
+ mlp_ratio=mlp_ratios[3],
460
+ qkv_bias=qkv_bias,
461
+ qk_scale=qk_scale,
462
+ drop=drop_rate,
463
+ attn_drop=attn_drop_rate,
464
+ drop_path=dpr[cur + i],
465
+ norm_layer=norm_layer,
466
+ sr_ratio=sr_ratios[3],
467
+ )
468
+ for i in range(depths[3])
469
+ ]
470
+ )
471
+ self.norm4 = norm_layer(embed_dims[3])
472
+
473
+ self.apply(self._init_weights)
474
+
475
+ def _init_weights(self, m):
476
+ if isinstance(m, nn.Linear):
477
+ trunc_normal_(m.weight, std=0.02)
478
+ if isinstance(m, nn.Linear) and m.bias is not None:
479
+ nn.init.constant_(m.bias, 0)
480
+ elif isinstance(m, LayerNorm):
481
+ nn.init.constant_(m.bias, 0)
482
+ nn.init.constant_(m.weight, 1.0)
483
+ elif isinstance(m, nn.Conv2d):
484
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
485
+ fan_out //= m.groups
486
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
487
+ if m.bias is not None:
488
+ m.bias.data.zero_()
489
+
490
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
491
+ outs = []
492
+
493
+ # Stage 1: H/4, W/4
494
+ x = self.patch_embed1(x)
495
+ x = self.block1(x)
496
+ x = self.norm1(x).contiguous()
497
+ outs.append(x)
498
+
499
+ # Stage 2: H/8, W/8
500
+ x = self.patch_embed2(x)
501
+ x = self.block2(x)
502
+ x = self.norm2(x).contiguous()
503
+ outs.append(x)
504
+
505
+ # Stage 3: H/16, W/16
506
+ x = self.patch_embed3(x)
507
+ x = self.block3(x)
508
+ x = self.norm3(x).contiguous()
509
+ outs.append(x)
510
+
511
+ # Stage 4: H/32, W/32
512
+ x = self.patch_embed4(x)
513
+ x = self.block4(x)
514
+ x = self.norm4(x).contiguous()
515
+ outs.append(x)
516
+
517
+ return outs
model/modules/Seg_Head.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ # ============================================================================
5
+ # Activation Module
6
+ # ============================================================================
7
+
8
+ class Activation(nn.Module):
9
+ """Activation wrapper that supports various activation functions"""
10
+ def __init__(self, activation=None):
11
+ super().__init__()
12
+
13
+ if activation is None or activation == 'identity':
14
+ self.activation = nn.Identity()
15
+ elif activation == 'sigmoid':
16
+ self.activation = nn.Sigmoid()
17
+ elif activation == 'softmax':
18
+ self.activation = nn.Softmax(dim=1)
19
+ elif activation == 'softmax2d':
20
+ self.activation = nn.Softmax(dim=1)
21
+ elif activation == 'logsoftmax':
22
+ self.activation = nn.LogSoftmax(dim=1)
23
+ elif activation == 'tanh':
24
+ self.activation = nn.Tanh()
25
+ elif activation == 'relu':
26
+ self.activation = nn.ReLU(inplace=True)
27
+ elif callable(activation):
28
+ self.activation = activation
29
+ else:
30
+ raise ValueError(
31
+ f'Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {activation}'
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.activation(x)
36
+
37
+ # ============================================================================
38
+ # Segmentation Head (nn.Sequential style)
39
+ # ============================================================================
40
+
41
+ class SegmentationHead(nn.Sequential):
42
+ """Segmentation head using nn.Sequential style"""
43
+ def __init__(
44
+ self,
45
+ in_channels,
46
+ out_channels,
47
+ kernel_size=3,
48
+ activation=None,
49
+ upsampling=1
50
+ ):
51
+ conv2d = nn.Conv2d(
52
+ in_channels,
53
+ out_channels,
54
+ kernel_size=kernel_size,
55
+ padding=kernel_size // 2
56
+ )
57
+ upsampling_layer = (
58
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
59
+ if upsampling > 1
60
+ else nn.Identity()
61
+ )
62
+ activation_layer = Activation(activation)
63
+ super().__init__(conv2d, upsampling_layer, activation_layer)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ onnxruntime
6
+
samples/kvasir/images/sample_01.jpg ADDED
samples/kvasir/images/sample_02.jpg ADDED

Git LFS Details

  • SHA256: e7046f72d982bf65c853e1465f8f45d7f29bca4f2d0ceac286641dc27e4ac872
  • Pointer size: 131 Bytes
  • Size of remote file: 220 kB
samples/kvasir/outputs/depth/sample_01.png ADDED
samples/kvasir/outputs/depth/sample_02.png ADDED

Git LFS Details

  • SHA256: d441ffac673c06779f38b811751849225f5381b97c4e0a3ecbe49c390adfbdd4
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
samples/kvasir/outputs/masks/sample_01.png ADDED
samples/kvasir/outputs/masks/sample_02.png ADDED
samples/kvasir/outputs/overlay/sample_01.jpg ADDED
samples/kvasir/outputs/overlay/sample_02.jpg ADDED

Git LFS Details

  • SHA256: 28f808f52574e0fc443abcde81a1ec88fd6822b2997f111a438c629c3802a7cb
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
scripts/export_onnx.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ REPO_ROOT = Path(__file__).resolve().parents[1]
8
+ sys.path.insert(0, str(REPO_ROOT))
9
+
10
+ from model.depthpolyp import build_depthpolyp
11
+
12
+
13
+ def load_checkpoint(path):
14
+ try:
15
+ return torch.load(path, map_location="cpu", weights_only=True)
16
+ except TypeError:
17
+ return torch.load(path, map_location="cpu")
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Export DepthPolyp to ONNX.")
22
+ parser.add_argument("--checkpoint", default="checkpoints/DepthPolyp_Kvasir.pth")
23
+ parser.add_argument("--output", default="checkpoints/DepthPolyp_Kvasir.onnx")
24
+ parser.add_argument("--image-size", type=int, default=224)
25
+ parser.add_argument("--opset", type=int, default=17)
26
+ return parser.parse_args()
27
+
28
+
29
+ def main():
30
+ args = parse_args()
31
+ model = build_depthpolyp(
32
+ encoder_name="b0",
33
+ in_channels=3,
34
+ num_classes=2,
35
+ decoder_channels=256,
36
+ activation=None,
37
+ )
38
+ state_dict = load_checkpoint(args.checkpoint)
39
+ model.load_state_dict(state_dict, strict=True)
40
+ model.eval()
41
+
42
+ output_path = Path(args.output)
43
+ output_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ dummy = torch.randn(1, 3, args.image_size, args.image_size)
46
+ torch.onnx.export(
47
+ model,
48
+ dummy,
49
+ output_path,
50
+ input_names=["image"],
51
+ output_names=["segmentation", "depth"],
52
+ opset_version=args.opset,
53
+ do_constant_folding=True,
54
+ dynamic_axes={
55
+ "image": {0: "batch"},
56
+ "segmentation": {0: "batch"},
57
+ "depth": {0: "batch"},
58
+ },
59
+ )
60
+ print(f"Exported ONNX model to {output_path}")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ main()
scripts/infer_onnx.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+
8
+
9
+ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description="Run DepthPolyp ONNX inference on images.")
14
+ parser.add_argument("--onnx", default="checkpoints/DepthPolyp_Kvasir.onnx")
15
+ parser.add_argument("--input", default="samples/kvasir/images")
16
+ parser.add_argument("--output", default="samples/kvasir/outputs")
17
+ parser.add_argument("--image-size", type=int, default=224)
18
+ parser.add_argument("--threshold", type=float, default=0.3)
19
+ return parser.parse_args()
20
+
21
+
22
+ def list_images(input_path: Path):
23
+ if input_path.is_file():
24
+ return [input_path]
25
+ return sorted(path for path in input_path.rglob("*") if path.suffix.lower() in IMAGE_EXTENSIONS)
26
+
27
+
28
+ def preprocess(image_path: Path, image_size: int):
29
+ image = Image.open(image_path).convert("RGB")
30
+ original_size = image.size
31
+ resized = image.resize((image_size, image_size), Image.BILINEAR)
32
+ array = np.asarray(resized).astype(np.float32) / 255.0
33
+ tensor = np.transpose(array, (2, 0, 1))[None, ...]
34
+ return image, original_size, tensor
35
+
36
+
37
+ def to_grayscale(probability: np.ndarray, size):
38
+ probability = np.clip(probability, 0.0, 1.0)
39
+ image = Image.fromarray((probability * 255).astype(np.uint8), mode="L")
40
+ return image.resize(size, Image.BILINEAR)
41
+
42
+
43
+ def colorize_purple_yellow(probability: np.ndarray, size):
44
+ probability = np.clip(probability, 0.0, 1.0)
45
+ stops = np.array(
46
+ [
47
+ [38, 5, 84],
48
+ [86, 33, 132],
49
+ [141, 48, 140],
50
+ [203, 71, 119],
51
+ [245, 135, 48],
52
+ [252, 231, 37],
53
+ ],
54
+ dtype=np.float32,
55
+ )
56
+ scaled = probability * (len(stops) - 1)
57
+ lower = np.floor(scaled).astype(np.int32)
58
+ upper = np.clip(lower + 1, 0, len(stops) - 1)
59
+ alpha = (scaled - lower)[..., None]
60
+ colored = stops[lower] * (1.0 - alpha) + stops[upper] * alpha
61
+ image = Image.fromarray(colored.astype(np.uint8), mode="RGB")
62
+ return image.resize(size, Image.BILINEAR)
63
+
64
+
65
+ def make_overlay(image: Image.Image, mask: Image.Image):
66
+ base = image.convert("RGBA")
67
+ mask_array = np.asarray(mask).astype(np.float32) / 255.0
68
+ color = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8)
69
+ color[..., 0] = 252
70
+ color[..., 1] = 231
71
+ color[..., 2] = 37
72
+ color[..., 3] = (mask_array * 155).astype(np.uint8)
73
+ return Image.alpha_composite(base, Image.fromarray(color, mode="RGBA")).convert("RGB")
74
+
75
+
76
+ def main():
77
+ args = parse_args()
78
+ input_path = Path(args.input)
79
+ output_root = Path(args.output)
80
+ mask_dir = output_root / "masks"
81
+ depth_dir = output_root / "depth"
82
+ overlay_dir = output_root / "overlay"
83
+ for directory in (mask_dir, depth_dir, overlay_dir):
84
+ directory.mkdir(parents=True, exist_ok=True)
85
+
86
+ session = ort.InferenceSession(args.onnx, providers=["CPUExecutionProvider"])
87
+ input_name = session.get_inputs()[0].name
88
+ images = list_images(input_path)
89
+ if not images:
90
+ raise FileNotFoundError(f"No images found under {input_path}")
91
+
92
+ for image_path in images:
93
+ image, original_size, tensor = preprocess(image_path, args.image_size)
94
+ segmentation, depth = session.run(None, {input_name: tensor})
95
+ seg_prob = segmentation[0, 0]
96
+ depth_prob = depth[0, 0]
97
+
98
+ seg_image = to_grayscale(seg_prob, original_size)
99
+ depth_image = colorize_purple_yellow(depth_prob, original_size)
100
+ binary_mask = seg_image.point(lambda value: 255 if value >= int(args.threshold * 255) else 0)
101
+ overlay = make_overlay(image, seg_image)
102
+
103
+ stem = image_path.stem
104
+ binary_mask.save(mask_dir / f"{stem}.png")
105
+ depth_image.save(depth_dir / f"{stem}.png")
106
+ overlay.save(overlay_dir / f"{stem}.jpg", quality=95)
107
+
108
+ print(f"Processed {len(images)} image(s). Outputs saved to {output_root}")
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()