yuxin commited on
Commit ·
d0601bd
1
Parent(s): 6f1b94e
add config
Browse files- config.json +0 -1
- model_segvol_single.py +28 -36
config.json
CHANGED
|
@@ -6,7 +6,6 @@
|
|
| 6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
| 7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
| 8 |
},
|
| 9 |
-
"custom_device": "cpu",
|
| 10 |
"model_type": "segvol",
|
| 11 |
"patch_size": [
|
| 12 |
4,
|
|
|
|
| 6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
| 7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
| 8 |
},
|
|
|
|
| 9 |
"model_type": "segvol",
|
| 10 |
"patch_size": [
|
| 11 |
4,
|
model_segvol_single.py
CHANGED
|
@@ -9,15 +9,11 @@ class SegVolConfig(PretrainedConfig):
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
test_mode=True,
|
| 12 |
-
custom_device='cpu',
|
| 13 |
-
# clip_model='.',
|
| 14 |
**kwargs,
|
| 15 |
):
|
| 16 |
self.spatial_size = [32, 256, 256]
|
| 17 |
self.patch_size = [4, 16, 16]
|
| 18 |
self.test_mode = test_mode
|
| 19 |
-
self.custom_device = custom_device
|
| 20 |
-
# self.clip_model = clip_model
|
| 21 |
super().__init__(**kwargs)
|
| 22 |
|
| 23 |
class SegVolModel(PreTrainedModel):
|
|
@@ -38,14 +34,11 @@ class SegVolModel(PreTrainedModel):
|
|
| 38 |
prompt_encoder=sam_model.prompt_encoder,
|
| 39 |
roi_size=self.config.spatial_size,
|
| 40 |
patch_size=self.config.patch_size,
|
| 41 |
-
custom_device=self.config.custom_device,
|
| 42 |
# clip_model=self.config.clip_model,
|
| 43 |
test_mode=self.config.test_mode,
|
| 44 |
)
|
| 45 |
|
| 46 |
-
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size
|
| 47 |
-
|
| 48 |
-
self.custom_device = self.config.custom_device
|
| 49 |
|
| 50 |
def forward_test(self,
|
| 51 |
image,
|
|
@@ -53,7 +46,8 @@ class SegVolModel(PreTrainedModel):
|
|
| 53 |
text_prompt=None,
|
| 54 |
bbox_prompt_group=None,
|
| 55 |
point_prompt_group=None,
|
| 56 |
-
use_zoom=True):
|
|
|
|
| 57 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
| 58 |
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
|
| 59 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
|
@@ -110,7 +104,7 @@ class SegVolModel(PreTrainedModel):
|
|
| 110 |
## inference
|
| 111 |
with torch.no_grad():
|
| 112 |
logits_single_cropped = sliding_window_inference(
|
| 113 |
-
image_single_cropped.to(
|
| 114 |
self.config.spatial_size, 1, self.model, 0.5,
|
| 115 |
text=text_prompt,
|
| 116 |
use_box=bbox_prompt is not None,
|
|
@@ -128,7 +122,7 @@ class SegVolModel(PreTrainedModel):
|
|
| 128 |
|
| 129 |
# processor
|
| 130 |
class SegVolProcessor():
|
| 131 |
-
def __init__(self, spatial_size
|
| 132 |
self.img_loader = transforms.LoadImage()
|
| 133 |
self.transform4test = transforms.Compose(
|
| 134 |
[
|
|
@@ -140,7 +134,6 @@ class SegVolProcessor():
|
|
| 140 |
]
|
| 141 |
)
|
| 142 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
| 143 |
-
self.custom_device = custom_device
|
| 144 |
self.transform4train = transforms.Compose(
|
| 145 |
[
|
| 146 |
# transforms.AddChanneld(keys=["image"]),
|
|
@@ -217,24 +210,24 @@ class SegVolProcessor():
|
|
| 217 |
item['zoom_out_label'] = item_zoom_out['label']
|
| 218 |
return item
|
| 219 |
|
| 220 |
-
def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0):
|
| 221 |
point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
|
| 222 |
-
points_single = (point.unsqueeze(0).float().to(
|
| 223 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
| 224 |
return points_single, binary_points_resize
|
| 225 |
|
| 226 |
-
def bbox_prompt_b(self, label_single_resize):
|
| 227 |
-
box_single = generate_box(label_single_resize).unsqueeze(0).float().to(
|
| 228 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
| 229 |
return box_single, binary_cube_resize
|
| 230 |
|
| 231 |
-
def dice_score(self, preds, labels):
|
| 232 |
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
| 233 |
predict = preds.view(1, -1)
|
| 234 |
target = labels.view(1, -1)
|
| 235 |
if target.shape[1] < 1e8:
|
| 236 |
-
predict = predict.to(
|
| 237 |
-
target = target.to(
|
| 238 |
predict = torch.sigmoid(predict)
|
| 239 |
predict = torch.where(predict > 0.5, 1., 0.)
|
| 240 |
|
|
@@ -425,20 +418,18 @@ class SegVol(nn.Module):
|
|
| 425 |
prompt_encoder,
|
| 426 |
roi_size,
|
| 427 |
patch_size,
|
| 428 |
-
custom_device,
|
| 429 |
# clip_model,
|
| 430 |
test_mode=False,
|
| 431 |
):
|
| 432 |
super().__init__()
|
| 433 |
-
self.custom_device = custom_device
|
| 434 |
self.image_encoder = image_encoder
|
| 435 |
self.mask_decoder = mask_decoder
|
| 436 |
self.prompt_encoder = prompt_encoder
|
| 437 |
-
self.text_encoder = TextEncoder(
|
| 438 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 439 |
self.test_mode = test_mode
|
| 440 |
-
self.dice_loss = BinaryDiceLoss()
|
| 441 |
-
self.bce_loss = BCELoss()
|
| 442 |
self.decoder_iter = 6
|
| 443 |
|
| 444 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
|
@@ -459,12 +450,13 @@ class SegVol(nn.Module):
|
|
| 459 |
return sl_loss
|
| 460 |
|
| 461 |
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
|
|
|
| 462 |
with torch.no_grad():
|
| 463 |
if boxes is not None:
|
| 464 |
if len(boxes.shape) == 2:
|
| 465 |
boxes = boxes[:, None, :] # (B, 1, 6)
|
| 466 |
if text is not None:
|
| 467 |
-
text_embedding = self.text_encoder(text) # (B, 768)
|
| 468 |
else:
|
| 469 |
text_embedding = None
|
| 470 |
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
@@ -487,7 +479,8 @@ class SegVol(nn.Module):
|
|
| 487 |
return logits
|
| 488 |
|
| 489 |
def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
|
| 490 |
-
|
|
|
|
| 491 |
# select prompt
|
| 492 |
prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
|
| 493 |
[None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
|
|
@@ -517,7 +510,7 @@ class SegVol(nn.Module):
|
|
| 517 |
# sll_loss += sll_loss_dice + sll_loss_bce
|
| 518 |
# return sll_loss
|
| 519 |
|
| 520 |
-
def build_prompt_label(self, bs, training_organs, train_labels):
|
| 521 |
# generate prompt & label
|
| 522 |
iter_organs = []
|
| 523 |
iter_bboxes = []
|
|
@@ -541,10 +534,10 @@ class SegVol(nn.Module):
|
|
| 541 |
iter_points_ax.append(point)
|
| 542 |
iter_point_labels.append(point_label)
|
| 543 |
# batched prompt
|
| 544 |
-
iter_points_ax = torch.stack(iter_points_ax, dim=0).to(
|
| 545 |
-
iter_point_labels = torch.stack(iter_point_labels, dim=0).to(
|
| 546 |
iter_points = (iter_points_ax, iter_point_labels)
|
| 547 |
-
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(
|
| 548 |
return iter_points, iter_bboxes, iter_organs
|
| 549 |
|
| 550 |
# def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
|
@@ -611,9 +604,8 @@ class SegVol(nn.Module):
|
|
| 611 |
# return pseudo_labels, bboxes
|
| 612 |
|
| 613 |
class TextEncoder(nn.Module):
|
| 614 |
-
def __init__(self
|
| 615 |
super().__init__()
|
| 616 |
-
self.custom_device = custom_device
|
| 617 |
config = CLIPTextConfig()
|
| 618 |
self.clip_text_model = CLIPTextModel(config)
|
| 619 |
self.tokenizer = None
|
|
@@ -622,20 +614,20 @@ class TextEncoder(nn.Module):
|
|
| 622 |
for param in self.clip_text_model.parameters():
|
| 623 |
param.requires_grad = False
|
| 624 |
|
| 625 |
-
def organ2tokens(self, organ_names):
|
| 626 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
| 627 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
| 628 |
for key in tokens.keys():
|
| 629 |
-
tokens[key] = tokens[key].to(
|
| 630 |
return tokens
|
| 631 |
|
| 632 |
-
def forward(self, text):
|
| 633 |
if text is None:
|
| 634 |
return None
|
| 635 |
if type(text) is str:
|
| 636 |
# text is supposed to be list
|
| 637 |
text = [text]
|
| 638 |
-
tokens = self.organ2tokens(text)
|
| 639 |
clip_outputs = self.clip_text_model(**tokens)
|
| 640 |
text_embedding = clip_outputs.pooler_output
|
| 641 |
text_embedding = self.dim_align(text_embedding)
|
|
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
test_mode=True,
|
|
|
|
|
|
|
| 12 |
**kwargs,
|
| 13 |
):
|
| 14 |
self.spatial_size = [32, 256, 256]
|
| 15 |
self.patch_size = [4, 16, 16]
|
| 16 |
self.test_mode = test_mode
|
|
|
|
|
|
|
| 17 |
super().__init__(**kwargs)
|
| 18 |
|
| 19 |
class SegVolModel(PreTrainedModel):
|
|
|
|
| 34 |
prompt_encoder=sam_model.prompt_encoder,
|
| 35 |
roi_size=self.config.spatial_size,
|
| 36 |
patch_size=self.config.patch_size,
|
|
|
|
| 37 |
# clip_model=self.config.clip_model,
|
| 38 |
test_mode=self.config.test_mode,
|
| 39 |
)
|
| 40 |
|
| 41 |
+
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def forward_test(self,
|
| 44 |
image,
|
|
|
|
| 46 |
text_prompt=None,
|
| 47 |
bbox_prompt_group=None,
|
| 48 |
point_prompt_group=None,
|
| 49 |
+
use_zoom=True,):
|
| 50 |
+
device = image.device
|
| 51 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
| 52 |
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
|
| 53 |
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
|
|
|
| 104 |
## inference
|
| 105 |
with torch.no_grad():
|
| 106 |
logits_single_cropped = sliding_window_inference(
|
| 107 |
+
image_single_cropped.to(device), prompt_reflection,
|
| 108 |
self.config.spatial_size, 1, self.model, 0.5,
|
| 109 |
text=text_prompt,
|
| 110 |
use_box=bbox_prompt is not None,
|
|
|
|
| 122 |
|
| 123 |
# processor
|
| 124 |
class SegVolProcessor():
|
| 125 |
+
def __init__(self, spatial_size) -> None:
|
| 126 |
self.img_loader = transforms.LoadImage()
|
| 127 |
self.transform4test = transforms.Compose(
|
| 128 |
[
|
|
|
|
| 134 |
]
|
| 135 |
)
|
| 136 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
|
|
|
| 137 |
self.transform4train = transforms.Compose(
|
| 138 |
[
|
| 139 |
# transforms.AddChanneld(keys=["image"]),
|
|
|
|
| 210 |
item['zoom_out_label'] = item_zoom_out['label']
|
| 211 |
return item
|
| 212 |
|
| 213 |
+
def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'):
|
| 214 |
point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
|
| 215 |
+
points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
|
| 216 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
| 217 |
return points_single, binary_points_resize
|
| 218 |
|
| 219 |
+
def bbox_prompt_b(self, label_single_resize, device='cpu'):
|
| 220 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device)
|
| 221 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
| 222 |
return box_single, binary_cube_resize
|
| 223 |
|
| 224 |
+
def dice_score(self, preds, labels, device='cpu'):
|
| 225 |
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
| 226 |
predict = preds.view(1, -1)
|
| 227 |
target = labels.view(1, -1)
|
| 228 |
if target.shape[1] < 1e8:
|
| 229 |
+
predict = predict.to(device)
|
| 230 |
+
target = target.to(device)
|
| 231 |
predict = torch.sigmoid(predict)
|
| 232 |
predict = torch.where(predict > 0.5, 1., 0.)
|
| 233 |
|
|
|
|
| 418 |
prompt_encoder,
|
| 419 |
roi_size,
|
| 420 |
patch_size,
|
|
|
|
| 421 |
# clip_model,
|
| 422 |
test_mode=False,
|
| 423 |
):
|
| 424 |
super().__init__()
|
|
|
|
| 425 |
self.image_encoder = image_encoder
|
| 426 |
self.mask_decoder = mask_decoder
|
| 427 |
self.prompt_encoder = prompt_encoder
|
| 428 |
+
self.text_encoder = TextEncoder()
|
| 429 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 430 |
self.test_mode = test_mode
|
| 431 |
+
self.dice_loss = BinaryDiceLoss()
|
| 432 |
+
self.bce_loss = BCELoss()
|
| 433 |
self.decoder_iter = 6
|
| 434 |
|
| 435 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
|
|
|
| 450 |
return sl_loss
|
| 451 |
|
| 452 |
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
| 453 |
+
device = image_embedding.device
|
| 454 |
with torch.no_grad():
|
| 455 |
if boxes is not None:
|
| 456 |
if len(boxes.shape) == 2:
|
| 457 |
boxes = boxes[:, None, :] # (B, 1, 6)
|
| 458 |
if text is not None:
|
| 459 |
+
text_embedding = self.text_encoder(text, device) # (B, 768)
|
| 460 |
else:
|
| 461 |
text_embedding = None
|
| 462 |
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
|
|
| 479 |
return logits
|
| 480 |
|
| 481 |
def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
|
| 482 |
+
device = image_embedding.device
|
| 483 |
+
iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device)
|
| 484 |
# select prompt
|
| 485 |
prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
|
| 486 |
[None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
|
|
|
|
| 510 |
# sll_loss += sll_loss_dice + sll_loss_bce
|
| 511 |
# return sll_loss
|
| 512 |
|
| 513 |
+
def build_prompt_label(self, bs, training_organs, train_labels, device):
|
| 514 |
# generate prompt & label
|
| 515 |
iter_organs = []
|
| 516 |
iter_bboxes = []
|
|
|
|
| 534 |
iter_points_ax.append(point)
|
| 535 |
iter_point_labels.append(point_label)
|
| 536 |
# batched prompt
|
| 537 |
+
iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device)
|
| 538 |
+
iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device)
|
| 539 |
iter_points = (iter_points_ax, iter_point_labels)
|
| 540 |
+
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device)
|
| 541 |
return iter_points, iter_bboxes, iter_organs
|
| 542 |
|
| 543 |
# def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
|
|
|
| 604 |
# return pseudo_labels, bboxes
|
| 605 |
|
| 606 |
class TextEncoder(nn.Module):
|
| 607 |
+
def __init__(self):
|
| 608 |
super().__init__()
|
|
|
|
| 609 |
config = CLIPTextConfig()
|
| 610 |
self.clip_text_model = CLIPTextModel(config)
|
| 611 |
self.tokenizer = None
|
|
|
|
| 614 |
for param in self.clip_text_model.parameters():
|
| 615 |
param.requires_grad = False
|
| 616 |
|
| 617 |
+
def organ2tokens(self, organ_names, device):
|
| 618 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
| 619 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
| 620 |
for key in tokens.keys():
|
| 621 |
+
tokens[key] = tokens[key].to(device)
|
| 622 |
return tokens
|
| 623 |
|
| 624 |
+
def forward(self, text, device):
|
| 625 |
if text is None:
|
| 626 |
return None
|
| 627 |
if type(text) is str:
|
| 628 |
# text is supposed to be list
|
| 629 |
text = [text]
|
| 630 |
+
tokens = self.organ2tokens(text, device)
|
| 631 |
clip_outputs = self.clip_text_model(**tokens)
|
| 632 |
text_embedding = clip_outputs.pooler_output
|
| 633 |
text_embedding = self.dim_align(text_embedding)
|