yuxin commited on
Commit ·
6e933c4
1
Parent(s): da33dfe
add config
Browse files- config.json +0 -1
- merges.txt +0 -0
- model_segvol_single.py +8 -13
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- vocab.json +0 -0
config.json
CHANGED
|
@@ -6,7 +6,6 @@
|
|
| 6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
| 7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
| 8 |
},
|
| 9 |
-
"clip_model": "openai/clip-vit-base-patch32",
|
| 10 |
"model_type": "segvol",
|
| 11 |
"patch_size": [
|
| 12 |
4,
|
|
|
|
| 6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
| 7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
| 8 |
},
|
|
|
|
| 9 |
"model_type": "segvol",
|
| 10 |
"patch_size": [
|
| 11 |
4,
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model_segvol_single.py
CHANGED
|
@@ -9,13 +9,13 @@ class SegVolConfig(PretrainedConfig):
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
test_mode=True,
|
| 12 |
-
clip_model='
|
| 13 |
**kwargs,
|
| 14 |
):
|
| 15 |
self.spatial_size = [32, 256, 256]
|
| 16 |
self.patch_size = [4, 16, 16]
|
| 17 |
self.test_mode = test_mode
|
| 18 |
-
self.clip_model = clip_model
|
| 19 |
super().__init__(**kwargs)
|
| 20 |
|
| 21 |
class SegVolModel(PreTrainedModel):
|
|
@@ -36,7 +36,7 @@ class SegVolModel(PreTrainedModel):
|
|
| 36 |
prompt_encoder=sam_model.prompt_encoder,
|
| 37 |
roi_size=self.config.spatial_size,
|
| 38 |
patch_size=self.config.patch_size,
|
| 39 |
-
clip_model=self.config.clip_model,
|
| 40 |
test_mode=self.config.test_mode,
|
| 41 |
)
|
| 42 |
|
|
@@ -118,7 +118,6 @@ class SegVolModel(PreTrainedModel):
|
|
| 118 |
return logits_global_single
|
| 119 |
|
| 120 |
def forward_train(self, image, train_organs, train_labels):
|
| 121 |
-
print('in forward_train')
|
| 122 |
loss = self.model(image, text=None, boxes=None, points=None,
|
| 123 |
train_organs=train_organs,
|
| 124 |
train_labels=train_labels)
|
|
@@ -318,7 +317,6 @@ def generate_box(pred_pre, bbox_shift=None):
|
|
| 318 |
ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
|
| 319 |
if all(tensor.nelement() == 0 for tensor in ones_idx):
|
| 320 |
bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
|
| 321 |
-
# print(bboxes, bboxes.shape)
|
| 322 |
return bboxes
|
| 323 |
min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
|
| 324 |
max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
|
|
@@ -395,8 +393,6 @@ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_p
|
|
| 395 |
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
| 396 |
points = torch.cat((points, extra_negative_points), dim=0)
|
| 397 |
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
| 398 |
-
# print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
|
| 399 |
-
# print('==> points ', points.shape, labels)
|
| 400 |
|
| 401 |
if fix_extra_point_num is None:
|
| 402 |
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
|
@@ -415,7 +411,7 @@ import torch
|
|
| 415 |
import torch.nn as nn
|
| 416 |
import torch.nn.functional as F
|
| 417 |
import numpy as np
|
| 418 |
-
from transformers import
|
| 419 |
import random
|
| 420 |
|
| 421 |
#%% set up model
|
|
@@ -426,7 +422,7 @@ class SegVol(nn.Module):
|
|
| 426 |
prompt_encoder,
|
| 427 |
roi_size,
|
| 428 |
patch_size,
|
| 429 |
-
clip_model,
|
| 430 |
test_mode=False,
|
| 431 |
):
|
| 432 |
super().__init__()
|
|
@@ -434,7 +430,7 @@ class SegVol(nn.Module):
|
|
| 434 |
self.image_encoder = image_encoder
|
| 435 |
self.mask_decoder = mask_decoder
|
| 436 |
self.prompt_encoder = prompt_encoder
|
| 437 |
-
self.text_encoder = TextEncoder(
|
| 438 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 439 |
self.test_mode = test_mode
|
| 440 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
|
@@ -453,7 +449,6 @@ class SegVol(nn.Module):
|
|
| 453 |
|
| 454 |
# train mode
|
| 455 |
## sl
|
| 456 |
-
print('supervised_forward ready')
|
| 457 |
sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
|
| 458 |
## ssl
|
| 459 |
# ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
|
|
@@ -612,12 +607,12 @@ class SegVol(nn.Module):
|
|
| 612 |
# return pseudo_labels, bboxes
|
| 613 |
|
| 614 |
class TextEncoder(nn.Module):
|
| 615 |
-
def __init__(self
|
| 616 |
super().__init__()
|
| 617 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 618 |
config = CLIPTextConfig()
|
| 619 |
self.clip_text_model = CLIPTextModel(config)
|
| 620 |
-
self.tokenizer =
|
| 621 |
self.dim_align = nn.Linear(512, 768)
|
| 622 |
# freeze text encoder
|
| 623 |
for param in self.clip_text_model.parameters():
|
|
|
|
| 9 |
def __init__(
|
| 10 |
self,
|
| 11 |
test_mode=True,
|
| 12 |
+
# clip_model='.',
|
| 13 |
**kwargs,
|
| 14 |
):
|
| 15 |
self.spatial_size = [32, 256, 256]
|
| 16 |
self.patch_size = [4, 16, 16]
|
| 17 |
self.test_mode = test_mode
|
| 18 |
+
# self.clip_model = clip_model
|
| 19 |
super().__init__(**kwargs)
|
| 20 |
|
| 21 |
class SegVolModel(PreTrainedModel):
|
|
|
|
| 36 |
prompt_encoder=sam_model.prompt_encoder,
|
| 37 |
roi_size=self.config.spatial_size,
|
| 38 |
patch_size=self.config.patch_size,
|
| 39 |
+
# clip_model=self.config.clip_model,
|
| 40 |
test_mode=self.config.test_mode,
|
| 41 |
)
|
| 42 |
|
|
|
|
| 118 |
return logits_global_single
|
| 119 |
|
| 120 |
def forward_train(self, image, train_organs, train_labels):
|
|
|
|
| 121 |
loss = self.model(image, text=None, boxes=None, points=None,
|
| 122 |
train_organs=train_organs,
|
| 123 |
train_labels=train_labels)
|
|
|
|
| 317 |
ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
|
| 318 |
if all(tensor.nelement() == 0 for tensor in ones_idx):
|
| 319 |
bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
|
|
|
|
| 320 |
return bboxes
|
| 321 |
min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
|
| 322 |
max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
|
|
|
|
| 393 |
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
| 394 |
points = torch.cat((points, extra_negative_points), dim=0)
|
| 395 |
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
|
|
|
|
|
|
| 396 |
|
| 397 |
if fix_extra_point_num is None:
|
| 398 |
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
|
|
|
| 411 |
import torch.nn as nn
|
| 412 |
import torch.nn.functional as F
|
| 413 |
import numpy as np
|
| 414 |
+
from transformers import CLIPTextModel, CLIPTextConfig
|
| 415 |
import random
|
| 416 |
|
| 417 |
#%% set up model
|
|
|
|
| 422 |
prompt_encoder,
|
| 423 |
roi_size,
|
| 424 |
patch_size,
|
| 425 |
+
# clip_model,
|
| 426 |
test_mode=False,
|
| 427 |
):
|
| 428 |
super().__init__()
|
|
|
|
| 430 |
self.image_encoder = image_encoder
|
| 431 |
self.mask_decoder = mask_decoder
|
| 432 |
self.prompt_encoder = prompt_encoder
|
| 433 |
+
self.text_encoder = TextEncoder()
|
| 434 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 435 |
self.test_mode = test_mode
|
| 436 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
|
|
|
| 449 |
|
| 450 |
# train mode
|
| 451 |
## sl
|
|
|
|
| 452 |
sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
|
| 453 |
## ssl
|
| 454 |
# ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
|
|
|
|
| 607 |
# return pseudo_labels, bboxes
|
| 608 |
|
| 609 |
class TextEncoder(nn.Module):
|
| 610 |
+
def __init__(self):
|
| 611 |
super().__init__()
|
| 612 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 613 |
config = CLIPTextConfig()
|
| 614 |
self.clip_text_model = CLIPTextModel(config)
|
| 615 |
+
self.tokenizer = None
|
| 616 |
self.dim_align = nn.Linear(512, 768)
|
| 617 |
# freeze text encoder
|
| 618 |
for param in self.clip_text_model.parameters():
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip", "special_tokens_map_file": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip/special_tokens_map.json", "tokenizer_class": "CLIPTokenizer"}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|