Leo-Lyu commited on
Commit
1715fda
·
verified ·
1 Parent(s): 80d692a

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/MCP-MedSAM.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Leo-Lyu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,47 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MCP-MedSAM
2
+
3
+ Pytorch Implementation of the paper:
4
+ "[MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day](https://arxiv.org/abs/2412.05888)"
5
+
6
+ ![MCP-MedSAM Architecture](docs/MCP-MedSAM.png)
7
+
8
+ ## 📄 Overview
9
+
10
+ This work proposes a lightweight variant of MedSAM by integrating:
11
+
12
+ - A **pre-trained Tiny ViT** as the vision backbone
13
+ - Two novel prompt types:
14
+ - **Modality Prompt**
15
+ - **Content Prompt**
16
+ - A **modified mask decoder** adapted to these prompts
17
+
18
+ To further improve performance across imaging modalities, we introduce a **modality-aware data sampling strategy** that ensures better balance and generalization.
19
+
20
+ With these enhancements, our model achieves strong multi-modality segmentation performance, and can be trained in approximately **1 day on a single A100 (40GB)** GPU.
21
+
22
+ <!--
23
+ We are currently releasing the inference code along with the model weight. You can download from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing).
24
+
25
+ The training code has been released and you can train your . -->
26
+
27
+ ## Requirements
28
+
29
+ * Python==3.10.14
30
+ * torch==2.0.0
31
+ * torchvision==0.15.0
32
+ * transformers==4.49.0
33
+
34
+ ## Training and Inference
35
+
36
+ Training and inference can be done by running train.py and infer.py. Additionally, we also release the model weight for inference, which can be downloaded from [here](https://drive.google.com/drive/folders/1NW4aSNhk-dtiK-dicTAUp0g0eR2fryNi?usp=sharing).
37
+
38
+ ## Citation
39
+
40
+ ```bash
41
+ @article{lyu2024mcp,
42
+ title={MCP-MedSAM: A Powerful Lightweight Medical Segment Anything Model Trained with a Single GPU in Just One Day},
43
+ author={Lyu, Donghang and Gao, Ruochen and Staring, Marius},
44
+ journal={arXiv preprint arXiv:2412.05888},
45
+ year={2024}
46
+ }
47
+ ```
docs/MCP-MedSAM.png ADDED

Git LFS Details

  • SHA256: 2b082ffd221532ee5590679539cf4eacb13f729f1d71aa987e202b5d889323d0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.09 MB
infer.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import makedirs
2
+ from os.path import join, basename
3
+ from glob import glob
4
+ from tqdm import tqdm
5
+ from time import time
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4
12
+ from matplotlib import pyplot as plt
13
+ import cv2
14
+ import argparse
15
+ from collections import OrderedDict
16
+ import pandas as pd
17
+ from datetime import datetime
18
+ from transformers import CLIPModel, CLIPTokenizer
19
+
20
+ torch.set_float32_matmul_precision('high')
21
+ torch.manual_seed(42)
22
+ torch.cuda.manual_seed(42)
23
+ np.random.seed(42)
24
+
25
+ parser = argparse.ArgumentParser()
26
+
27
+ parser.add_argument(
28
+ '-i',
29
+ '--input_dir',
30
+ type=str,
31
+ default='',
32
+ # required=True,
33
+ help='root directory of the data',
34
+ )
35
+ parser.add_argument(
36
+ '-o',
37
+ '--output_dir',
38
+ type=str,
39
+ default='',
40
+ help='directory to save the prediction',
41
+ )
42
+ parser.add_argument(
43
+ '-lite_medsam_checkpoint_path',
44
+ type=str,
45
+ default="",
46
+ help='path to the checkpoint of MedSAM-Lite',
47
+ )
48
+ parser.add_argument(
49
+ '-device',
50
+ type=str,
51
+ default="cuda:0",
52
+ help='device to run the inference',
53
+ )
54
+ parser.add_argument(
55
+ '-num_workers',
56
+ type=int,
57
+ default=4,
58
+ help='number of workers for inference with multiprocessing',
59
+ )
60
+ parser.add_argument(
61
+ '--save_overlay',
62
+ default=False,
63
+ action='store_true',
64
+ help='whether to save the overlay image'
65
+ )
66
+
67
+ parser.add_argument(
68
+ '-png_save_dir',
69
+ type=str,
70
+ default=None,
71
+ help='directory to save the overlay image'
72
+ )
73
+
74
+ args = parser.parse_args()
75
+
76
+ data_root = args.input_dir
77
+ pred_save_dir = args.output_dir
78
+ save_overlay = args.save_overlay
79
+ num_workers = args.num_workers
80
+
81
+ if save_overlay:
82
+ assert args.png_save_dir is not None, "Please specify the directory to save the overlay image"
83
+ png_save_dir = args.png_save_dir
84
+ makedirs(png_save_dir, exist_ok=True)
85
+
86
+ lite_medsam_checkpoint_path = args.lite_medsam_checkpoint_path
87
+ makedirs(pred_save_dir, exist_ok=True)
88
+ device = torch.device(args.device)
89
+ image_size = 256
90
+ model1 = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32", resume_download=True)
91
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16", resume_download=True)
92
+ model1.requires_grad_(False)
93
+
94
+
95
+ def resize_longest_side(image, target_length=256):
96
+ """
97
+ Resize image to target_length while keeping the aspect ratio
98
+ Expects a numpy array with shape HxWxC in uint8 format.
99
+ """
100
+ oldh, oldw = image.shape[0], image.shape[1]
101
+ scale = target_length * 1.0 / max(oldh, oldw)
102
+ newh, neww = oldh * scale, oldw * scale
103
+ neww, newh = int(neww + 0.5), int(newh + 0.5)
104
+ target_size = (neww, newh)
105
+
106
+ return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
107
+
108
+ def pad_image(image, target_size=256):
109
+ """
110
+ Pad image to target_size
111
+ Expects a numpy array with shape HxWxC in uint8 format.
112
+ """
113
+ # Pad
114
+ h, w = image.shape[0], image.shape[1]
115
+ padh = target_size - h
116
+ padw = target_size - w
117
+ if len(image.shape) == 3: ## Pad image
118
+ image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
119
+ else: ## Pad gt mask
120
+ image_padded = np.pad(image, ((0, padh), (0, padw)))
121
+
122
+ return image_padded
123
+
124
+ class MedSAM_Lite(nn.Module):
125
+ def __init__(
126
+ self,
127
+ image_encoder,
128
+ mask_decoder,
129
+ prompt_encoder
130
+ ):
131
+ super().__init__()
132
+ self.image_encoder = image_encoder
133
+ self.mask_decoder = mask_decoder
134
+ self.prompt_encoder = prompt_encoder
135
+
136
+ def forward(self, image, points, boxes, masks, features, crops, text_features, category_idx):
137
+ image_embedding = self.image_encoder(image)
138
+ with torch.no_grad():
139
+ boxes = torch.as_tensor(boxes, dtype=torch.float32, device=image.device)
140
+ if len(boxes.shape) == 2:
141
+ boxes = boxes[:, None, :] # (B, 1, 4)
142
+
143
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
144
+ points=points,
145
+ boxes=boxes,
146
+ masks=masks,
147
+ features=features,
148
+ crops=crops,
149
+ text_features = text_features,
150
+ category_idx=category_idx
151
+ )
152
+ low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder(
153
+ image_embeddings=image_embedding, # (B, 256, 64, 64)
154
+ image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
155
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
156
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
157
+ multimask_output=False,
158
+ ) # (B, 1, 256, 256)
159
+
160
+ return low_res_masks
161
+
162
+ @torch.no_grad()
163
+ def postprocess_masks(self, masks, new_size, original_size):
164
+ """
165
+ Do cropping and resizing
166
+
167
+ Parameters
168
+ ----------
169
+ masks : torch.Tensor
170
+ masks predicted by the model
171
+ new_size : tuple
172
+ the shape of the image after resizing to the longest side of 256
173
+ original_size : tuple
174
+ the original shape of the image
175
+
176
+ Returns
177
+ -------
178
+ torch.Tensor
179
+ the upsampled mask to the original size
180
+ """
181
+ # Crop
182
+ masks = masks[..., :new_size[0], :new_size[1]]
183
+ # Resize
184
+ masks = F.interpolate(
185
+ masks,
186
+ size=(original_size[0], original_size[1]),
187
+ mode="bilinear",
188
+ align_corners=False,
189
+ )
190
+
191
+ return masks
192
+
193
+
194
+ def show_mask(mask, ax, mask_color=None, alpha=0.5):
195
+ """
196
+ show mask on the image
197
+
198
+ Parameters
199
+ ----------
200
+ mask : numpy.ndarray
201
+ mask of the image
202
+ ax : matplotlib.axes.Axes
203
+ axes to plot the mask
204
+ mask_color : numpy.ndarray
205
+ color of the mask
206
+ alpha : float
207
+ transparency of the mask
208
+ """
209
+ if mask_color is not None:
210
+ color = np.concatenate([mask_color, np.array([alpha])], axis=0)
211
+ else:
212
+ color = np.array([251/255, 252/255, 30/255, alpha])
213
+ h, w = mask.shape[-2:]
214
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
215
+ ax.imshow(mask_image)
216
+
217
+
218
+ def show_box(box, ax, edgecolor='blue'):
219
+ """
220
+ show bounding box on the image
221
+
222
+ Parameters
223
+ ----------
224
+ box : numpy.ndarray
225
+ bounding box coordinates in the original image
226
+ ax : matplotlib.axes.Axes
227
+ axes to plot the bounding box
228
+ edgecolor : str
229
+ color of the bounding box
230
+ """
231
+ x0, y0 = box[0], box[1]
232
+ w, h = box[2] - box[0], box[3] - box[1]
233
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))
234
+
235
+ def show_points(points, ax):
236
+ points = points.numpy()
237
+ for i, (x, y) in enumerate(points):
238
+ ax.scatter(x, y, color='yellow', s=15)
239
+
240
+ def get_bbox256(mask_256, bbox_shift=3):
241
+ """
242
+ Get the bounding box coordinates from the mask (256x256)
243
+
244
+ Parameters
245
+ ----------
246
+ mask_256 : numpy.ndarray
247
+ the mask of the resized image
248
+
249
+ bbox_shift : int
250
+ Add perturbation to the bounding box coordinates
251
+
252
+ Returns
253
+ -------
254
+ numpy.ndarray
255
+ bounding box coordinates in the resized image
256
+ """
257
+ y_indices, x_indices = np.where(mask_256 > 0)
258
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
259
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
260
+ # add perturbation to bounding box coordinates and test the robustness
261
+ # this can be removed if you do not want to test the robustness
262
+ H, W = mask_256.shape
263
+ x_min = max(0, x_min - bbox_shift)
264
+ x_max = min(W, x_max + bbox_shift)
265
+ y_min = max(0, y_min - bbox_shift)
266
+ y_max = min(H, y_max + bbox_shift)
267
+
268
+ bboxes256 = np.array([x_min, y_min, x_max, y_max])
269
+
270
+ return bboxes256
271
+
272
+ def resize_box_to_256(box, original_size):
273
+ """
274
+ the input bounding box is obtained from the original image
275
+ here, we rescale it to the coordinates of the resized image
276
+
277
+ Parameters
278
+ ----------
279
+ box : numpy.ndarray
280
+ bounding box coordinates in the original image
281
+ original_size : tuple
282
+ the original size of the image
283
+
284
+ Returns
285
+ -------
286
+ numpy.ndarray
287
+ bounding box coordinates in the resized image
288
+ """
289
+ new_box = np.zeros_like(box)
290
+ ratio = 256 / max(original_size)
291
+ for i in range(len(box)):
292
+ new_box[i] = int(box[i] * ratio)
293
+
294
+ return new_box, ratio
295
+
296
+
297
+ def get_points_256(box, gt2D):
298
+ gt2D = np.mean(gt2D, axis=-1)
299
+ if len(box)==1:
300
+ x_min, y_min, x_max, y_max = box[0]
301
+ else:
302
+ x_min, y_min, x_max, y_max = box
303
+
304
+ try:
305
+ bounder_shiftx = np.random.randint(int((x_max-x_min)/5), int(2*(x_max-x_min)/5), (1,))
306
+ # bounder_shiftx = int((x_max-x_min)/5)
307
+ except:
308
+ bounder_shiftx = 0
309
+ try:
310
+ bounder_shifty = np.random.randint(int((y_max-y_min)/5), int(2*(y_max-y_min)/5), (1,))
311
+ # bounder_shifty = int((y_max-y_min)/5)
312
+ except:
313
+ bounder_shifty = 0
314
+
315
+ mid_x = int((x_min+x_max)//2)
316
+ mid_y = int((y_min+y_max)//2)
317
+ x_min = int(x_min+bounder_shiftx)
318
+ x_max = int(x_max-bounder_shiftx)
319
+ y_min = int(y_min+bounder_shifty)
320
+ y_max = int(y_max-bounder_shifty)
321
+ cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
322
+
323
+ coords = []
324
+ for i in range(4):
325
+ gt2D_tmp = np.zeros((256, 256))
326
+ gt2D_tmp[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]] = gt2D[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]]
327
+ y_indices, x_indices = np.where(gt2D_tmp > 0)
328
+ if y_indices.size==0:
329
+ coords.append([mid_x, mid_y])
330
+ else:
331
+ x_point = np.random.choice(x_indices)
332
+ y_point = np.random.choice(y_indices)
333
+ coords.append([x_point, y_point])
334
+ coords = np.array(coords).reshape(4, 2)
335
+ coords = torch.tensor(coords).float()
336
+ return coords
337
+
338
+ def get_points_256_v0(box, gt2D):
339
+ gt2D = np.mean(gt2D, axis=-1)
340
+ if len(box)==1:
341
+ x_min, y_min, x_max, y_max = box[0]
342
+ else:
343
+ x_min, y_min, x_max, y_max = box
344
+ mid_x = int((x_min+x_max)//2)
345
+ mid_y = int((y_min+y_max)//2)
346
+ try:
347
+ bounder_shiftx = np.random.randint(int((x_max-x_min)/3), int(2*(x_max-x_min)/4)-1, (1,))
348
+ # bounder_shiftx = 0
349
+ except:
350
+ bounder_shiftx = 0
351
+ try:
352
+ bounder_shifty = np.random.randint(int((y_max-y_min)/3), int(2*(y_max-y_min)/4)-1, (1,))
353
+ # bounder_shifty = 0
354
+ except:
355
+ bounder_shifty = 0
356
+ x_min = int(x_min+bounder_shiftx)
357
+ x_max = int(x_max-bounder_shiftx)
358
+ y_min = int(y_min+bounder_shifty)
359
+ y_max = int(y_max-bounder_shifty)
360
+ # cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
361
+
362
+ coords = []
363
+ gt2D_tmp = np.zeros((256, 256))
364
+ gt2D_tmp[y_min:y_max, x_min:x_max] = gt2D[y_min:y_max, x_min:x_max]
365
+ for i in range(4):
366
+ y_indices, x_indices = np.where(gt2D_tmp > 0)
367
+ if y_indices.size==0:
368
+ coords.append([mid_x, mid_y])
369
+ else:
370
+ x_point = np.random.choice(x_indices)
371
+ y_point = np.random.choice(y_indices)
372
+ coords.append([x_point, y_point])
373
+ coords = np.array(coords).reshape(4, 2)
374
+ coords = torch.tensor(coords).float()
375
+ return coords
376
+
377
+ @torch.no_grad()
378
+ def medsam_inference(medsam_model, img_embed, box_256, features, crops, text_features, category_idx, new_size, original_size):
379
+ """
380
+ Perform inference using the LiteMedSAM model.
381
+
382
+ Args:
383
+ medsam_model (MedSAMModel): The MedSAM model.
384
+ img_embed (torch.Tensor): The image embeddings.
385
+ box_256 (numpy.ndarray): The bounding box coordinates.
386
+ new_size (tuple): The new size of the image.
387
+ original_size (tuple): The original size of the image.
388
+ Returns:
389
+ tuple: A tuple containing the segmented image and the intersection over union (IoU) score.
390
+ """
391
+ box_torch = torch.as_tensor(box_256[None, None, ...], dtype=torch.float, device=img_embed.device)
392
+ features = features.unsqueeze(0).to(device)
393
+ crops = crops.unsqueeze(0).to(device)
394
+ category_idx = torch.tensor([category_idx]).to(device)
395
+ sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
396
+ points=None,
397
+ boxes=box_torch,
398
+ masks=None,
399
+ features=features,
400
+ crops=crops,
401
+ text_features = text_features,
402
+ category_idx=category_idx
403
+ )
404
+
405
+ low_res_logits, iou, _, _, _ = medsam_model.mask_decoder(
406
+ image_embeddings=img_embed, # (B, 256, 64, 64)
407
+ image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
408
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
409
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
410
+ multimask_output=False
411
+ )
412
+
413
+ low_res_pred = medsam_model.postprocess_masks(low_res_logits, new_size, original_size)
414
+ low_res_pred = torch.sigmoid(low_res_pred)
415
+ low_res_pred = low_res_pred.squeeze().cpu().numpy()
416
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
417
+ return medsam_seg, iou
418
+
419
+ medsam_lite_image_encoder = TinyViT(
420
+ img_size=256,
421
+ in_chans=3,
422
+ embed_dims=[
423
+ 64, ## (64, 256, 256)
424
+ 128, ## (128, 128, 128)
425
+ 160, ## (160, 64, 64)
426
+ 320 ## (320, 64, 64)
427
+ ],
428
+ depths=[2, 2, 6, 2],
429
+ num_heads=[2, 4, 5, 10],
430
+ window_sizes=[7, 7, 14, 7],
431
+ mlp_ratio=4.,
432
+ drop_rate=0.,
433
+ drop_path_rate=0.0,
434
+ use_checkpoint=False,
435
+ mbconv_expand_ratio=4.0,
436
+ local_conv_size=3,
437
+ layer_lr_decay=0.8
438
+ )
439
+
440
+ medsam_lite_prompt_encoder = PromptEncoder(
441
+ embed_dim=256,
442
+ image_embedding_size=(64, 64),
443
+ input_image_size=(256, 256),
444
+ mask_in_chans=16
445
+ )
446
+
447
+ medsam_lite_mask_decoder = MaskDecoder_F4(
448
+ num_multimask_outputs=3,
449
+ transformer=TwoWayTransformer(
450
+ depth=2,
451
+ embedding_dim=256,
452
+ mlp_dim=2048,
453
+ num_heads=8,
454
+ ),
455
+ modality=True,
456
+ contents=True,
457
+ transformer_dim=256,
458
+ iou_head_depth=3,
459
+ iou_head_hidden_dim=256,
460
+ )
461
+
462
+
463
+ medsam_lite_model = MedSAM_Lite(
464
+ image_encoder = medsam_lite_image_encoder,
465
+ mask_decoder = medsam_lite_mask_decoder,
466
+ prompt_encoder = medsam_lite_prompt_encoder
467
+ )
468
+
469
+ lite_medsam_checkpoint = torch.load(lite_medsam_checkpoint_path, map_location='cpu')
470
+ medsam_lite_model.load_state_dict(lite_medsam_checkpoint["model"])
471
+ medsam_lite_model.to(device)
472
+ medsam_lite_model.eval()
473
+
474
+
475
+ def m2_pre_img(image_data, image_size=224):
476
+ transform1 = transforms.Compose([
477
+ transforms.ToTensor(), # normalize to [0.0,1.0]
478
+ transforms.Resize([image_size, image_size], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
479
+ ]
480
+ )
481
+
482
+ resize_img_torch = transform1(image_data)
483
+ return resize_img_torch
484
+
485
+ def get_contents(img, box):
486
+ if len(box)==1:
487
+ x_mino, y_mino, x_maxo, y_maxo = box[0]
488
+ else:
489
+ x_mino, y_mino, x_maxo, y_maxo = box
490
+ crops = img[y_mino:y_maxo,x_mino:x_maxo,:]
491
+ crops_128 = m2_pre_img(crops, image_size=64)
492
+ crops_224 = m2_pre_img(crops)
493
+ crops_224 = crops_224.unsqueeze(0)
494
+ with torch.no_grad():
495
+ image_features = model1.get_image_features(crops_224)
496
+ return crops_128, image_features
497
+
498
+ def get_text_features(modality_text):
499
+
500
+ text_token = tokenizer(modality_text, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
501
+ with torch.no_grad():
502
+ text_features = model1.get_text_features(text_token)
503
+ return text_features
504
+
505
+
506
+ def get_category(idx):
507
+ categories_map = {
508
+ "CT": 0,
509
+ "MR": 1,
510
+ "Endoscopy": 2,
511
+ "XRay": 3,
512
+ "X-Ray": 3,
513
+ "PET": 4,
514
+ "Dermoscopy": 5,
515
+ "Mammography": 6,
516
+ "Mammo": 6,
517
+ "US": 7,
518
+ "OCT": 8,
519
+ "Fundus": 9,
520
+ "Microscopy": 10,
521
+ "Microscope": 10
522
+ }
523
+ return categories_map[idx]
524
+
525
+ def change_name(name):
526
+ if name=="Microscope":
527
+ name = "Microscopy"
528
+ return name
529
+
530
+ def MedSAM_infer_npz_2D(img_npz_file):
531
+ npz_name = basename(img_npz_file)
532
+ c_name = change_name(npz_name.split('_')[1])
533
+ modality_text = f"{c_name} Image"
534
+ category_idx = get_category(c_name)
535
+ npz_data = np.load(img_npz_file, 'r', allow_pickle=True) # (H, W, 3)
536
+ img_3c = npz_data['imgs'] # (H, W, 3)
537
+ assert np.max(img_3c)<256, f'input data should be in range [0, 255], but got {np.unique(img_3c)}'
538
+ H, W = img_3c.shape[:2]
539
+ boxes = npz_data['boxes']
540
+ segs = np.zeros(img_3c.shape[:2], dtype=np.uint8)
541
+ text_features = get_text_features(modality_text)
542
+ text_features = torch.tensor(text_features).unsqueeze(0).to(device)
543
+
544
+ ## preprocessing
545
+ img_256 = resize_longest_side(img_3c, 256)
546
+ newh, neww = img_256.shape[:2]
547
+ img_256_norm = (img_256 - img_256.min()) / np.clip(
548
+ img_256.max() - img_256.min(), a_min=1e-8, a_max=None
549
+ )
550
+ img_256_padded = pad_image(img_256_norm, 256)
551
+ img_256_tensor = torch.tensor(img_256_padded).float().permute(2, 0, 1).unsqueeze(0).to(device)
552
+ with torch.no_grad():
553
+ image_embedding = medsam_lite_model.image_encoder(img_256_tensor)
554
+
555
+ for idx, box in enumerate(boxes, start=1):
556
+ crops, features = get_contents(img_3c, box)
557
+ box256, ratio = resize_box_to_256(box, original_size=(H, W))
558
+ box256 = box256[None, ...] # (1, 4)
559
+ medsam_mask, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box256, features, crops, text_features, category_idx, (newh, neww), (H, W))
560
+ segs[medsam_mask>0] = idx%256
561
+ # print(f'{npz_name}, box: {box}, predicted iou: {np.round(iou_pred.item(), 4)}')
562
+
563
+ np.savez_compressed(
564
+ join(pred_save_dir, npz_name),
565
+ segs=segs,
566
+ )
567
+
568
+ # visualize image, mask and bounding box
569
+ if save_overlay and "Microscope" not in npz_name:
570
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
571
+ ax[0].imshow(img_3c)
572
+ ax[1].imshow(img_3c)
573
+ ax[0].set_title("Image")
574
+ ax[1].set_title("LiteMedSAM Segmentation")
575
+ ax[0].axis('off')
576
+ ax[1].axis('off')
577
+
578
+ for i, box in enumerate(boxes):
579
+ color = np.random.rand(3)
580
+ box_viz = box
581
+ show_box(box_viz, ax[1], edgecolor=color)
582
+ # show_points(points[i], ax[1])
583
+ show_mask((segs == i+1).astype(np.uint8), ax[1], mask_color=color)
584
+
585
+ plt.tight_layout()
586
+ plt.savefig(join(png_save_dir, npz_name.split(".")[0] + '.png'), dpi=300)
587
+ plt.close()
588
+
589
+
590
+ def MedSAM_infer_npz_3D(img_npz_file):
591
+ npz_name = basename(img_npz_file)
592
+ c_name = change_name(npz_name.split('_')[1])
593
+ modality_text = f"{c_name} Image"
594
+ category_idx = get_category(c_name)
595
+ npz_data = np.load(img_npz_file, 'r', allow_pickle=True)
596
+ img_3D = npz_data['imgs'] # (D, H, W)
597
+ # not used in this demo because it treats each slice independently
598
+ # spacing = npz_data['spacing']
599
+ segs = np.zeros_like(img_3D, dtype=np.uint8)
600
+ boxes_3D = npz_data['boxes'] # [[x_min, y_min, z_min, x_max, y_max, z_max]]
601
+ text_features = get_text_features(modality_text)
602
+ text_features = torch.tensor(text_features).unsqueeze(0).to(device)
603
+
604
+ for idx, box3D in enumerate(boxes_3D, start=1):
605
+ segs_3d_temp = np.zeros_like(img_3D, dtype=np.uint8)
606
+ x_min, y_min, z_min, x_max, y_max, z_max = box3D
607
+ assert z_min < z_max, f"z_min should be smaller than z_max, but got {z_min=} and {z_max=}"
608
+ mid_slice_bbox_2d = np.array([x_min, y_min, x_max, y_max])
609
+ z_middle = int((z_max - z_min)/2 + z_min)
610
+
611
+ # infer from middle slice to the z_max
612
+ # print(npz_name, 'infer from middle slice to the z_max')
613
+ for z in range(z_middle, z_max):
614
+ img_2d = img_3D[z, :, :]
615
+ if len(img_2d.shape) == 2:
616
+ img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1)
617
+ else:
618
+ img_3c = img_2d
619
+ H, W, _ = img_3c.shape
620
+
621
+ img_256 = resize_longest_side(img_3c, 256)
622
+ new_H, new_W = img_256.shape[:2]
623
+
624
+ img_256 = (img_256 - img_256.min()) / np.clip(
625
+ img_256.max() - img_256.min(), a_min=1e-8, a_max=None
626
+ ) # normalize to [0, 1], (H, W, 3)
627
+ ## Pad image to 256x256
628
+ img_256 = pad_image(img_256)
629
+
630
+ # convert the shape to (3, H, W)
631
+ img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0).to(device)
632
+ # get the image embedding
633
+ with torch.no_grad():
634
+ image_embedding = medsam_lite_model.image_encoder(img_256_tensor) # (1, 256, 64, 64)
635
+ if z == z_middle:
636
+ crops, features = get_contents(img_3c, mid_slice_bbox_2d)
637
+ box_256, _ = resize_box_to_256(mid_slice_bbox_2d, original_size=(H, W))
638
+ else:
639
+ pre_seg = segs_3d_temp[z-1, :, :]
640
+ if np.max(pre_seg) > 0:
641
+ box_original = get_bbox256(pre_seg)
642
+ crops, features = get_contents(img_3c, box_original)
643
+ pre_seg256 = resize_longest_side(pre_seg)
644
+ pre_seg256 = pad_image(pre_seg256)
645
+ box_256 = get_bbox256(pre_seg256)
646
+ else:
647
+ crops, features = get_contents(img_3c, mid_slice_bbox_2d)
648
+ box_256, _ = resize_box_to_256(mid_slice_bbox_2d, original_size=(H, W))
649
+ img_2d_seg, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box_256, features, crops, text_features, category_idx, [new_H, new_W], [H, W])
650
+ segs_3d_temp[z, img_2d_seg>0] = idx
651
+
652
+ # infer from middle slice to the z_max
653
+ # print(npz_name, 'infer from middle slice to the z_min')
654
+ for z in range(z_middle-1, z_min, -1):
655
+ img_2d = img_3D[z, :, :]
656
+ if len(img_2d.shape) == 2:
657
+ img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1)
658
+ else:
659
+ img_3c = img_2d
660
+ H, W, _ = img_3c.shape
661
+
662
+ img_256 = resize_longest_side(img_3c)
663
+ new_H, new_W = img_256.shape[:2]
664
+
665
+ img_256 = (img_256 - img_256.min()) / np.clip(
666
+ img_256.max() - img_256.min(), a_min=1e-8, a_max=None
667
+ ) # normalize to [0, 1], (H, W, 3)
668
+ ## Pad image to 256x256
669
+ img_256 = pad_image(img_256)
670
+
671
+ img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0).to(device)
672
+ # get the image embedding
673
+ with torch.no_grad():
674
+ image_embedding = medsam_lite_model.image_encoder(img_256_tensor) # (1, 256, 64, 64)
675
+
676
+ pre_seg = segs_3d_temp[z+1, :, :]
677
+ # pre_seg = segs[z+1, :, :]
678
+ if np.max(pre_seg) > 0:
679
+ box_original = get_bbox256(pre_seg)
680
+ crops, features = get_contents(img_3c, box_original)
681
+ pre_seg256 = resize_longest_side(pre_seg)
682
+ pre_seg256 = pad_image(pre_seg256)
683
+ box_256 = get_bbox256(pre_seg256)
684
+ else:
685
+ crops, features = get_contents(img_3c, mid_slice_bbox_2d)
686
+ scale_256 = 256 / max(H, W)
687
+ box_256 = mid_slice_bbox_2d * scale_256
688
+ img_2d_seg, iou_pred = medsam_inference(medsam_lite_model, image_embedding, box_256, features, crops, text_features, category_idx, [new_H, new_W], [H, W])
689
+ segs_3d_temp[z, img_2d_seg>0] = idx
690
+ segs[segs_3d_temp>0] = idx
691
+ np.savez_compressed(
692
+ join(pred_save_dir, npz_name),
693
+ segs=segs,
694
+ )
695
+
696
+ # visualize image, mask and bounding box
697
+ if save_overlay and "Microscope" not in npz_name:
698
+ idx = int(segs.shape[0] / 2)
699
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
700
+ ax[0].imshow(img_3D[idx], cmap='gray')
701
+ ax[1].imshow(img_3D[idx], cmap='gray')
702
+ ax[0].set_title("Image")
703
+ ax[1].set_title("LiteMedSAM Segmentation")
704
+ ax[0].axis('off')
705
+ ax[1].axis('off')
706
+
707
+ for i, box3D in enumerate(boxes_3D, start=1):
708
+ if np.sum(segs[idx]==i) > 0:
709
+ color = np.random.rand(3)
710
+ x_min, y_min, z_min, x_max, y_max, z_max = box3D
711
+ box_viz = np.array([x_min, y_min, x_max, y_max])
712
+ show_box(box_viz, ax[1], edgecolor=color)
713
+ show_mask(segs[idx]==i, ax[1], mask_color=color)
714
+
715
+ plt.tight_layout()
716
+ plt.savefig(join(png_save_dir, npz_name.split(".")[0] + '.png'), dpi=300)
717
+ plt.close()
718
+
719
+
720
+ if __name__ == '__main__':
721
+
722
+ img_npz_files = sorted(glob(join(data_root, '*.npz'), recursive=True))
723
+ efficiency = OrderedDict()
724
+ efficiency['case'] = []
725
+ efficiency['time'] = []
726
+ for img_npz_file in tqdm(img_npz_files):
727
+ start_time = time()
728
+ if basename(img_npz_file).startswith('3D'):
729
+ MedSAM_infer_npz_3D(img_npz_file)
730
+ else:
731
+ MedSAM_infer_npz_2D(img_npz_file)
732
+ end_time = time()
733
+ efficiency['case'].append(basename(img_npz_file))
734
+ efficiency['time'].append(end_time - start_time)
735
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
736
+ # print(current_time, 'file name:', basename(img_npz_file), 'time cost:', np.round(end_time - start_time, 4))
737
+ efficiency_df = pd.DataFrame(efficiency)
738
+ efficiency_df.to_csv(join(pred_save_dir, 'efficiency.csv'), index=False)
modality_npz_dataset.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import os
4
+ from torchvision import transforms
5
+ from torch.utils.data import Dataset
6
+ import torch
7
+ import cv2
8
+ from transformers import CLIPModel, CLIPTokenizer
9
+ from os.path import join, exists, isfile, isdir, basename
10
+ import random
11
+
12
+ join = os.path.join
13
+ import json
14
+
15
+
16
+ def reshape_MR(img):
17
+
18
+ original_shape = img.shape
19
+ sorted_axes = np.argsort(original_shape)
20
+ new_img = img.transpose(sorted_axes)
21
+
22
+ return new_img
23
+
24
+ class ModalityNpzDataset(Dataset):
25
+ def __init__(self,
26
+ data_root,
27
+ points=True,
28
+ contents=True,
29
+ image_size=256,
30
+ bbox_shift=5,
31
+ data_aug=True):
32
+
33
+ self.data_root = data_root
34
+
35
+
36
+ json_data = json.load(open("case_data.json", "r"))
37
+ self.file_paths = json_data
38
+
39
+ assert len(self.file_paths) == 11
40
+
41
+ self.image_size = image_size
42
+ self.target_length = image_size
43
+ self.bbox_shift = bbox_shift
44
+ self.data_aug = data_aug
45
+ self.points = points
46
+ self.contents = contents
47
+
48
+ self.categories_map = {
49
+ "CT": 0,
50
+ "MR": 1,
51
+ "Endoscopy": 2,
52
+ "XRay": 3,
53
+ "X-Ray": 3,
54
+ "PET": 4,
55
+ "Dermoscopy": 5,
56
+ "Mammography": 6,
57
+ "Mammo": 6,
58
+ "US": 7,
59
+ "OCT": 8,
60
+ "Fundus": 9,
61
+ "Microscopy": 10,
62
+ "Microscope": 10
63
+ }
64
+
65
+ self.model1 = CLIPModel.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
66
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
67
+ self.model1.requires_grad_(False)
68
+
69
+
70
+
71
+ def show_box(self, box, ax):
72
+ x0, y0 = box[0], box[1]
73
+ w, h = box[2] - box[0], box[3] - box[1]
74
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))
75
+
76
+ def vis(self, image, bboxes, title):
77
+ _, axs = plt.subplots(1, 2, figsize=(10, 10))
78
+
79
+ axs[0].imshow(image, cmap="gray")
80
+ self.show_box(bboxes, axs[0])
81
+ axs[0].axis('off')
82
+ axs[0].set_title(title)
83
+
84
+ plt.subplots_adjust(wspace=0.01, hspace=0)
85
+ plt.savefig(
86
+ "test.png",
87
+ bbox_inches='tight',
88
+ dpi=300
89
+ )
90
+ plt.close()
91
+
92
+ def vis_crop(self, image, title):
93
+
94
+ plt.imshow(np.transpose(image, (1,2,0)))
95
+ plt.axis('off')
96
+ plt.title(title)
97
+
98
+ plt.savefig(
99
+ "test.png",
100
+ bbox_inches='tight',
101
+ dpi=300
102
+ )
103
+ plt.close()
104
+
105
+ def __getitem__(self, index):
106
+ #! add the random index
107
+
108
+ modality_map = [
109
+ "CT",
110
+ "MR",
111
+ "Endoscopy",
112
+ "X-ray",
113
+ "PET",
114
+ "Dermoscopy",
115
+ "Mammography",
116
+ "US",
117
+ "OCT",
118
+ "Fundus",
119
+ "Microscopy"
120
+ ]
121
+ modality_index = random.randint(0, 10)
122
+ index = random.randint(0, len(self.file_paths[modality_map[modality_index]])-1)
123
+ file_path = self.file_paths[modality_map[modality_index]][index][0]
124
+ temp = '/'.join(file_path.split('/')[7:])
125
+ file_path = self.data_root+'/'+temp
126
+
127
+
128
+ npz = np.load(file_path, 'r', allow_pickle=True)
129
+ img_name = basename(file_path)
130
+
131
+ mt = img_name.split("_")[0]
132
+ if mt=="2D" or mt=="3D":
133
+ mt = img_name.split("_")[1]
134
+ category_text = f"{mt} Image"
135
+ category_idx = self.categories_map[mt]
136
+ gts = npz["gts"]
137
+ img = npz["imgs"]
138
+
139
+ # special case for MR_totalseg
140
+ if "MR_totalseg" in img_name:
141
+ img = reshape_MR(img)
142
+ gts = reshape_MR(gts)
143
+ if img.shape[1] <=100:
144
+ return self.__getitem__(random.randint(0,len(self)-1))
145
+
146
+ if len(gts.shape) > 2: ## 3D image
147
+ i=random.randint(0,gts.shape[0]-1)
148
+ img = img[i, :, :]
149
+ gts = gts[i, :, :]
150
+ img_3c = np.repeat(img[:, :, None], 3, axis=-1) # (H, W, 3)
151
+ img_resized = self.resize_longest_side(img_3c)
152
+ else:
153
+ if len(img.shape) < 3:
154
+ img_3c = np.repeat(img[:, :, None], 3, axis=-1)
155
+ else:
156
+ img_3c = img
157
+ img_resized = self.resize_longest_side(img_3c)
158
+ gts = np.uint16(gts)
159
+
160
+ # Resizing
161
+ img_resized = (img_resized - img_resized.min()) / np.clip(img_resized.max() - img_resized.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3
162
+ img_padded = self.pad_image(img_resized) #self.pad_image(img_resize) # (256, 256, 3)
163
+ # convert the shape to (3, H, W)
164
+ img_padded = np.transpose(img_padded, (2, 0, 1)) # (3, 256, 256)
165
+ assert np.max(img_padded)<=1.0 and np.min(img_padded)>=0.0, 'image should be normalized to [0, 1]'
166
+
167
+ label_ids = np.unique(gts)
168
+ label_ids = label_ids.tolist()
169
+
170
+ try:
171
+ label_ids.remove(0)
172
+ label_id = random.choice(label_ids)
173
+ gt2D_original = np.uint8(gts == label_id)
174
+ gt = cv2.resize(
175
+ gt2D_original,
176
+ (img_resized.shape[1], img_resized.shape[0]),
177
+ interpolation=cv2.INTER_NEAREST
178
+ ).astype(np.uint8)
179
+ gt2D = self.pad_image(gt)
180
+
181
+ except:
182
+ return self.__getitem__(random.randint(0,len(self)-1))
183
+
184
+
185
+ box_original = self.get_bbox(gt2D_original)
186
+ x_mino, y_mino, x_maxo, y_maxo = box_original
187
+
188
+ if self.data_aug:
189
+ if random.random() > 0.5:
190
+ img_padded = np.ascontiguousarray(np.flip(img_padded, axis=-1))
191
+ gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1))
192
+ if random.random() > 0.5:
193
+ img_padded = np.ascontiguousarray(np.flip(img_padded, axis=-2))
194
+ gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2))
195
+
196
+ try:
197
+ gt2D = np.uint8(gt2D > 0)
198
+ y_indices, x_indices = np.where(gt2D > 0)
199
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
200
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
201
+ H, W = gt2D.shape
202
+ x_min = max(0, x_min - random.randint(0, self.bbox_shift))
203
+ x_max = min(W, x_max + random.randint(0, self.bbox_shift))
204
+ y_min = max(0, y_min - random.randint(0, self.bbox_shift))
205
+ y_max = min(H, y_max + random.randint(0, self.bbox_shift))
206
+ bboxes = np.array([x_min, y_min, x_max, y_max])
207
+ except:
208
+ return self.__getitem__(random.randint(0,len(self)-1))
209
+
210
+ if self.points:
211
+ mid_x = (x_min+x_max)//2
212
+ mid_y = (y_min+y_max)//2
213
+ cl = [[y_min, mid_y, x_min, mid_x], [mid_y,y_max,x_min,mid_x], [mid_y,y_max, mid_x,x_max], [y_min,mid_y, mid_x,x_max]]
214
+ coords = []
215
+ for i in range(4):
216
+ gt2D_tmp = np.zeros((H, W))
217
+ gt2D_tmp[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]] = gt2D[cl[i][0]:cl[i][1], cl[i][2]:cl[i][3]]
218
+ y_indices, x_indices = np.where(gt2D_tmp > 0)
219
+ if y_indices.size==0:
220
+ coords.append([mid_x, mid_y])
221
+ else:
222
+ x_point = np.random.choice(x_indices)
223
+ y_point = np.random.choice(y_indices)
224
+ coords.append([x_point, y_point])
225
+ coords = np.array(coords).reshape(4, 2)
226
+ coords = torch.tensor(coords).float()
227
+ else:
228
+ coords = None
229
+
230
+ if self.contents:
231
+ try:
232
+ crops = img_3c[y_mino:y_maxo,x_mino:x_maxo,:]
233
+ crops_64 = self.m2_pre_img(crops, image_size=64) # change here for the size of cropped part
234
+ crops_224 = self.m2_pre_img(crops)
235
+ except:
236
+ crops_64 = torch.zeros((3, 64, 64))
237
+ crops_224 = torch.zeros((3, 224, 224))
238
+ crops_224 = crops_224.unsqueeze(0)
239
+ text_token = self.tokenizer(category_text, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
240
+ with torch.no_grad():
241
+ image_features = self.model1.get_image_features(crops_224)
242
+ text_features = self.model1.get_text_features(text_token)
243
+ else:
244
+ crops_64 = None
245
+ image_features = None
246
+ text_features = None
247
+
248
+
249
+ return {
250
+ "image": torch.tensor(img_padded).float(),
251
+ "gt2D": torch.tensor(gt2D[None, :,:]).long(),
252
+ "coords": coords,
253
+ "bboxes": torch.tensor(bboxes[None, None, ...]).float(),
254
+ "image_crop": crops_64.float(),
255
+ "image_feature": image_features.float(),
256
+ "text_feature": text_features.float(),
257
+ "category_idx": category_idx,
258
+ "image_name": img_name,
259
+ "new_size": torch.tensor(np.array([img_padded.shape[0], img_padded.shape[1]])).long(),
260
+ "original_size": torch.tensor(np.array([img_3c.shape[0], img_3c.shape[1]])).long()
261
+ }
262
+
263
+ def __len__(self):
264
+ return 108714
265
+
266
+ def get_bbox(self, mask_256, bbox_shift=5):
267
+ y_indices, x_indices = np.where(mask_256 > 0)
268
+ x_min, x_max = np.min(x_indices), np.max(x_indices)
269
+ y_min, y_max = np.min(y_indices), np.max(y_indices)
270
+ H, W = mask_256.shape
271
+ x_min = max(0, x_min - random.randint(0, bbox_shift))
272
+ x_max = min(W, x_max + random.randint(0, bbox_shift))
273
+ y_min = max(0, y_min - random.randint(0, bbox_shift))
274
+ y_max = min(H, y_max + random.randint(0, bbox_shift))
275
+
276
+ bboxes256 = np.array([x_min, y_min, x_max, y_max])
277
+
278
+ return bboxes256
279
+
280
+ def m2_pre_img(self, image_data, image_size=224):
281
+ transform1 = transforms.Compose([
282
+ transforms.ToTensor(), # normalize to [0.0,1.0]
283
+ transforms.Resize([image_size, image_size], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
284
+ ]
285
+ )
286
+
287
+ resize_img_torch = transform1(image_data)
288
+ return resize_img_torch
289
+
290
+ def resize_longest_side(self, image):
291
+ """
292
+ Expects a numpy array with shape HxWxC in uint8 format.
293
+ """
294
+ long_side_length = self.target_length
295
+ oldh, oldw = image.shape[0], image.shape[1]
296
+ scale = long_side_length * 1.0 / max(oldh, oldw)
297
+ newh, neww = oldh * scale, oldw * scale
298
+ neww, newh = int(neww + 0.5), int(newh + 0.5)
299
+ target_size = (neww, newh)
300
+
301
+ return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
302
+
303
+ def pad_image(self, image):
304
+ """
305
+ Expects a numpy array with shape HxWxC in uint8 format.
306
+ """
307
+ # Pad
308
+ h, w = image.shape[0], image.shape[1]
309
+ padh = self.image_size - h
310
+ padw = self.image_size - w
311
+ if len(image.shape) == 3: ## Pad image
312
+ image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
313
+ else: ## Pad gt mask
314
+ image_padded = np.pad(image, ((0, padh), (0, padw)))
315
+
316
+ return image_padded
317
+
models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .mask_decoder import MaskDecoder, MaskDecoder_F4
2
+ from .prompt_encoder import PromptEncoder
3
+ from .transformer import TwoWayTransformer
4
+ from .tiny_vit import TinyViT
models/common.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from typing import Type
12
+
13
+
14
+ class MLPBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embedding_dim: int,
18
+ mlp_dim: int,
19
+ act: Type[nn.Module] = nn.GELU,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
23
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
24
+ self.act = act()
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return self.lin2(self.act(self.lin1(x)))
28
+
29
+
30
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
models/lite_medsam.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .mask_decoder import MaskDecoder
6
+ from .prompt_encoder import PromptEncoder
7
+ from .transform import TwoWayTransformer
8
+
9
+ class MedSAM_Lite(nn.Module):
10
+ def __init__(self,
11
+ image_encoder,
12
+ mask_decoder,
13
+ prompt_encoder
14
+ ):
15
+ super().__init__()
16
+ self.image_encoder = image_encoder
17
+ self.mask_decoder = mask_decoder
18
+ self.prompt_encoder = prompt_encoder
19
+
20
+ def forward(self, image, boxes):
21
+ image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
22
+
23
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
24
+ points=None,
25
+ boxes=boxes,
26
+ masks=None,
27
+ )
28
+ low_res_masks, iou_predictions = self.mask_decoder(
29
+ image_embeddings=image_embedding, # (B, 256, 64, 64)
30
+ image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
31
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
32
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
33
+ multimask_output=False,
34
+ ) # (B, 1, 256, 256)
35
+
36
+ return low_res_masks, iou_predictions
37
+
38
+ @torch.no_grad()
39
+ def postprocess_masks(self, masks, new_size, original_size):
40
+ """
41
+ Do cropping and resizing
42
+ """
43
+ # Crop
44
+ masks = masks[:, :, :new_size[0], :new_size[1]]
45
+ # Resize
46
+ masks = F.interpolate(
47
+ masks,
48
+ size=(original_size[0], original_size[1]),
49
+ mode="bilinear",
50
+ align_corners=False,
51
+ )
52
+
53
+ return masks
54
+
models/mask_decoder.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from typing import List, Tuple, Type
13
+
14
+ from .common import LayerNorm2d
15
+ from .transformer import TwoWayTransformer
16
+
17
+ class Classifier(nn.Module):
18
+ def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_dim = out_dim or in_dim
21
+ hid_dim = hid_dim or in_dim
22
+ self.fc1 = nn.Linear(in_dim, hid_dim)
23
+ self.act = act()
24
+ self.fc2 = nn.Linear(hid_dim, out_dim)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ return x
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
36
+ super(Block, self).__init__()
37
+
38
+
39
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
40
+ self.batch_norm1 = nn.BatchNorm2d(out_channels)
41
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
42
+
43
+ self.i_downsample = i_downsample
44
+ self.stride = stride
45
+ self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
46
+
47
+ def forward(self, x):
48
+ identity = x.clone()
49
+
50
+ x = self.relu(self.batch_norm1(self.conv1(x)))
51
+ x = self.conv2(x)
52
+
53
+ if self.i_downsample is not None:
54
+ identity = self.i_downsample(identity)
55
+
56
+ x += identity
57
+ return x
58
+
59
+ class MaskDecoder(nn.Module):
60
+ def __init__(
61
+ self,
62
+ *,
63
+ transformer_dim: int,
64
+ transformer: nn.Module,
65
+ modality,
66
+ contents,
67
+ num_multimask_outputs: int = 3,
68
+ activation: Type[nn.Module] = nn.GELU,
69
+ iou_head_depth: int = 3,
70
+ iou_head_hidden_dim: int = 256,
71
+ category_num = 11
72
+ ) -> None:
73
+ """
74
+ Predicts masks given an image and prompt embeddings, using a
75
+ transformer architecture.
76
+
77
+ Arguments:
78
+ transformer_dim (int): the channel dimension of the transformer
79
+ transformer (nn.Module): the transformer used to predict masks
80
+ num_multimask_outputs (int): the number of masks to predict
81
+ when disambiguating masks
82
+ activation (nn.Module): the type of activation to use when
83
+ upscaling masks
84
+ iou_head_depth (int): the depth of the MLP used to predict
85
+ mask quality
86
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
87
+ used to predict mask quality
88
+ """
89
+ super().__init__()
90
+ self.transformer_dim = transformer_dim
91
+ self.transformer = transformer
92
+ self.category_num = category_num
93
+ self.modality = modality
94
+ self.contents = contents
95
+
96
+ self.num_multimask_outputs = num_multimask_outputs
97
+
98
+ self.iou_token = nn.Embedding(1, transformer_dim)
99
+ self.num_mask_tokens = num_multimask_outputs + 1
100
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
101
+
102
+ self.convs = Block(transformer_dim, transformer_dim)
103
+ self.w_lin = nn.Linear(transformer_dim, transformer_dim)
104
+ self.b_lin = nn.Linear(transformer_dim, transformer_dim)
105
+
106
+ self.output_upscaling = nn.Sequential(
107
+ nn.ConvTranspose2d(
108
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
109
+ ),
110
+ LayerNorm2d(transformer_dim // 4),
111
+ activation(),
112
+ nn.ConvTranspose2d(
113
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
114
+ ),
115
+ activation(),
116
+ )
117
+ self.output_hypernetworks_mlps = nn.ModuleList(
118
+ [
119
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
120
+ for i in range(self.num_mask_tokens)
121
+ ]
122
+ )
123
+
124
+ self.iou_prediction_head = MLP(
125
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
126
+ )
127
+
128
+ self.category_prediction_head = Classifier(
129
+ transformer_dim, transformer_dim//4, category_num
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ image_embeddings: torch.Tensor,
135
+ image_pe: torch.Tensor,
136
+ sparse_prompt_embeddings: torch.Tensor,
137
+ dense_prompt_embeddings: torch.Tensor,
138
+ multimask_output: bool,
139
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ """
141
+ Predict masks given image and prompt embeddings.
142
+
143
+ Arguments:
144
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
145
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
146
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
147
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
148
+ multimask_output (bool): Whether to return multiple masks or a single
149
+ mask.
150
+
151
+ Returns:
152
+ torch.Tensor: batched predicted masks
153
+ torch.Tensor: batched predictions of mask quality
154
+ """
155
+ masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out = self.predict_masks(
156
+ image_embeddings=image_embeddings,
157
+ image_pe=image_pe,
158
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
159
+ dense_prompt_embeddings=dense_prompt_embeddings,
160
+ )
161
+
162
+ # Select the correct mask or masks for output
163
+ if multimask_output:
164
+ mask_slice = slice(1, None)
165
+ else:
166
+ mask_slice = slice(0, 1)
167
+ masks = masks[:, mask_slice, :, :]
168
+ iou_pred = iou_pred[:, mask_slice]
169
+
170
+ # Prepare output
171
+ return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
172
+
173
+ def predict_masks(
174
+ self,
175
+ image_embeddings: torch.Tensor,
176
+ image_pe: torch.Tensor,
177
+ sparse_prompt_embeddings: torch.Tensor,
178
+ dense_prompt_embeddings: torch.Tensor,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Predicts masks. See 'forward' for more details."""
181
+ # Concatenate output tokens
182
+ output_tokens = torch.cat(
183
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
184
+ )
185
+ output_tokens = output_tokens.unsqueeze(0).expand(
186
+ sparse_prompt_embeddings.size(0), -1, -1
187
+ )
188
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
189
+
190
+ # Expand per-image data in batch direction to be per-mask
191
+ if image_embeddings.shape[0] != tokens.shape[0]:
192
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
193
+ else:
194
+ src = image_embeddings
195
+ src = src + dense_prompt_embeddings
196
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
197
+ b, c, h, w = src.shape
198
+
199
+ # Run the transformer
200
+ hs, src = self.transformer(src, pos_src, tokens)
201
+ iou_token_out = hs[:, 0, :]
202
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
203
+
204
+ # Upscale mask embeddings and predict masks using the mask tokens
205
+ src = src.transpose(1, 2).view(b, c, h, w)
206
+ if self.contents:
207
+ clip_tokens_out = tokens[:,-2,:]
208
+ image_tokens_out = F.adaptive_avg_pool2d(dense_prompt_embeddings, output_size=(1, 1)).squeeze(-1).squeeze(-1)
209
+ clip_new_out = hs[:,-2,:].unsqueeze(-1).unsqueeze(-1)
210
+ src = dense_prompt_embeddings+src+clip_new_out
211
+ src = self.convs(src)
212
+ else:
213
+ clip_tokens_out = None
214
+ image_tokens_out = None
215
+
216
+ if self.modality:
217
+ category_tokens_out = hs[:,-1,:]
218
+ wc = self.w_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
219
+ bc = self.b_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
220
+ src = wc*src+bc+src
221
+ category_pred = self.category_prediction_head(category_tokens_out)
222
+ else:
223
+ category_pred = None
224
+
225
+ upscaled_embedding = self.output_upscaling(src)
226
+ hyper_in_list: List[torch.Tensor] = []
227
+ for i in range(self.num_mask_tokens):
228
+ hyper_in_list.append(
229
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
230
+ )
231
+ hyper_in = torch.stack(hyper_in_list, dim=1)
232
+ b, c, h, w = upscaled_embedding.shape
233
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
234
+
235
+ # Generate mask quality predictions
236
+ iou_pred = self.iou_prediction_head(iou_token_out)
237
+
238
+ return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
239
+
240
+ # Lightly adapted from
241
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
242
+ class MLP(nn.Module):
243
+ def __init__(
244
+ self,
245
+ input_dim: int,
246
+ hidden_dim: int,
247
+ output_dim: int,
248
+ num_layers: int,
249
+ sigmoid_output: bool = False,
250
+ ) -> None:
251
+ super().__init__()
252
+ self.num_layers = num_layers
253
+ h = [hidden_dim] * (num_layers - 1)
254
+ self.layers = nn.ModuleList(
255
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
256
+ )
257
+ self.sigmoid_output = sigmoid_output
258
+
259
+ def forward(self, x):
260
+ for i, layer in enumerate(self.layers):
261
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
262
+ if self.sigmoid_output:
263
+ x = F.sigmoid(x)
264
+ return x
265
+
266
+ class MaskDecoder_F4(nn.Module):
267
+ def __init__(
268
+ self,
269
+ *,
270
+ transformer_dim: int,
271
+ transformer: nn.Module,
272
+ modality,
273
+ contents,
274
+ num_multimask_outputs: int = 3,
275
+ activation: Type[nn.Module] = nn.GELU,
276
+ iou_head_depth: int = 3,
277
+ iou_head_hidden_dim: int = 256,
278
+ category_num = 11
279
+ ) -> None:
280
+ """
281
+ Predicts masks given an image and prompt embeddings, using a
282
+ transformer architecture.
283
+
284
+ Arguments:
285
+ transformer_dim (int): the channel dimension of the transformer
286
+ transformer (nn.Module): the transformer used to predict masks
287
+ num_multimask_outputs (int): the number of masks to predict
288
+ when disambiguating masks
289
+ activation (nn.Module): the type of activation to use when
290
+ upscaling masks
291
+ iou_head_depth (int): the depth of the MLP used to predict
292
+ mask quality
293
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
294
+ used to predict mask quality
295
+ """
296
+ super().__init__()
297
+ self.transformer_dim = transformer_dim
298
+ self.transformer = transformer
299
+ self.category_num = category_num
300
+ self.modality = modality
301
+ self.contents = contents
302
+
303
+ self.num_multimask_outputs = num_multimask_outputs
304
+
305
+ self.iou_token = nn.Embedding(1, transformer_dim)
306
+ self.num_mask_tokens = num_multimask_outputs + 1
307
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
308
+
309
+ self.convs = Block(transformer_dim, transformer_dim)
310
+ self.conv1 = nn.Conv2d(transformer_dim*2, transformer_dim, 1)
311
+ self.c_conv = Block(transformer_dim, transformer_dim)
312
+ self.w_lin = nn.Linear(transformer_dim, transformer_dim)
313
+ self.b_lin = nn.Linear(transformer_dim, transformer_dim)
314
+ self.m_conv = Block(transformer_dim, transformer_dim)
315
+
316
+ self.output_upscaling = nn.Sequential(
317
+ nn.ConvTranspose2d(
318
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
319
+ ),
320
+ LayerNorm2d(transformer_dim // 4),
321
+ activation(),
322
+ nn.ConvTranspose2d(
323
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
324
+ ),
325
+ activation(),
326
+ )
327
+ self.output_hypernetworks_mlps = nn.ModuleList(
328
+ [
329
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
330
+ for i in range(self.num_mask_tokens)
331
+ ]
332
+ )
333
+
334
+ self.iou_prediction_head = MLP(
335
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
336
+ )
337
+
338
+ # self.category_prediction_head = Classifier(
339
+ # transformer_dim, transformer_dim//4, category_num
340
+ # )
341
+ self.category_prediction_head = Classifier(
342
+ transformer_dim, transformer_dim//4, category_num
343
+ )
344
+
345
+ def forward(
346
+ self,
347
+ image_embeddings: torch.Tensor,
348
+ image_pe: torch.Tensor,
349
+ sparse_prompt_embeddings: torch.Tensor,
350
+ dense_prompt_embeddings: torch.Tensor,
351
+ multimask_output: bool,
352
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
353
+ """
354
+ Predict masks given image and prompt embeddings.
355
+
356
+ Arguments:
357
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
358
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
359
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
360
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
361
+ multimask_output (bool): Whether to return multiple masks or a single
362
+ mask.
363
+
364
+ Returns:
365
+ torch.Tensor: batched predicted masks
366
+ torch.Tensor: batched predictions of mask quality
367
+ """
368
+ masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out = self.predict_masks(
369
+ image_embeddings=image_embeddings,
370
+ image_pe=image_pe,
371
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
372
+ dense_prompt_embeddings=dense_prompt_embeddings,
373
+ )
374
+
375
+ # Select the correct mask or masks for output
376
+ if multimask_output:
377
+ mask_slice = slice(1, None)
378
+ else:
379
+ mask_slice = slice(0, 1)
380
+ masks = masks[:, mask_slice, :, :]
381
+ iou_pred = iou_pred[:, mask_slice]
382
+
383
+ # Prepare output
384
+ return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
385
+
386
+ def predict_masks(
387
+ self,
388
+ image_embeddings: torch.Tensor,
389
+ image_pe: torch.Tensor,
390
+ sparse_prompt_embeddings: torch.Tensor,
391
+ dense_prompt_embeddings: torch.Tensor,
392
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
393
+ """Predicts masks. See 'forward' for more details."""
394
+ # Concatenate output tokens
395
+ output_tokens = torch.cat(
396
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
397
+ )
398
+ output_tokens = output_tokens.unsqueeze(0).expand(
399
+ sparse_prompt_embeddings.size(0), -1, -1
400
+ )
401
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
402
+ m_token = tokens[:,-1,:]
403
+
404
+ # Expand per-image data in batch direction to be per-mask
405
+ if image_embeddings.shape[0] != tokens.shape[0]:
406
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
407
+ else:
408
+ src = image_embeddings
409
+ src = src + dense_prompt_embeddings
410
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
411
+ b, c, h, w = src.shape
412
+
413
+ # Run the transformer
414
+ hs, src = self.transformer(src, pos_src, tokens)
415
+ iou_token_out = hs[:, 0, :]
416
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
417
+
418
+ # Upscale mask embeddings and predict masks using the mask tokens
419
+ src = src.transpose(1, 2).view(b, c, h, w)
420
+
421
+ if self.modality:
422
+ category_tokens_out = hs[:,-1,:]
423
+ wc = self.w_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
424
+ bc = self.b_lin(category_tokens_out).unsqueeze(-1).unsqueeze(-1)
425
+ src_m = wc*src+bc+src
426
+ m_info = wc.squeeze(-1).squeeze(-1)+bc.squeeze(-1).squeeze(-1)+category_tokens_out
427
+ category_pred = self.category_prediction_head(m_info)
428
+ src_m = self.m_conv(src_m)
429
+ else:
430
+ category_pred = None
431
+
432
+ if self.contents:
433
+ clip_tokens_out = tokens[:,-2,:]
434
+ image_tokens_out = F.adaptive_avg_pool2d(dense_prompt_embeddings, output_size=(1, 1)).squeeze(-1).squeeze(-1)
435
+ clip_new_out = hs[:,-2,:].unsqueeze(-1).unsqueeze(-1)
436
+ src_vp = dense_prompt_embeddings+src+clip_new_out
437
+ src_vp = self.convs(src_vp)
438
+ else:
439
+ clip_tokens_out = None
440
+ image_tokens_out = None
441
+
442
+ if self.contents and self.modality:
443
+ src = torch.cat((src_m, src_vp), dim=1)
444
+ src = self.conv1(src)
445
+ src = self.c_conv(src)
446
+ elif self.contents:
447
+ src = src_vp
448
+ elif self.modality:
449
+ src = src_m
450
+
451
+ upscaled_embedding = self.output_upscaling(src)
452
+ hyper_in_list: List[torch.Tensor] = []
453
+ for i in range(self.num_mask_tokens):
454
+ hyper_in_list.append(
455
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
456
+ )
457
+ hyper_in = torch.stack(hyper_in_list, dim=1)
458
+ b, c, h, w = upscaled_embedding.shape
459
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
460
+
461
+ # Generate mask quality predictions
462
+ iou_pred = self.iou_prediction_head(iou_token_out)
463
+
464
+ return masks, iou_pred, category_pred, clip_tokens_out, image_tokens_out
465
+
models/prompt_encoder.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from typing import Any, Optional, Tuple, Type
13
+
14
+ from .common import LayerNorm2d
15
+
16
+ class PositionEmbeddingRandom(nn.Module):
17
+ """
18
+ Positional encoding using random spatial frequencies.
19
+ """
20
+
21
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
22
+ super().__init__()
23
+ if scale is None or scale <= 0.0:
24
+ scale = 1.0
25
+ self.register_buffer(
26
+ "positional_encoding_gaussian_matrix",
27
+ scale * torch.randn((2, num_pos_feats)),
28
+ )
29
+
30
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
31
+ """Positionally encode points that are normalized to [0,1]."""
32
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
33
+ coords = 2 * coords - 1
34
+ coords = coords @ self.positional_encoding_gaussian_matrix
35
+ coords = 2 * np.pi * coords
36
+ # outputs d_1 x ... x d_n x C shape
37
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
38
+
39
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
40
+ """Generate positional encoding for a grid of the specified size."""
41
+ h, w = size
42
+ device: Any = self.positional_encoding_gaussian_matrix.device
43
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
44
+ y_embed = grid.cumsum(dim=0) - 0.5
45
+ x_embed = grid.cumsum(dim=1) - 0.5
46
+ y_embed = y_embed / h
47
+ x_embed = x_embed / w
48
+
49
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
50
+ return pe.permute(2, 0, 1) # C x H x W
51
+
52
+ def forward_with_coords(
53
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
54
+ ) -> torch.Tensor:
55
+ """Positionally encode points that are not normalized to [0,1]."""
56
+ coords = coords_input.clone()
57
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
58
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
59
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
60
+
61
+ class Block(nn.Module):
62
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
63
+ super(Block, self).__init__()
64
+
65
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
66
+ self.batch_norm1 = nn.BatchNorm2d(out_channels)
67
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
68
+ self.batch_norm2 = nn.BatchNorm2d(out_channels)
69
+
70
+ self.i_downsample = i_downsample
71
+ self.stride = stride
72
+ self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
73
+
74
+ def forward(self, x):
75
+ identity = x.clone()
76
+
77
+ x = self.relu(self.batch_norm1(self.conv1(x)))
78
+ x = self.batch_norm2(self.conv2(x))
79
+
80
+ if self.i_downsample is not None:
81
+ identity = self.i_downsample(identity)
82
+
83
+ x += identity
84
+ x = self.relu(x)
85
+ return x
86
+
87
+ class Crop_Net_New(nn.Module):
88
+ def __init__(self, dim):
89
+ super().__init__()
90
+ self.conv = nn.Conv2d(3, dim, 3, 1, 1)
91
+
92
+ self.conv1 = Block(dim, dim)
93
+ self.conv2 = Block(dim, dim)
94
+ self.conv3 = Block(dim, dim)
95
+
96
+ self.conv4 = nn.Conv2d(dim, dim, 5, 1, 2)
97
+
98
+ def forward(self, x):
99
+ x = self.conv(x)
100
+ x = self.conv1(x)
101
+ x = self.conv2(x)
102
+ x = self.conv3(x)
103
+ return self.conv4(x)
104
+
105
+ class Mlp(nn.Module):
106
+ def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.):
107
+ super().__init__()
108
+ out_dim = out_dim or in_dim
109
+ hid_dim = hid_dim or in_dim
110
+ self.fc1 = nn.Linear(in_dim, hid_dim)
111
+ self.act = act()
112
+ self.fc2 = nn.Linear(hid_dim, out_dim)
113
+ self.drop = nn.Dropout(drop)
114
+
115
+ def forward(self, x):
116
+ x = self.fc1(x)
117
+ x = self.act(x)
118
+ x = self.drop(x)
119
+ x = self.fc2(x)
120
+ x = self.drop(x)
121
+ return x
122
+
123
+ class PromptEncoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ embed_dim: int,
127
+ image_embedding_size: Tuple[int, int],
128
+ input_image_size: Tuple[int, int],
129
+ mask_in_chans: int,
130
+ activation: Type[nn.Module] = nn.GELU,
131
+ ) -> None:
132
+ """
133
+ Encodes prompts for input to SAM's mask decoder.
134
+
135
+ Arguments:
136
+ embed_dim (int): The prompts' embedding dimension
137
+ image_embedding_size (tuple(int, int)): The spatial size of the
138
+ image embedding, as (H, W).
139
+ input_image_size (int): The padded size of the image as input
140
+ to the image encoder, as (H, W).
141
+ mask_in_chans (int): The number of hidden channels used for
142
+ encoding input masks.
143
+ activation (nn.Module): The activation to use when encoding
144
+ input masks.
145
+ """
146
+ super().__init__()
147
+ self.embed_dim = embed_dim
148
+ self.input_image_size = input_image_size
149
+ self.image_embedding_size = image_embedding_size
150
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
151
+
152
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
153
+ point_embeddings = [
154
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
155
+ ]
156
+ self.point_embeddings = nn.ModuleList(point_embeddings)
157
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
158
+
159
+ self.mask_input_size = (
160
+ 4 * image_embedding_size[0],
161
+ 4 * image_embedding_size[1],
162
+ )
163
+
164
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
165
+
166
+ self.crop_nets = Crop_Net_New(embed_dim)
167
+
168
+ self.clip_img_mlp = Mlp(in_dim=512, hid_dim=256, out_dim=256)
169
+ self.clip_text_mlp = Mlp(in_dim=512, hid_dim=256, out_dim=256)
170
+ self.mlps = Mlp(in_dim=512, hid_dim=512, out_dim=256)
171
+
172
+ self.categories = nn.Embedding(11, 256)
173
+
174
+ def get_dense_pe(self) -> torch.Tensor:
175
+ """
176
+ Returns the positional encoding used to encode point prompts,
177
+ applied to a dense set of points the shape of the image encoding.
178
+
179
+ Returns:
180
+ torch.Tensor: Positional encoding with shape
181
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
182
+ """
183
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
184
+
185
+ def _embed_points(
186
+ self,
187
+ points: torch.Tensor,
188
+ labels: torch.Tensor,
189
+ pad: bool,
190
+ ) -> torch.Tensor:
191
+ """Embeds point prompts."""
192
+ points = points + 0.5 # Shift to center of pixel
193
+ if pad:
194
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
195
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
196
+ points = torch.cat([points, padding_point], dim=1)
197
+ labels = torch.cat([labels, padding_label], dim=1)
198
+ point_embedding = self.pe_layer.forward_with_coords(
199
+ points, self.input_image_size
200
+ )
201
+ point_embedding[labels == -1] = 0.0
202
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
203
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
204
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
205
+ return point_embedding
206
+
207
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
208
+ """Embeds box prompts."""
209
+ boxes = boxes + 0.5 # Shift to center of pixel
210
+ coords = boxes.reshape(-1, 2, 2)
211
+ corner_embedding = self.pe_layer.forward_with_coords(
212
+ coords, self.input_image_size
213
+ )
214
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
215
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
216
+ return corner_embedding
217
+
218
+ # def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
219
+ # """Embeds mask inputs."""
220
+ # mask_embedding = self.mask_downscaling(masks)
221
+ # return mask_embedding
222
+
223
+ def _get_batch_size(
224
+ self,
225
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
226
+ boxes: Optional[torch.Tensor],
227
+ masks: Optional[torch.Tensor],
228
+ ) -> int:
229
+ """
230
+ Gets the batch size of the output given the batch size of the input prompts.
231
+ """
232
+ if points is not None:
233
+ return points[0].shape[0]
234
+ elif boxes is not None:
235
+ return boxes.shape[0]
236
+ # elif tokens is not None:
237
+ # return tokens.shape[0]
238
+ elif masks is not None:
239
+ return masks.shape[0]
240
+ else:
241
+ return 1
242
+
243
+ def _get_device(self) -> torch.device:
244
+ return self.point_embeddings[0].weight.device
245
+
246
+ def forward(
247
+ self,
248
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
249
+ boxes: Optional[torch.Tensor],
250
+ masks,
251
+ features,
252
+ crops,
253
+ text_features,
254
+ category_idx
255
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
256
+ """
257
+ Embeds different types of prompts, returning both sparse and dense
258
+ embeddings.
259
+
260
+ Arguments:
261
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
262
+ and labels to embed.
263
+ boxes (torch.Tensor or none): boxes to embed
264
+ masks (torch.Tensor or none): masks to embed
265
+
266
+ Returns:
267
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
268
+ BxNx(embed_dim), where N is determined by the number of input points
269
+ and boxes.
270
+ torch.Tensor: dense embeddings for the masks, in the shape
271
+ Bx(embed_dim)x(embed_H)x(embed_W)
272
+ """
273
+ bs = self._get_batch_size(points, boxes, masks)
274
+ sparse_embeddings = torch.empty(
275
+ (bs, 0, self.embed_dim), device=self._get_device()
276
+ )
277
+ if points is not None:
278
+ coords, labels = points
279
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
280
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
281
+
282
+ if boxes is not None:
283
+ box_embeddings = self._embed_boxes(boxes)
284
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
285
+
286
+ if features is not None:
287
+ clip_embeddings = self.clip_img_mlp(features)
288
+ sparse_embeddings = torch.cat([sparse_embeddings, clip_embeddings], dim=1)
289
+
290
+ if category_idx is not None:
291
+ text_embeddings = self.clip_text_mlp(text_features)
292
+ category_embeddings = torch.zeros((bs, 1, 256)).to(boxes.device)
293
+ for i in range(bs):
294
+ category_embeddings[i,0,:] = self.categories(category_idx[i].long())
295
+ modality_embeddings = torch.cat((text_embeddings, category_embeddings), dim=-1)
296
+ text_embeddings = self.mlps(modality_embeddings)
297
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embeddings], dim=1)
298
+
299
+ if crops is not None:
300
+ dense_embeddings = self.crop_nets(crops)
301
+ else:
302
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
303
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
304
+ )
305
+
306
+ return sparse_embeddings, dense_embeddings
models/tiny_vit.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # TinyViT Model Architecture
3
+ # Copyright (c) 2022 Microsoft
4
+ # Adapted from LeViT and Swin Transformer
5
+ # LeViT: (https://github.com/facebookresearch/levit)
6
+ # Swin: (https://github.com/microsoft/swin-transformer)
7
+ # Build the TinyViT Model
8
+ # --------------------------------------------------------
9
+ # The TinyViT model is adapted from MobileSAM's variant.
10
+ # --------------------------------------------------------
11
+
12
+ import itertools
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint as checkpoint
17
+ from timm.models.layers import DropPath as TimmDropPath,\
18
+ to_2tuple, trunc_normal_
19
+ from typing import Tuple
20
+
21
+ class Conv2d_BN(torch.nn.Sequential):
22
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
23
+ groups=1, bn_weight_init=1):
24
+ super().__init__()
25
+ self.add_module('c', torch.nn.Conv2d(
26
+ a, b, ks, stride, pad, dilation, groups, bias=False))
27
+ bn = torch.nn.BatchNorm2d(b)
28
+ torch.nn.init.constant_(bn.weight, bn_weight_init)
29
+ torch.nn.init.constant_(bn.bias, 0)
30
+ self.add_module('bn', bn)
31
+
32
+ @torch.no_grad()
33
+ def fuse(self):
34
+ c, bn = self._modules.values()
35
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
36
+ w = c.weight * w[:, None, None, None]
37
+ b = bn.bias - bn.running_mean * bn.weight / \
38
+ (bn.running_var + bn.eps)**0.5
39
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
40
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
41
+ m.weight.data.copy_(w)
42
+ m.bias.data.copy_(b)
43
+ return m
44
+
45
+
46
+ class DropPath(TimmDropPath):
47
+ def __init__(self, drop_prob=None):
48
+ super().__init__(drop_prob=drop_prob)
49
+ self.drop_prob = drop_prob
50
+
51
+ def __repr__(self):
52
+ msg = super().__repr__()
53
+ msg += f'(drop_prob={self.drop_prob})'
54
+ return msg
55
+
56
+
57
+ class PatchEmbed(nn.Module):
58
+ def __init__(self, in_chans, embed_dim, resolution, activation):
59
+ super().__init__()
60
+ img_size: Tuple[int, int] = to_2tuple(resolution)
61
+ #self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
62
+ self.patches_resolution = img_size
63
+ self.num_patches = self.patches_resolution[0] * \
64
+ self.patches_resolution[1]
65
+ self.in_chans = in_chans
66
+ self.embed_dim = embed_dim
67
+ n = embed_dim
68
+ #self.seq = nn.Sequential(
69
+ # Conv2d_BN(in_chans, n // 2, 3, 2, 1),
70
+ # activation(),
71
+ # Conv2d_BN(n // 2, n, 3, 2, 1),
72
+ #)
73
+ self.seq = nn.Sequential(
74
+ Conv2d_BN(in_chans, n // 2, 1, 1, 0),
75
+ activation(),
76
+ Conv2d_BN(n // 2, n, 1, 1, 0),
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.seq(x)
81
+
82
+
83
+ class MBConv(nn.Module):
84
+ def __init__(self, in_chans, out_chans, expand_ratio,
85
+ activation, drop_path):
86
+ super().__init__()
87
+ self.in_chans = in_chans
88
+ self.hidden_chans = int(in_chans * expand_ratio)
89
+ self.out_chans = out_chans
90
+
91
+ self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
92
+ self.act1 = activation()
93
+
94
+ self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
95
+ ks=3, stride=1, pad=1, groups=self.hidden_chans)
96
+ self.act2 = activation()
97
+
98
+ self.conv3 = Conv2d_BN(
99
+ self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
100
+ self.act3 = activation()
101
+
102
+ self.drop_path = DropPath(
103
+ drop_path) if drop_path > 0. else nn.Identity()
104
+
105
+ def forward(self, x):
106
+ shortcut = x
107
+
108
+ x = self.conv1(x)
109
+ x = self.act1(x)
110
+
111
+ x = self.conv2(x)
112
+ x = self.act2(x)
113
+
114
+ x = self.conv3(x)
115
+
116
+ x = self.drop_path(x)
117
+
118
+ x += shortcut
119
+ x = self.act3(x)
120
+
121
+ return x
122
+
123
+
124
+ class PatchMerging(nn.Module):
125
+ def __init__(self, input_resolution, dim, out_dim, activation):
126
+ super().__init__()
127
+
128
+ self.input_resolution = input_resolution
129
+ self.dim = dim
130
+ self.out_dim = out_dim
131
+ self.act = activation()
132
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
133
+ stride_c=2
134
+ if(out_dim==320 or out_dim==448 or out_dim==576):
135
+ stride_c=1
136
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
137
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
138
+
139
+ def forward(self, x):
140
+ if x.ndim == 3:
141
+ H, W = self.input_resolution
142
+ B = len(x)
143
+ # (B, C, H, W)
144
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
145
+
146
+ x = self.conv1(x)
147
+ x = self.act(x)
148
+
149
+ x = self.conv2(x)
150
+ x = self.act(x)
151
+ x = self.conv3(x)
152
+ x = x.flatten(2).transpose(1, 2)
153
+ return x
154
+
155
+
156
+ class ConvLayer(nn.Module):
157
+ def __init__(self, dim, input_resolution, depth,
158
+ activation,
159
+ drop_path=0., downsample=None, use_checkpoint=False,
160
+ out_dim=None,
161
+ conv_expand_ratio=4.,
162
+ ):
163
+
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.input_resolution = input_resolution
167
+ self.depth = depth
168
+ self.use_checkpoint = use_checkpoint
169
+
170
+ # build blocks
171
+ self.blocks = nn.ModuleList([
172
+ MBConv(dim, dim, conv_expand_ratio, activation,
173
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
174
+ )
175
+ for i in range(depth)])
176
+
177
+ # patch merging layer
178
+ if downsample is not None:
179
+ self.downsample = downsample(
180
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
181
+ else:
182
+ self.downsample = None
183
+
184
+ def forward(self, x):
185
+ for blk in self.blocks:
186
+ if self.use_checkpoint:
187
+ x = checkpoint.checkpoint(blk, x)
188
+ else:
189
+ x = blk(x)
190
+ if self.downsample is not None:
191
+ x = self.downsample(x)
192
+ return x
193
+
194
+
195
+ class Mlp(nn.Module):
196
+ def __init__(self, in_features, hidden_features=None,
197
+ out_features=None, act_layer=nn.GELU, drop=0.):
198
+ super().__init__()
199
+ out_features = out_features or in_features
200
+ hidden_features = hidden_features or in_features
201
+ self.norm = nn.LayerNorm(in_features)
202
+ self.fc1 = nn.Linear(in_features, hidden_features)
203
+ self.fc2 = nn.Linear(hidden_features, out_features)
204
+ self.act = act_layer()
205
+ self.drop = nn.Dropout(drop)
206
+
207
+ def forward(self, x):
208
+ x = self.norm(x)
209
+
210
+ x = self.fc1(x)
211
+ x = self.act(x)
212
+ x = self.drop(x)
213
+ x = self.fc2(x)
214
+ x = self.drop(x)
215
+ return x
216
+
217
+
218
+ class Attention(torch.nn.Module):
219
+ def __init__(self, dim, key_dim, num_heads=8,
220
+ attn_ratio=4,
221
+ resolution=(14, 14),
222
+ ):
223
+ super().__init__()
224
+ # (h, w)
225
+ assert isinstance(resolution, tuple) and len(resolution) == 2
226
+ self.num_heads = num_heads
227
+ self.scale = key_dim ** -0.5
228
+ self.key_dim = key_dim
229
+ self.nh_kd = nh_kd = key_dim * num_heads
230
+ self.d = int(attn_ratio * key_dim)
231
+ self.dh = int(attn_ratio * key_dim) * num_heads
232
+ self.attn_ratio = attn_ratio
233
+ h = self.dh + nh_kd * 2
234
+
235
+ self.norm = nn.LayerNorm(dim)
236
+ self.qkv = nn.Linear(dim, h)
237
+ self.proj = nn.Linear(self.dh, dim)
238
+
239
+ points = list(itertools.product(
240
+ range(resolution[0]), range(resolution[1])))
241
+ N = len(points)
242
+ attention_offsets = {}
243
+ idxs = []
244
+ for p1 in points:
245
+ for p2 in points:
246
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
247
+ if offset not in attention_offsets:
248
+ attention_offsets[offset] = len(attention_offsets)
249
+ idxs.append(attention_offsets[offset])
250
+ self.attention_biases = torch.nn.Parameter(
251
+ torch.zeros(num_heads, len(attention_offsets)))
252
+ self.register_buffer('attention_bias_idxs',
253
+ torch.LongTensor(idxs).view(N, N),
254
+ persistent=False)
255
+
256
+ @torch.no_grad()
257
+ def train(self, mode=True):
258
+ super().train(mode)
259
+ if mode and hasattr(self, 'ab'):
260
+ del self.ab
261
+ else:
262
+ self.register_buffer('ab',
263
+ self.attention_biases[:, self.attention_bias_idxs],
264
+ persistent=False)
265
+
266
+ def forward(self, x): # x (B,N,C)
267
+ B, N, _ = x.shape
268
+
269
+ # Normalization
270
+ x = self.norm(x)
271
+
272
+ qkv = self.qkv(x)
273
+ # (B, N, num_heads, d)
274
+ q, k, v = qkv.view(B, N, self.num_heads, -
275
+ 1).split([self.key_dim, self.key_dim, self.d], dim=3)
276
+ # (B, num_heads, N, d)
277
+ q = q.permute(0, 2, 1, 3)
278
+ k = k.permute(0, 2, 1, 3)
279
+ v = v.permute(0, 2, 1, 3)
280
+
281
+ attn = (
282
+ (q @ k.transpose(-2, -1)) * self.scale
283
+ +
284
+ (self.attention_biases[:, self.attention_bias_idxs]
285
+ if self.training else self.ab)
286
+ )
287
+ attn = attn.softmax(dim=-1)
288
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
289
+ x = self.proj(x)
290
+ return x
291
+
292
+
293
+ class TinyViTBlock(nn.Module):
294
+ r""" TinyViT Block.
295
+
296
+ Args:
297
+ dim (int): Number of input channels.
298
+ input_resolution (tuple[int, int]): Input resolution.
299
+ num_heads (int): Number of attention heads.
300
+ window_size (int): Window size.
301
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
302
+ drop (float, optional): Dropout rate. Default: 0.0
303
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
304
+ local_conv_size (int): the kernel size of the convolution between
305
+ Attention and MLP. Default: 3
306
+ activation: the activation function. Default: nn.GELU
307
+ """
308
+
309
+ def __init__(self, dim, input_resolution, num_heads, window_size=7,
310
+ mlp_ratio=4., drop=0., drop_path=0.,
311
+ local_conv_size=3,
312
+ activation=nn.GELU,
313
+ ):
314
+ super().__init__()
315
+ self.dim = dim
316
+ self.input_resolution = input_resolution
317
+ self.num_heads = num_heads
318
+ assert window_size > 0, 'window_size must be greater than 0'
319
+ self.window_size = window_size
320
+ self.mlp_ratio = mlp_ratio
321
+
322
+ self.drop_path = DropPath(
323
+ drop_path) if drop_path > 0. else nn.Identity()
324
+
325
+ assert dim % num_heads == 0, 'dim must be divisible by num_heads'
326
+ head_dim = dim // num_heads
327
+
328
+ window_resolution = (window_size, window_size)
329
+ self.attn = Attention(dim, head_dim, num_heads,
330
+ attn_ratio=1, resolution=window_resolution)
331
+
332
+ mlp_hidden_dim = int(dim * mlp_ratio)
333
+ mlp_activation = activation
334
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
335
+ act_layer=mlp_activation, drop=drop)
336
+
337
+ pad = local_conv_size // 2
338
+ self.local_conv = Conv2d_BN(
339
+ dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
340
+
341
+ def forward(self, x):
342
+ H, W = self.input_resolution
343
+ B, L, C = x.shape
344
+ assert L == H * W, "input feature has wrong size"
345
+ res_x = x
346
+ if H == self.window_size and W == self.window_size:
347
+ x = self.attn(x)
348
+ else:
349
+ x = x.view(B, H, W, C)
350
+ pad_b = (self.window_size - H %
351
+ self.window_size) % self.window_size
352
+ pad_r = (self.window_size - W %
353
+ self.window_size) % self.window_size
354
+ padding = pad_b > 0 or pad_r > 0
355
+
356
+ if padding:
357
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
358
+
359
+ pH, pW = H + pad_b, W + pad_r
360
+ nH = pH // self.window_size
361
+ nW = pW // self.window_size
362
+ # window partition
363
+ x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
364
+ B * nH * nW, self.window_size * self.window_size, C)
365
+ x = self.attn(x)
366
+ # window reverse
367
+ x = x.view(B, nH, nW, self.window_size, self.window_size,
368
+ C).transpose(2, 3).reshape(B, pH, pW, C)
369
+
370
+ if padding:
371
+ x = x[:, :H, :W].contiguous()
372
+
373
+ x = x.view(B, L, C)
374
+
375
+ x = res_x + self.drop_path(x)
376
+
377
+ x = x.transpose(1, 2).reshape(B, C, H, W)
378
+ x = self.local_conv(x)
379
+ x = x.view(B, C, L).transpose(1, 2)
380
+
381
+ x = x + self.drop_path(self.mlp(x))
382
+ return x
383
+
384
+ def extra_repr(self) -> str:
385
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
386
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
387
+
388
+
389
+ class BasicLayer(nn.Module):
390
+ """ A basic TinyViT layer for one stage.
391
+
392
+ Args:
393
+ dim (int): Number of input channels.
394
+ input_resolution (tuple[int]): Input resolution.
395
+ depth (int): Number of blocks.
396
+ num_heads (int): Number of attention heads.
397
+ window_size (int): Local window size.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ drop (float, optional): Dropout rate. Default: 0.0
400
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
401
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
402
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
403
+ local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
404
+ activation: the activation function. Default: nn.GELU
405
+ out_dim: the output dimension of the layer. Default: dim
406
+ """
407
+
408
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
409
+ mlp_ratio=4., drop=0.,
410
+ drop_path=0., downsample=None, use_checkpoint=False,
411
+ local_conv_size=3,
412
+ activation=nn.GELU,
413
+ out_dim=None,
414
+ ):
415
+
416
+ super().__init__()
417
+ self.dim = dim
418
+ self.input_resolution = input_resolution
419
+ self.depth = depth
420
+ self.use_checkpoint = use_checkpoint
421
+
422
+ # build blocks
423
+ self.blocks = nn.ModuleList([
424
+ TinyViTBlock(dim=dim, input_resolution=input_resolution,
425
+ num_heads=num_heads, window_size=window_size,
426
+ mlp_ratio=mlp_ratio,
427
+ drop=drop,
428
+ drop_path=drop_path[i] if isinstance(
429
+ drop_path, list) else drop_path,
430
+ local_conv_size=local_conv_size,
431
+ activation=activation,
432
+ )
433
+ for i in range(depth)])
434
+
435
+ # patch merging layer
436
+ if downsample is not None:
437
+ self.downsample = downsample(
438
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
439
+ else:
440
+ self.downsample = None
441
+
442
+ def forward(self, x):
443
+ for blk in self.blocks:
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x)
446
+ else:
447
+ x = blk(x)
448
+ if self.downsample is not None:
449
+ x = self.downsample(x)
450
+ return x
451
+
452
+ def extra_repr(self) -> str:
453
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
454
+
455
+ class LayerNorm2d(nn.Module):
456
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
457
+ super().__init__()
458
+ self.weight = nn.Parameter(torch.ones(num_channels))
459
+ self.bias = nn.Parameter(torch.zeros(num_channels))
460
+ self.eps = eps
461
+
462
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
463
+ u = x.mean(1, keepdim=True)
464
+ s = (x - u).pow(2).mean(1, keepdim=True)
465
+ x = (x - u) / torch.sqrt(s + self.eps)
466
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
467
+ return x
468
+
469
+ class TinyViT(nn.Module):
470
+ def __init__(self,
471
+ img_size=224,
472
+ in_chans=3,
473
+ #num_classes=1000,
474
+ embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
475
+ num_heads=[3, 6, 12, 24],
476
+ window_sizes=[7, 7, 14, 7],
477
+ mlp_ratio=4.,
478
+ drop_rate=0.,
479
+ drop_path_rate=0.1,
480
+ use_checkpoint=False,
481
+ mbconv_expand_ratio=4.0,
482
+ local_conv_size=3,
483
+ layer_lr_decay=1.0,
484
+ ):
485
+ super().__init__()
486
+ self.img_size=img_size
487
+ #self.num_classes = num_classes
488
+ self.depths = depths
489
+ self.num_layers = len(depths)
490
+ self.mlp_ratio = mlp_ratio
491
+
492
+ activation = nn.GELU
493
+
494
+ self.patch_embed = PatchEmbed(in_chans=in_chans,
495
+ embed_dim=embed_dims[0],
496
+ resolution=img_size,
497
+ activation=activation)
498
+
499
+ patches_resolution = self.patch_embed.patches_resolution
500
+ self.patches_resolution = patches_resolution
501
+
502
+ # stochastic depth
503
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
504
+ sum(depths))] # stochastic depth decay rule
505
+
506
+ # build layers
507
+ self.layers = nn.ModuleList()
508
+ for i_layer in range(self.num_layers):
509
+ kwargs = dict(dim=embed_dims[i_layer],
510
+ input_resolution=(
511
+ patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
512
+ patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))
513
+ ),
514
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
515
+ # patches_resolution[1] // (2 ** i_layer)),
516
+ depth=depths[i_layer],
517
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
518
+ downsample=PatchMerging if (
519
+ i_layer < self.num_layers - 1) else None,
520
+ use_checkpoint=use_checkpoint,
521
+ out_dim=embed_dims[min(
522
+ i_layer + 1, len(embed_dims) - 1)],
523
+ activation=activation,
524
+ )
525
+ if i_layer == 0:
526
+ layer = ConvLayer(
527
+ conv_expand_ratio=mbconv_expand_ratio,
528
+ **kwargs,
529
+ )
530
+ else:
531
+ layer = BasicLayer(
532
+ num_heads=num_heads[i_layer],
533
+ window_size=window_sizes[i_layer],
534
+ mlp_ratio=self.mlp_ratio,
535
+ drop=drop_rate,
536
+ local_conv_size=local_conv_size,
537
+ **kwargs)
538
+ self.layers.append(layer)
539
+
540
+ # init weights
541
+ self.apply(self._init_weights)
542
+ self.set_layer_lr_decay(layer_lr_decay)
543
+
544
+ self.neck = nn.Sequential(
545
+ nn.Conv2d(
546
+ embed_dims[-1],
547
+ 256,
548
+ kernel_size=1,
549
+ bias=False,
550
+ ),
551
+ LayerNorm2d(256),
552
+ nn.Conv2d(
553
+ 256,
554
+ 256,
555
+ kernel_size=3,
556
+ padding=1,
557
+ bias=False,
558
+ ),
559
+ LayerNorm2d(256),
560
+ )
561
+
562
+ def set_layer_lr_decay(self, layer_lr_decay):
563
+ decay_rate = layer_lr_decay
564
+
565
+ # layers -> blocks (depth)
566
+ depth = sum(self.depths)
567
+ lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
568
+
569
+ def _set_lr_scale(m, scale):
570
+ for p in m.parameters():
571
+ p.lr_scale = scale
572
+
573
+ self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
574
+ i = 0
575
+ for layer in self.layers:
576
+ for block in layer.blocks:
577
+ block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
578
+ i += 1
579
+ if layer.downsample is not None:
580
+ layer.downsample.apply(
581
+ lambda x: _set_lr_scale(x, lr_scales[i - 1]))
582
+ assert i == depth
583
+
584
+ for k, p in self.named_parameters():
585
+ p.param_name = k
586
+
587
+ def _check_lr_scale(m):
588
+ for p in m.parameters():
589
+ assert hasattr(p, 'lr_scale'), p.param_name
590
+
591
+ self.apply(_check_lr_scale)
592
+
593
+ def _init_weights(self, m):
594
+ if isinstance(m, nn.Linear):
595
+ trunc_normal_(m.weight, std=.02)
596
+ if isinstance(m, nn.Linear) and m.bias is not None:
597
+ nn.init.constant_(m.bias, 0)
598
+ elif isinstance(m, nn.LayerNorm):
599
+ nn.init.constant_(m.bias, 0)
600
+ nn.init.constant_(m.weight, 1.0)
601
+
602
+ @torch.jit.ignore
603
+ def no_weight_decay_keywords(self):
604
+ return {'attention_biases'}
605
+
606
+ def forward_features(self, x):
607
+ # x: (N, C, H, W)
608
+ x = self.patch_embed(x)
609
+ x = self.layers[0](x)
610
+ start_i = 1
611
+
612
+ for i in range(start_i, len(self.layers)):
613
+ layer = self.layers[i]
614
+ x = layer(x)
615
+ B, _, C = x.size()
616
+ x = x.view(B, 64, 64, C)
617
+ x = x.permute(0, 3, 1, 2)
618
+ x = self.neck(x)
619
+
620
+ return x
621
+
622
+ def forward(self, x):
623
+ x = self.forward_features(x)
624
+ return x
625
+
626
+ # model = TinyViT(
627
+ # img_size=256,
628
+ # in_chans=3,
629
+ # embed_dims=[
630
+ # 64, ## (64, 256, 256)
631
+ # 128, ## (128, 128, 128)
632
+ # 160, ## (160, 64, 64)
633
+ # 320 ## (320, 64, 64)
634
+ # ],
635
+ # depths=[2, 2, 6, 2],
636
+ # num_heads=[2, 4, 5, 10],
637
+ # window_sizes=[7, 7, 14, 7],
638
+ # mlp_ratio=4.,
639
+ # drop_rate=0.,
640
+ # drop_path_rate=0.0,
641
+ # use_checkpoint=False,
642
+ # mbconv_expand_ratio=4.0,
643
+ # local_conv_size=3,
644
+ # layer_lr_decay=0.8
645
+ # )
models/transformer.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ import math
12
+ from typing import Tuple, Type
13
+
14
+ from .common import MLPBlock
15
+
16
+
17
+ class TwoWayTransformer(nn.Module):
18
+ def __init__(
19
+ self,
20
+ depth: int,
21
+ embedding_dim: int,
22
+ num_heads: int,
23
+ mlp_dim: int,
24
+ activation: Type[nn.Module] = nn.ReLU,
25
+ attention_downsample_rate: int = 2,
26
+ ) -> None:
27
+ """
28
+ A transformer decoder that attends to an input image using
29
+ queries whose positional embedding is supplied.
30
+
31
+ Args:
32
+ depth (int): number of layers in the transformer
33
+ embedding_dim (int): the channel dimension for the input embeddings
34
+ num_heads (int): the number of heads for multihead attention. Must
35
+ divide embedding_dim
36
+ mlp_dim (int): the channel dimension internal to the MLP block
37
+ activation (nn.Module): the activation to use in the MLP block
38
+ """
39
+ super().__init__()
40
+ self.depth = depth
41
+ self.embedding_dim = embedding_dim
42
+ self.num_heads = num_heads
43
+ self.mlp_dim = mlp_dim
44
+ self.layers = nn.ModuleList()
45
+
46
+ for i in range(depth):
47
+ self.layers.append(
48
+ TwoWayAttentionBlock(
49
+ embedding_dim=embedding_dim,
50
+ num_heads=num_heads,
51
+ mlp_dim=mlp_dim,
52
+ activation=activation,
53
+ attention_downsample_rate=attention_downsample_rate,
54
+ skip_first_layer_pe=(i == 0),
55
+ )
56
+ )
57
+
58
+ self.final_attn_token_to_image = Attention(
59
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
60
+ )
61
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
62
+
63
+ def forward(
64
+ self,
65
+ image_embedding: Tensor,
66
+ image_pe: Tensor,
67
+ point_embedding: Tensor,
68
+ ) -> Tuple[Tensor, Tensor]:
69
+ """
70
+ Args:
71
+ image_embedding (torch.Tensor): image to attend to. Should be shape
72
+ B x embedding_dim x h x w for any h and w.
73
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
74
+ have the same shape as image_embedding.
75
+ point_embedding (torch.Tensor): the embedding to add to the query points.
76
+ Must have shape B x N_points x embedding_dim for any N_points.
77
+
78
+ Returns:
79
+ torch.Tensor: the processed point_embedding
80
+ torch.Tensor: the processed image_embedding
81
+ """
82
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
83
+ bs, c, h, w = image_embedding.shape
84
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
85
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
86
+
87
+ # Prepare queries
88
+ queries = point_embedding
89
+ keys = image_embedding
90
+
91
+ # Apply transformer blocks and final layernorm
92
+ for layer in self.layers:
93
+ queries, keys = layer(
94
+ queries=queries,
95
+ keys=keys,
96
+ query_pe=point_embedding,
97
+ key_pe=image_pe,
98
+ )
99
+
100
+ # Apply the final attention layer from the points to the image
101
+ q = queries + point_embedding
102
+ k = keys + image_pe
103
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
104
+ queries = queries + attn_out
105
+ queries = self.norm_final_attn(queries)
106
+
107
+ return queries, keys
108
+
109
+
110
+ class TwoWayAttentionBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ embedding_dim: int,
114
+ num_heads: int,
115
+ mlp_dim: int = 2048,
116
+ activation: Type[nn.Module] = nn.ReLU,
117
+ attention_downsample_rate: int = 2,
118
+ skip_first_layer_pe: bool = False,
119
+ ) -> None:
120
+ """
121
+ A transformer block with four layers: (1) self-attention of sparse
122
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
123
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
124
+ inputs.
125
+
126
+ Arguments:
127
+ embedding_dim (int): the channel dimension of the embeddings
128
+ num_heads (int): the number of heads in the attention layers
129
+ mlp_dim (int): the hidden dimension of the mlp block
130
+ activation (nn.Module): the activation of the mlp block
131
+ skip_first_layer_pe (bool): skip the PE on the first layer
132
+ """
133
+ super().__init__()
134
+ self.self_attn = Attention(embedding_dim, num_heads)
135
+ self.norm1 = nn.LayerNorm(embedding_dim)
136
+
137
+ self.cross_attn_token_to_image = Attention(
138
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
139
+ )
140
+ self.norm2 = nn.LayerNorm(embedding_dim)
141
+
142
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
143
+ self.norm3 = nn.LayerNorm(embedding_dim)
144
+
145
+ self.norm4 = nn.LayerNorm(embedding_dim)
146
+ self.cross_attn_image_to_token = Attention(
147
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
148
+ )
149
+
150
+ self.skip_first_layer_pe = skip_first_layer_pe
151
+
152
+ def forward(
153
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
154
+ ) -> Tuple[Tensor, Tensor]:
155
+ # Self attention block
156
+ if self.skip_first_layer_pe:
157
+ queries = self.self_attn(q=queries, k=queries, v=queries)
158
+ else:
159
+ q = queries + query_pe
160
+ attn_out = self.self_attn(q=q, k=q, v=queries)
161
+ queries = queries + attn_out
162
+ queries = self.norm1(queries)
163
+
164
+ # Cross attention block, tokens attending to image embedding
165
+ q = queries + query_pe
166
+ k = keys + key_pe
167
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
168
+ queries = queries + attn_out
169
+ queries = self.norm2(queries)
170
+
171
+ # MLP block
172
+ mlp_out = self.mlp(queries)
173
+ queries = queries + mlp_out
174
+ queries = self.norm3(queries)
175
+
176
+ # Cross attention block, image embedding attending to tokens
177
+ q = queries + query_pe
178
+ k = keys + key_pe
179
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
180
+ keys = keys + attn_out
181
+ keys = self.norm4(keys)
182
+
183
+ return queries, keys
184
+
185
+
186
+ class Attention(nn.Module):
187
+ """
188
+ An attention layer that allows for downscaling the size of the embedding
189
+ after projection to queries, keys, and values.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ embedding_dim: int,
195
+ num_heads: int,
196
+ downsample_rate: int = 1,
197
+ ) -> None:
198
+ super().__init__()
199
+ self.embedding_dim = embedding_dim
200
+ self.internal_dim = embedding_dim // downsample_rate
201
+ self.num_heads = num_heads
202
+ assert (
203
+ self.internal_dim % num_heads == 0
204
+ ), "num_heads must divide embedding_dim."
205
+
206
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
207
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
208
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
209
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
210
+
211
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
212
+ b, n, c = x.shape
213
+ x = x.reshape(b, n, num_heads, c // num_heads)
214
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
215
+
216
+ def _recombine_heads(self, x: Tensor) -> Tensor:
217
+ b, n_heads, n_tokens, c_per_head = x.shape
218
+ x = x.transpose(1, 2)
219
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
220
+
221
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
222
+ # Input projections
223
+ q = self.q_proj(q)
224
+ k = self.k_proj(k)
225
+ v = self.v_proj(v)
226
+
227
+ # Separate into heads
228
+ q = self._separate_heads(q, self.num_heads)
229
+ k = self._separate_heads(k, self.num_heads)
230
+ v = self._separate_heads(v, self.num_heads)
231
+
232
+ # Attention
233
+ _, _, _, c_per_head = q.shape
234
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
235
+ attn = attn / math.sqrt(c_per_head)
236
+ attn = torch.softmax(attn, dim=-1)
237
+
238
+ # Get output
239
+ out = attn @ v
240
+ out = self._recombine_heads(out)
241
+ out = self.out_proj(out)
242
+
243
+ return out
train.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import monai
4
+ from os import listdir, makedirs
5
+ from os.path import join, exists, isfile, isdir, basename
6
+ from tqdm import tqdm
7
+ from time import time
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader
13
+ from datetime import datetime
14
+ from shutil import copyfile
15
+ from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4
16
+ import torch.nn.functional as F
17
+ import gc
18
+ from matplotlib import pyplot as plt
19
+ import argparse
20
+ from modality_npz_dataset import ModalityNpzDataset
21
+
22
+ torch.cuda.empty_cache()
23
+ os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
24
+ os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4
25
+ os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6
26
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
27
+ os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6
28
+
29
+ def setup_seed(seed):
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+ setup_seed(2024)
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--data_root",
39
+ type=str,
40
+ default="",
41
+ help="Path to the npy data root.")
42
+
43
+ parser.add_argument('--task_name', type=str, default='MedSAM-Lite-All')
44
+
45
+ parser.add_argument("--pretrained_checkpoint",
46
+ type=str,
47
+ default=None,
48
+ help="Path to the pretrained Lite-MedSAM checkpoint.")
49
+
50
+ parser.add_argument("--resume",
51
+ type=str,
52
+ default=None,
53
+ help="Path to the checkpoint to continue training.")
54
+ parser.add_argument(
55
+ "--work_dir",
56
+ type=str,
57
+ default="./work_dir",
58
+ help=
59
+ "Path to the working directory where checkpoints and logs will be saved."
60
+ )
61
+
62
+ parser.add_argument('--data_aug',
63
+ action='store_true',
64
+ default=False,
65
+ help='use data augmentation during training')
66
+
67
+ parser.add_argument("--num_epochs",
68
+ type=int,
69
+ default=25,
70
+ help="Number of epochs to train.")
71
+ parser.add_argument("--batch_size",
72
+ type=int,
73
+ default=16,
74
+ help="Batch size.")
75
+ parser.add_argument("--num_workers",
76
+ type=int,
77
+ default=8,
78
+ help="Number of workers for dataloader.")
79
+
80
+ parser.add_argument(
81
+ "--bbox_shift",
82
+ type=int,
83
+ default=5,
84
+ help="Perturbation to bounding box coordinates during training.")
85
+
86
+ parser.add_argument("-lr", type=float, default=2e-4, help="Learning rate.")
87
+
88
+ parser.add_argument("-weight_decay",
89
+ type=float,
90
+ default=0.001,
91
+ help="Weight decay.")
92
+
93
+ parser.add_argument("-iou_loss_weight",
94
+ type=float,
95
+ default=1.0,
96
+ help="Weight of IoU loss.")
97
+
98
+ parser.add_argument("-seg_loss_weight",
99
+ type=float,
100
+ default=1.0,
101
+ help="Weight of segmentation loss.")
102
+ parser.add_argument("-ce_loss_weight",
103
+ type=float,
104
+ default=1.0,
105
+ help="Weight of cross entropy loss.")
106
+
107
+ parser.add_argument("--sanity_check",
108
+ action="store_true",
109
+ default=True,
110
+ help="Whether to do sanity check for dataloading.")
111
+
112
+ args = parser.parse_args()
113
+ return args
114
+
115
+
116
+ def show_mask(mask, ax, random_color=True):
117
+ if random_color:
118
+ color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0)
119
+ else:
120
+ color = np.array([251 / 255, 252 / 255, 30 / 255, 0.45])
121
+ h, w = mask.shape[-2:]
122
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
123
+ ax.imshow(mask_image)
124
+
125
+
126
+ def show_box(box, ax):
127
+ x0, y0 = box[0], box[1]
128
+ w, h = box[2] - box[0], box[3] - box[1]
129
+ ax.add_patch(
130
+ plt.Rectangle((x0, y0),
131
+ w,
132
+ h,
133
+ edgecolor='blue',
134
+ facecolor=(0, 0, 0, 0),
135
+ lw=2))
136
+
137
+
138
+ def show_points(points, ax):
139
+ for i, (x, y) in enumerate(points):
140
+ ax.scatter(x, y, color='red', s=10)
141
+
142
+
143
+ def cal_iou(result, reference):
144
+
145
+ intersection = torch.count_nonzero(torch.logical_and(result, reference),
146
+ dim=[i for i in range(1, result.ndim)])
147
+ union = torch.count_nonzero(torch.logical_or(result, reference),
148
+ dim=[i for i in range(1, result.ndim)])
149
+
150
+ iou = intersection.float() / union.float()
151
+
152
+ return iou.unsqueeze(1)
153
+
154
+
155
+ def sanity_check_dataset(args):
156
+
157
+ tr_dataset = ModalityNpzDataset(args.data_root, data_aug=True)
158
+ tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
159
+
160
+ for step, batch in enumerate(tr_dataloader):
161
+ # show the example
162
+ _, axs = plt.subplots(1, 2, figsize=(10, 10))
163
+ idx = random.randint(0, 4)
164
+
165
+ image = batch["image"]
166
+ gt = batch["gt2D"]
167
+ bboxes = batch["bboxes"]
168
+ names_temp = batch["image_name"]
169
+
170
+ axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
171
+ show_mask(gt[idx].cpu().squeeze().numpy(), axs[0])
172
+ show_box(bboxes[idx].numpy().squeeze(), axs[0])
173
+ axs[0].axis('off')
174
+ # set title
175
+ axs[0].set_title(names_temp[idx])
176
+ idx = random.randint(4, 7)
177
+ axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
178
+ show_mask(gt[idx].cpu().squeeze().numpy(), axs[1])
179
+ show_box(bboxes[idx].numpy().squeeze(), axs[1])
180
+ axs[1].axis('off')
181
+ # set title
182
+ axs[1].set_title(names_temp[idx])
183
+ plt.subplots_adjust(wspace=0.01, hspace=0)
184
+ plt.savefig(join(args.work_dir, 'Sanitycheck_DA.png'),
185
+ bbox_inches='tight',
186
+ dpi=300)
187
+ plt.close()
188
+ break
189
+
190
+
191
+ class MedSAM_Lite(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ image_encoder,
196
+ mask_decoder,
197
+ prompt_encoder,
198
+ ):
199
+ super().__init__()
200
+ self.image_encoder = image_encoder
201
+ self.mask_decoder = mask_decoder
202
+ self.prompt_encoder = prompt_encoder
203
+ encoder_weight_file = "" # path for vision encoder (tiny vit) weights
204
+
205
+ self.image_encoder.load_state_dict(torch.load(encoder_weight_file))
206
+
207
+ def forward(self, image, points, boxes, masks, features, crops,
208
+ text_features, category_idx):
209
+ image_embedding = self.image_encoder(image)
210
+
211
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
212
+ points=points,
213
+ boxes=boxes,
214
+ masks=masks,
215
+ features=features,
216
+ crops=crops,
217
+ text_features=text_features,
218
+ category_idx=category_idx)
219
+
220
+ low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder(
221
+ image_embeddings=image_embedding, # (B, 256, 64, 64)
222
+ image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
223
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
224
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
225
+ multimask_output=False,
226
+ ) # (B, 1, 256, 256)
227
+
228
+ return low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec
229
+
230
+ @torch.no_grad()
231
+ def postprocess_masks(self, masks, new_size, original_size):
232
+ """
233
+ Do cropping and resizing
234
+ """
235
+ # Crop
236
+ masks = masks[:, :, :new_size[0], :new_size[1]]
237
+ # Resize
238
+ masks = F.interpolate(
239
+ masks,
240
+ size=(original_size[0], original_size[1]),
241
+ mode="bilinear",
242
+ align_corners=False,
243
+ )
244
+
245
+ return masks
246
+
247
+
248
+ def collate_fn(batch):
249
+ """
250
+ Collate function for PyTorch DataLoader.
251
+ """
252
+ batch_dict = {}
253
+ for key in batch[0].keys():
254
+ if key == "image_name" or key == "category_idx":
255
+ batch_dict[key] = [sample[key] for sample in batch]
256
+ else:
257
+ batch_dict[key] = torch.stack([sample[key] for sample in batch],
258
+ dim=0)
259
+
260
+ return batch_dict
261
+
262
+
263
+ if __name__ == "__main__":
264
+
265
+ args = get_args()
266
+ sanity_check_dataset(args)
267
+
268
+ run_id = datetime.now().strftime("%Y%m%d-%H%M")
269
+ print(f"Run ID: {run_id}")
270
+
271
+ model_save_path = join(args.work_dir, args.task_name + "-" + run_id)
272
+ makedirs(model_save_path, exist_ok=True)
273
+ copyfile(__file__,
274
+ join(model_save_path, run_id + "_" + os.path.basename(__file__)))
275
+
276
+ device = torch.device("cuda")
277
+
278
+ num_epochs = args.num_epochs
279
+ batch_size = args.batch_size
280
+ num_workers = args.num_workers
281
+
282
+ medsam_lite_image_encoder = TinyViT(
283
+ img_size=256,
284
+ in_chans=3,
285
+ embed_dims=[
286
+ 64, ## (64, 256, 256)
287
+ 128, ## (128, 128, 128)
288
+ 160, ## (160, 64, 64)
289
+ 320 ## (320, 64, 64)
290
+ ],
291
+ depths=[2, 2, 6, 2],
292
+ num_heads=[2, 4, 5, 10],
293
+ window_sizes=[7, 7, 14, 7],
294
+ mlp_ratio=4.,
295
+ drop_rate=0.,
296
+ drop_path_rate=0.0,
297
+ use_checkpoint=False,
298
+ mbconv_expand_ratio=4.0,
299
+ local_conv_size=3,
300
+ layer_lr_decay=0.8)
301
+
302
+ medsam_lite_prompt_encoder = PromptEncoder(embed_dim=256,
303
+ image_embedding_size=(64, 64),
304
+ input_image_size=(256, 256),
305
+ mask_in_chans=16)
306
+
307
+ medsam_lite_mask_decoder = MaskDecoder_F4(
308
+ num_multimask_outputs=3,
309
+ transformer=TwoWayTransformer(
310
+ depth=2,
311
+ embedding_dim=256,
312
+ mlp_dim=2048,
313
+ num_heads=8,
314
+ ),
315
+ modality=True,
316
+ contents=True,
317
+ transformer_dim=256,
318
+ iou_head_depth=3,
319
+ iou_head_hidden_dim=256,
320
+ )
321
+
322
+ medsam_lite_model = MedSAM_Lite(image_encoder=medsam_lite_image_encoder,
323
+ mask_decoder=medsam_lite_mask_decoder,
324
+ prompt_encoder=medsam_lite_prompt_encoder)
325
+
326
+ if args.resume is None and args.pretrained_checkpoint is not None:
327
+ ## Load pretrained checkpoint if there's no checkpoint to resume from and there's a pretrained checkpoint
328
+ print(
329
+ f"Loading pretrained checkpoint from {args.pretrained_checkpoint}")
330
+ medsam_lite_checkpoint = torch.load(args.pretrained_checkpoint,
331
+ map_location="cpu")
332
+ medsam_lite_model.load_state_dict(medsam_lite_checkpoint["model"],
333
+ strict=True)
334
+
335
+ medsam_lite_model = medsam_lite_model.to(device)
336
+
337
+ medsam_lite_model.train()
338
+
339
+ print(
340
+ f"MedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}"
341
+ )
342
+
343
+ print('lr:', args.lr)
344
+
345
+ optimizer = optim.AdamW(
346
+ medsam_lite_model.parameters(),
347
+ lr=args.lr,
348
+ betas=(0.9, 0.999),
349
+ eps=1e-08,
350
+ weight_decay=args.weight_decay,
351
+ )
352
+ lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
353
+ mode='min',
354
+ factor=0.9,
355
+ patience=5,
356
+ cooldown=0)
357
+ seg_loss = monai.losses.DiceLoss(sigmoid=True,
358
+ squared_pred=True,
359
+ reduction='mean')
360
+ bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
361
+ iou_loss = nn.MSELoss(reduction='mean')
362
+ ce_loss = nn.CrossEntropyLoss(reduction='mean')
363
+
364
+ train_dataset = ModalityNpzDataset(data_root=args.data_root, data_aug=True)
365
+
366
+ train_loader = DataLoader(train_dataset,
367
+ batch_size=batch_size,
368
+ shuffle=True,
369
+ num_workers=num_workers,
370
+ pin_memory=True)
371
+
372
+ if args.resume is not None:
373
+ ckpt_folders = sorted(listdir(args.resume))
374
+ ckpt_folders = [
375
+ f for f in ckpt_folders
376
+ if (f.startswith(args.task_name)
377
+ and isfile(join(args.resume, f, 'medsam_lite_latest.pth')))
378
+ ]
379
+ print('*' * 20)
380
+ print('existing ckpts in', args.resume, ckpt_folders)
381
+ # find the latest ckpt folders
382
+ time_strings = [
383
+ f.split(args.task_name + '-')[-1] for f in ckpt_folders
384
+ ]
385
+ dates = [datetime.strptime(f, '%Y%m%d-%H%M') for f in time_strings]
386
+ latest_date = max(dates)
387
+ latest_ckpt = join(
388
+ args.work_dir,
389
+ args.task_name + '-' + latest_date.strftime('%Y%m%d-%H%M'),
390
+ 'medsam_lite_latest.pth')
391
+ print('Loading from', latest_ckpt)
392
+ checkpoint = torch.load(latest_ckpt, map_location=device)
393
+ medsam_lite_model.module.load_state_dict(checkpoint["model"])
394
+ optimizer.load_state_dict(checkpoint["optimizer"])
395
+ start_epoch = checkpoint["epoch"] + 1
396
+ best_loss = checkpoint["loss"]
397
+ print(f"Loaded checkpoint from epoch {start_epoch}")
398
+ else:
399
+ start_epoch = 0
400
+ best_loss = 1e10
401
+
402
+ train_losses = []
403
+ epoch_times = []
404
+
405
+ print("Training")
406
+ for epoch in range(start_epoch, num_epochs):
407
+ if epoch == num_epochs - 1:
408
+ for param_group in optimizer.param_groups:
409
+ param_group['lr'] = 5e-5
410
+
411
+ epoch_loss = [1e10 for _ in range(len(train_loader))]
412
+ epoch_start_time = time()
413
+ pbar = tqdm(train_loader)
414
+ for step, batch in enumerate(pbar):
415
+ gc.collect()
416
+ torch.cuda.empty_cache()
417
+ image = batch["image"]
418
+ gt2D = batch["gt2D"]
419
+ boxes = batch["bboxes"]
420
+ coords = batch["coords"]
421
+ crops = batch["image_crop"]
422
+ features = batch["image_feature"]
423
+ text_features = batch["text_feature"]
424
+ class_idx = batch["category_idx"]
425
+ class_idx = torch.tensor(class_idx)
426
+
427
+ optimizer.zero_grad()
428
+ image, gt2D, boxes, coords, crops, features, text_features, class_idx = image.to(
429
+ device), gt2D.to(device), boxes.to(device), coords.to(
430
+ device), crops.to(device), features.to(
431
+ device), text_features.to(device), class_idx.to(device)
432
+ labels_torch = torch.ones(coords.shape[0]).long()
433
+ labels_torch = labels_torch.unsqueeze(1).expand(-1, 4)
434
+ labels_torch = labels_torch.to(device)
435
+ point_prompt = (coords, labels_torch)
436
+ logits_pred, iou_pred, category_predictions, clip_vec, img_vec = medsam_lite_model(
437
+ image, None, boxes, None, features, crops, text_features, class_idx)
438
+
439
+ clip_img_features = clip_vec / clip_vec.norm(dim=-1, keepdim=True)
440
+ img_features = img_vec / img_vec.norm(dim=-1, keepdim=True)
441
+ similarity1 = torch.matmul(clip_img_features, img_features.T)
442
+ similarity2 = torch.matmul(img_features, clip_img_features.T)
443
+ sim_labels = torch.arange(similarity1.shape[0]).to(image.device)
444
+
445
+ l_seg = seg_loss(logits_pred, gt2D)
446
+ l_bce = bce_loss(logits_pred, gt2D.float())
447
+ l_ce_sim = 0.5 * (ce_loss(similarity1, sim_labels.long()) +
448
+ ce_loss(similarity2, sim_labels.long()))
449
+ l_ce = ce_loss(category_predictions, class_idx.long())
450
+ mask_loss = l_seg + l_bce
451
+ with torch.no_grad():
452
+ iou_gt = cal_iou(torch.sigmoid(logits_pred) > 0.5, gt2D.bool())
453
+ l_iou = iou_loss(iou_pred, iou_gt)
454
+ loss = mask_loss + l_iou + 0.01 * l_ce_sim + 0.01 * l_ce
455
+ epoch_loss[step] = loss.item()
456
+ loss.backward()
457
+ optimizer.step()
458
+ optimizer.zero_grad()
459
+ pbar.set_description(
460
+ f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}"
461
+ )
462
+
463
+ epoch_end_time = time()
464
+ epoch_duration = epoch_end_time - epoch_start_time
465
+ epoch_times.append(epoch_duration)
466
+
467
+ epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss)
468
+
469
+ train_losses.append(epoch_loss_reduced)
470
+ lr_scheduler.step(epoch_loss_reduced)
471
+
472
+ model_weights = medsam_lite_model.state_dict()
473
+
474
+ checkpoint = {
475
+ "model": model_weights,
476
+ "epoch": epoch,
477
+ "optimizer": optimizer.state_dict(),
478
+ "loss": epoch_loss_reduced,
479
+ "best_loss": best_loss,
480
+ }
481
+ torch.save(checkpoint, join(model_save_path, "medsam_lite_latest.pth"))
482
+
483
+ if epoch_loss_reduced < best_loss:
484
+ print(
485
+ f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}")
486
+ best_loss = epoch_loss_reduced
487
+ checkpoint["best_loss"] = best_loss
488
+ torch.save(checkpoint, join(model_save_path,
489
+ "medsam_lite_best.pth"))
490
+ epoch_loss_reduced = 1e10
491
+
492
+ fig, axes = plt.subplots(2, 1, figsize=(10, 8))
493
+ axes[0].title.set_text("Dice + Binary Cross Entropy + IoU Loss")
494
+ axes[0].plot(train_losses)
495
+ axes[0].set_ylabel("Loss")
496
+ axes[1].plot(epoch_times)
497
+ axes[1].title.set_text("Epoch Duration")
498
+ axes[1].set_ylabel("Duration (s)")
499
+ axes[1].set_xlabel("Epoch")
500
+ plt.tight_layout()
501
+ plt.savefig(join(model_save_path, "log.png"))
502
+ plt.close()