Spaces:
Running
Running
init
#1
by
yuxindu
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- README.md +4 -5
- SegVol_v1.pth +0 -3
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +0 -339
- model/LICENSE +0 -21
- model/README.md +0 -74
- model/__pycache__/inference_cpu.cpython-39.pyc +0 -0
- model/asset/FLARE22_Tr_0002_0000.nii.gz +0 -3
- model/asset/FLARE22_Tr_0005_0000.nii.gz +0 -3
- model/asset/FLARE22_Tr_0034_0000.nii.gz +0 -3
- model/asset/FLARE22_Tr_0045_0000.nii.gz +0 -3
- model/asset/model.png +0 -0
- model/asset/overview back.png +0 -0
- model/asset/overview.png +0 -0
- model/config/clip/config.json +0 -157
- model/config/clip/special_tokens_map.json +0 -1
- model/config/clip/tokenizer.json +0 -0
- model/config/clip/tokenizer_config.json +0 -1
- model/config/clip/vocab.json +0 -0
- model/config/config_demo.json +0 -8
- model/data_process/__pycache__/demo_data_process.cpython-39.pyc +0 -0
- model/data_process/demo_data_process.py +0 -95
- model/inference_cpu.py +0 -172
- model/inference_demo.py +0 -219
- model/network/__pycache__/model.cpython-39.pyc +0 -0
- model/network/model.py +0 -91
- model/script/inference_demo.sh +0 -8
- model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py +0 -172
- model/segment_anything_volumetric/__init__.py +0 -12
- model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/automatic_mask_generator.py +0 -372
- model/segment_anything_volumetric/build_sam.py +0 -111
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py +0 -709
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py +0 -232
- model/segment_anything_volumetric/modeling/__init__.py +0 -11
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
README.md
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
---
|
| 2 |
title: SegVol
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
-
sdk_version: 1.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: SegVol
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: streamlit
|
| 7 |
+
sdk_version: 1.29.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
SegVol_v1.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
|
| 3 |
-
size 1660308695
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/utils.cpython-39.pyc
DELETED
|
Binary file (3.88 kB)
|
|
|
app.py
DELETED
|
@@ -1,339 +0,0 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
from streamlit_drawable_canvas import st_canvas
|
| 3 |
-
from streamlit_image_coordinates import streamlit_image_coordinates
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
from model.data_process.demo_data_process import process_ct_gt
|
| 7 |
-
import numpy as np
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
from PIL import Image, ImageDraw
|
| 10 |
-
import monai.transforms as transforms
|
| 11 |
-
from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
|
| 12 |
-
import nibabel as nib
|
| 13 |
-
import tempfile
|
| 14 |
-
|
| 15 |
-
print('script run')
|
| 16 |
-
|
| 17 |
-
#############################################
|
| 18 |
-
# init session_state
|
| 19 |
-
if 'option' not in st.session_state:
|
| 20 |
-
st.session_state.option = None
|
| 21 |
-
if 'text_prompt' not in st.session_state:
|
| 22 |
-
st.session_state.text_prompt = None
|
| 23 |
-
|
| 24 |
-
if 'reset_demo_case' not in st.session_state:
|
| 25 |
-
st.session_state.reset_demo_case = False
|
| 26 |
-
|
| 27 |
-
if 'preds_3D' not in st.session_state:
|
| 28 |
-
st.session_state.preds_3D = None
|
| 29 |
-
st.session_state.preds_3D_ori = None
|
| 30 |
-
|
| 31 |
-
if 'data_item' not in st.session_state:
|
| 32 |
-
st.session_state.data_item = None
|
| 33 |
-
|
| 34 |
-
if 'points' not in st.session_state:
|
| 35 |
-
st.session_state.points = []
|
| 36 |
-
|
| 37 |
-
if 'use_text_prompt' not in st.session_state:
|
| 38 |
-
st.session_state.use_text_prompt = False
|
| 39 |
-
|
| 40 |
-
if 'use_point_prompt' not in st.session_state:
|
| 41 |
-
st.session_state.use_point_prompt = False
|
| 42 |
-
|
| 43 |
-
if 'use_box_prompt' not in st.session_state:
|
| 44 |
-
st.session_state.use_box_prompt = False
|
| 45 |
-
|
| 46 |
-
if 'rectangle_3Dbox' not in st.session_state:
|
| 47 |
-
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
| 48 |
-
|
| 49 |
-
if 'irregular_box' not in st.session_state:
|
| 50 |
-
st.session_state.irregular_box = False
|
| 51 |
-
|
| 52 |
-
if 'running' not in st.session_state:
|
| 53 |
-
st.session_state.running = False
|
| 54 |
-
|
| 55 |
-
if 'transparency' not in st.session_state:
|
| 56 |
-
st.session_state.transparency = 0.25
|
| 57 |
-
|
| 58 |
-
case_list = [
|
| 59 |
-
'model/asset/FLARE22_Tr_0002_0000.nii.gz',
|
| 60 |
-
'model/asset/FLARE22_Tr_0005_0000.nii.gz',
|
| 61 |
-
'model/asset/FLARE22_Tr_0034_0000.nii.gz',
|
| 62 |
-
'model/asset/FLARE22_Tr_0045_0000.nii.gz'
|
| 63 |
-
]
|
| 64 |
-
|
| 65 |
-
#############################################
|
| 66 |
-
|
| 67 |
-
#############################################
|
| 68 |
-
# reset functions
|
| 69 |
-
def clear_prompts():
|
| 70 |
-
st.session_state.points = []
|
| 71 |
-
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
| 72 |
-
|
| 73 |
-
def reset_demo_case():
|
| 74 |
-
st.session_state.data_item = None
|
| 75 |
-
st.session_state.reset_demo_case = True
|
| 76 |
-
clear_prompts()
|
| 77 |
-
|
| 78 |
-
def clear_file():
|
| 79 |
-
st.session_state.option = None
|
| 80 |
-
process_ct_gt.clear()
|
| 81 |
-
reset_demo_case()
|
| 82 |
-
clear_prompts()
|
| 83 |
-
|
| 84 |
-
#############################################
|
| 85 |
-
|
| 86 |
-
st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
|
| 87 |
-
|
| 88 |
-
github_col, arxive_col = st.columns(2)
|
| 89 |
-
|
| 90 |
-
with github_col:
|
| 91 |
-
st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
|
| 92 |
-
|
| 93 |
-
with arxive_col:
|
| 94 |
-
st.write('Paper:https://arxiv.org/abs/2311.13385')
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# modify demo case here
|
| 98 |
-
demo_type = st.radio(
|
| 99 |
-
"Demo case source",
|
| 100 |
-
["Select", "Upload"],
|
| 101 |
-
on_change=clear_file
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
if demo_type=="Select":
|
| 105 |
-
uploaded_file = st.selectbox(
|
| 106 |
-
"Select a demo case",
|
| 107 |
-
case_list,
|
| 108 |
-
index=None,
|
| 109 |
-
placeholder="Select a demo case...",
|
| 110 |
-
on_change=reset_demo_case
|
| 111 |
-
)
|
| 112 |
-
else:
|
| 113 |
-
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
|
| 114 |
-
|
| 115 |
-
st.session_state.option = uploaded_file
|
| 116 |
-
|
| 117 |
-
if st.session_state.option is not None and \
|
| 118 |
-
st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
|
| 119 |
-
|
| 120 |
-
st.session_state.data_item = process_ct_gt(st.session_state.option)
|
| 121 |
-
st.session_state.reset_demo_case = False
|
| 122 |
-
st.session_state.preds_3D = None
|
| 123 |
-
st.session_state.preds_3D_ori = None
|
| 124 |
-
|
| 125 |
-
prompt_col1, prompt_col2 = st.columns(2)
|
| 126 |
-
|
| 127 |
-
with prompt_col1:
|
| 128 |
-
st.session_state.use_text_prompt = st.toggle('Sematic prompt')
|
| 129 |
-
text_prompt_type = st.radio(
|
| 130 |
-
"Sematic prompt type",
|
| 131 |
-
["Predefined", "Custom"],
|
| 132 |
-
disabled=(not st.session_state.use_text_prompt)
|
| 133 |
-
)
|
| 134 |
-
if text_prompt_type == "Predefined":
|
| 135 |
-
pre_text = st.selectbox(
|
| 136 |
-
"Predefined anatomical category:",
|
| 137 |
-
['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
|
| 138 |
-
index=None,
|
| 139 |
-
disabled=(not st.session_state.use_text_prompt)
|
| 140 |
-
)
|
| 141 |
-
else:
|
| 142 |
-
pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
|
| 143 |
-
disabled=(not st.session_state.use_text_prompt))
|
| 144 |
-
if pre_text is None or len(pre_text) > 0:
|
| 145 |
-
st.session_state.text_prompt = pre_text
|
| 146 |
-
else:
|
| 147 |
-
st.session_state.text_prompt = None
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
with prompt_col2:
|
| 151 |
-
spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
|
| 152 |
-
spatial_prompt = st.radio(
|
| 153 |
-
"Spatial prompt type",
|
| 154 |
-
["Point prompt", "Box prompt"],
|
| 155 |
-
on_change=clear_prompts,
|
| 156 |
-
disabled=(not spatial_prompt_on))
|
| 157 |
-
st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
|
| 158 |
-
|
| 159 |
-
if spatial_prompt == "Point prompt":
|
| 160 |
-
st.session_state.use_point_prompt = True
|
| 161 |
-
st.session_state.use_box_prompt = False
|
| 162 |
-
elif spatial_prompt == "Box prompt":
|
| 163 |
-
st.session_state.use_box_prompt = True
|
| 164 |
-
st.session_state.use_point_prompt = False
|
| 165 |
-
else:
|
| 166 |
-
st.session_state.use_point_prompt = False
|
| 167 |
-
st.session_state.use_box_prompt = False
|
| 168 |
-
|
| 169 |
-
if not spatial_prompt_on:
|
| 170 |
-
st.session_state.use_point_prompt = False
|
| 171 |
-
st.session_state.use_box_prompt = False
|
| 172 |
-
|
| 173 |
-
if not st.session_state.use_text_prompt:
|
| 174 |
-
st.session_state.text_prompt = None
|
| 175 |
-
|
| 176 |
-
if st.session_state.option is None:
|
| 177 |
-
st.write('please select demo case first')
|
| 178 |
-
else:
|
| 179 |
-
image_3D = st.session_state.data_item['z_image'][0].numpy()
|
| 180 |
-
col_control1, col_control2 = st.columns(2)
|
| 181 |
-
|
| 182 |
-
with col_control1:
|
| 183 |
-
selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
|
| 184 |
-
|
| 185 |
-
with col_control2:
|
| 186 |
-
selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
|
| 187 |
-
if st.session_state.use_box_prompt:
|
| 188 |
-
top, bottom = st.select_slider(
|
| 189 |
-
'Top and bottom of box',
|
| 190 |
-
options=range(0, 325),
|
| 191 |
-
value=(0, 324),
|
| 192 |
-
disabled=st.session_state.running
|
| 193 |
-
)
|
| 194 |
-
st.session_state.rectangle_3Dbox[0] = top
|
| 195 |
-
st.session_state.rectangle_3Dbox[3] = bottom
|
| 196 |
-
col_image1, col_image2 = st.columns(2)
|
| 197 |
-
|
| 198 |
-
if st.session_state.preds_3D is not None:
|
| 199 |
-
st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
|
| 200 |
-
|
| 201 |
-
with col_image1:
|
| 202 |
-
|
| 203 |
-
image_z_array = image_3D[selected_index_z]
|
| 204 |
-
|
| 205 |
-
preds_z_array = None
|
| 206 |
-
if st.session_state.preds_3D is not None:
|
| 207 |
-
preds_z_array = st.session_state.preds_3D[selected_index_z]
|
| 208 |
-
|
| 209 |
-
image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
if st.session_state.use_point_prompt:
|
| 213 |
-
value_xy = streamlit_image_coordinates(image_z, width=325)
|
| 214 |
-
|
| 215 |
-
if value_xy is not None:
|
| 216 |
-
point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
|
| 217 |
-
if len(st.session_state.points) >= 3:
|
| 218 |
-
st.warning('Max point num is 3', icon="⚠️")
|
| 219 |
-
elif point_ax_xy not in st.session_state.points:
|
| 220 |
-
st.session_state.points.append(point_ax_xy)
|
| 221 |
-
print('point_ax_xy add rerun')
|
| 222 |
-
st.rerun()
|
| 223 |
-
elif st.session_state.use_box_prompt:
|
| 224 |
-
canvas_result_xy = st_canvas(
|
| 225 |
-
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
| 226 |
-
stroke_width=3,
|
| 227 |
-
stroke_color='#2909F1',
|
| 228 |
-
background_image=image_z,
|
| 229 |
-
update_streamlit=True,
|
| 230 |
-
height=325,
|
| 231 |
-
width=325,
|
| 232 |
-
drawing_mode='transform',
|
| 233 |
-
point_display_radius=0,
|
| 234 |
-
key="canvas_xy",
|
| 235 |
-
initial_drawing=initial_rectangle,
|
| 236 |
-
display_toolbar=True
|
| 237 |
-
)
|
| 238 |
-
try:
|
| 239 |
-
print(canvas_result_xy.json_data['objects'][0]['angle'])
|
| 240 |
-
if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
|
| 241 |
-
st.warning('Rotating is undefined behavior', icon="⚠️")
|
| 242 |
-
st.session_state.irregular_box = True
|
| 243 |
-
else:
|
| 244 |
-
st.session_state.irregular_box = False
|
| 245 |
-
reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
|
| 246 |
-
except:
|
| 247 |
-
print('exception')
|
| 248 |
-
pass
|
| 249 |
-
else:
|
| 250 |
-
st.image(image_z, use_column_width=False)
|
| 251 |
-
|
| 252 |
-
with col_image2:
|
| 253 |
-
image_y_array = image_3D[:, selected_index_y, :]
|
| 254 |
-
|
| 255 |
-
preds_y_array = None
|
| 256 |
-
if st.session_state.preds_3D is not None:
|
| 257 |
-
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
|
| 258 |
-
|
| 259 |
-
image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
|
| 260 |
-
|
| 261 |
-
if st.session_state.use_point_prompt:
|
| 262 |
-
value_yz = streamlit_image_coordinates(image_y, width=325)
|
| 263 |
-
|
| 264 |
-
if value_yz is not None:
|
| 265 |
-
point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
|
| 266 |
-
if len(st.session_state.points) >= 3:
|
| 267 |
-
st.warning('Max point num is 3', icon="⚠️")
|
| 268 |
-
elif point_ax_xz not in st.session_state.points:
|
| 269 |
-
st.session_state.points.append(point_ax_xz)
|
| 270 |
-
print('point_ax_xz add rerun')
|
| 271 |
-
st.rerun()
|
| 272 |
-
elif st.session_state.use_box_prompt:
|
| 273 |
-
if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
|
| 274 |
-
draw = ImageDraw.Draw(image_y)
|
| 275 |
-
#rectangle xz view (upper-left and lower-right)
|
| 276 |
-
rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
|
| 277 |
-
(st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
|
| 278 |
-
# Draw the rectangle on the image
|
| 279 |
-
draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
|
| 280 |
-
st.image(image_y, use_column_width=False)
|
| 281 |
-
else:
|
| 282 |
-
st.image(image_y, use_column_width=False)
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
col1, col2, col3 = st.columns(3)
|
| 286 |
-
|
| 287 |
-
with col1:
|
| 288 |
-
if st.button("Clear", use_container_width=True,
|
| 289 |
-
disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
|
| 290 |
-
clear_prompts()
|
| 291 |
-
st.session_state.preds_3D = None
|
| 292 |
-
st.session_state.preds_3D_ori = None
|
| 293 |
-
st.rerun()
|
| 294 |
-
|
| 295 |
-
with col2:
|
| 296 |
-
img_nii = None
|
| 297 |
-
if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
|
| 298 |
-
meta_dict = st.session_state.data_item['meta']
|
| 299 |
-
foreground_start_coord = st.session_state.data_item['foreground_start_coord']
|
| 300 |
-
foreground_end_coord = st.session_state.data_item['foreground_end_coord']
|
| 301 |
-
original_shape = st.session_state.data_item['ori_shape']
|
| 302 |
-
pred_array = st.session_state.preds_3D_ori
|
| 303 |
-
original_array = np.zeros(original_shape)
|
| 304 |
-
original_array[foreground_start_coord[0]:foreground_end_coord[0],
|
| 305 |
-
foreground_start_coord[1]:foreground_end_coord[1],
|
| 306 |
-
foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
|
| 307 |
-
|
| 308 |
-
original_array = original_array.transpose(2, 1, 0)
|
| 309 |
-
img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
|
| 310 |
-
|
| 311 |
-
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
|
| 312 |
-
nib.save(img_nii, tmpfile.name)
|
| 313 |
-
with open(tmpfile.name, "rb") as f:
|
| 314 |
-
bytes_data = f.read()
|
| 315 |
-
st.download_button(
|
| 316 |
-
label="Download result(.nii.gz)",
|
| 317 |
-
data=bytes_data,
|
| 318 |
-
file_name="segvol_preds.nii.gz",
|
| 319 |
-
mime="application/octet-stream",
|
| 320 |
-
disabled=img_nii is None
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
with col3:
|
| 324 |
-
run_button_name = 'Run'if not st.session_state.running else 'Running'
|
| 325 |
-
if st.button(run_button_name, type="primary", use_container_width=True,
|
| 326 |
-
disabled=(
|
| 327 |
-
st.session_state.data_item is None or
|
| 328 |
-
(st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
|
| 329 |
-
st.session_state.irregular_box or
|
| 330 |
-
st.session_state.running
|
| 331 |
-
)):
|
| 332 |
-
st.session_state.running = True
|
| 333 |
-
st.rerun()
|
| 334 |
-
|
| 335 |
-
if st.session_state.running:
|
| 336 |
-
st.session_state.running = False
|
| 337 |
-
with st.status("Running...", expanded=False) as status:
|
| 338 |
-
run()
|
| 339 |
-
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/LICENSE
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
MIT License
|
| 2 |
-
|
| 3 |
-
Copyright (c) 2023 BAAI-DCAI
|
| 4 |
-
|
| 5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
-
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
-
in the Software without restriction, including without limitation the rights
|
| 8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
-
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
-
furnished to do so, subject to the following conditions:
|
| 11 |
-
|
| 12 |
-
The above copyright notice and this permission notice shall be included in all
|
| 13 |
-
copies or substantial portions of the Software.
|
| 14 |
-
|
| 15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/README.md
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
# SegVol: Universal and Interactive Volumetric Medical Image Segmentation
|
| 2 |
-
This repo is the official implementation of [SegVol: Universal and Interactive Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.13385).
|
| 3 |
-
|
| 4 |
-
## News🚀
|
| 5 |
-
(2023.11.24) *You can download weight files of SegVol and ViT(CTs pre-train) [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link).* 🔥
|
| 6 |
-
|
| 7 |
-
(2023.11.23) *The brief introduction and instruction have been uploaded.*
|
| 8 |
-
|
| 9 |
-
(2023.11.23) *The inference demo code has been uploaded.*
|
| 10 |
-
|
| 11 |
-
(2023.11.22) *The first edition of our paper has been uploaded to arXiv.* 📃
|
| 12 |
-
|
| 13 |
-
## Introduction
|
| 14 |
-
<img src="https://github.com/BAAI-DCAI/SegVol/blob/main/asset/overview.png" width="60%" height="60%">
|
| 15 |
-
|
| 16 |
-
The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
|
| 17 |
-
|
| 18 |
-
We will release SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k CTs).
|
| 19 |
-
|
| 20 |
-
## Usage
|
| 21 |
-
### Requirements
|
| 22 |
-
The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or higher virsion) is needed first. Following install key requirements using commands:
|
| 23 |
-
|
| 24 |
-
```
|
| 25 |
-
pip install 'monai[all]==0.9.0'
|
| 26 |
-
pip install einops==0.6.1
|
| 27 |
-
pip install transformers==4.18.0
|
| 28 |
-
pip install matplotlib
|
| 29 |
-
```
|
| 30 |
-
### Config and run demo script
|
| 31 |
-
1. You can download the demo case [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link), or download the whole demo dataset [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
|
| 32 |
-
2. Please set CT path and Ground Truth path of the case in the [config_demo.json](https://github.com/BAAI-DCAI/SegVol/blob/main/config/config_demo.json).
|
| 33 |
-
3. After that, config the [inference_demo.sh](https://github.com/BAAI-DCAI/SegVol/blob/main/script/inference_demo.sh) file for execution:
|
| 34 |
-
|
| 35 |
-
- `$segvol_ckpt`: the path of SegVol's checkpoint (Download from [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link)).
|
| 36 |
-
|
| 37 |
-
- `$work_dir`: any path of folder you want to save the log files and visualizaion results.
|
| 38 |
-
|
| 39 |
-
4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** [here](https://github.com/BAAI-DCAI/SegVol/blob/35f3ff9c943a74f630e6948051a1fe21aaba91bc/inference_demo.py#L208C11-L208C11).
|
| 40 |
-
5. Now, just run `bash script/inference_demo.sh` to infer your demo case.
|
| 41 |
-
|
| 42 |
-
## Citation
|
| 43 |
-
If you find this repository helpful, please consider citing:
|
| 44 |
-
```
|
| 45 |
-
@misc{du2023segvol,
|
| 46 |
-
title={SegVol: Universal and Interactive Volumetric Medical Image Segmentation},
|
| 47 |
-
author={Yuxin Du and Fan Bai and Tiejun Huang and Bo Zhao},
|
| 48 |
-
year={2023},
|
| 49 |
-
eprint={2311.13385},
|
| 50 |
-
archivePrefix={arXiv},
|
| 51 |
-
primaryClass={cs.CV}
|
| 52 |
-
}
|
| 53 |
-
```
|
| 54 |
-
|
| 55 |
-
## Acknowledgement
|
| 56 |
-
Thanks for the following amazing works:
|
| 57 |
-
|
| 58 |
-
[HuggingFace](https://huggingface.co/).
|
| 59 |
-
|
| 60 |
-
[CLIP](https://github.com/openai/CLIP).
|
| 61 |
-
|
| 62 |
-
[MONAI](https://github.com/Project-MONAI/MONAI).
|
| 63 |
-
|
| 64 |
-
[Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.
|
| 65 |
-
|
| 66 |
-
[Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.
|
| 67 |
-
|
| 68 |
-
[Image by pch.vector](https://www.freepik.com/free-vector/different-phone-hand-gestures-set_9649376.htm#query=Vector%20touch%20screen%20hand%20gestures&position=4&from_view=search&track=ais) on Freepik.
|
| 69 |
-
|
| 70 |
-
[Image by starline](https://www.freepik.com/free-vector/set-three-light-bulb-represent-effective-business-idea-concept_37588597.htm#query=idea&position=0&from_view=search&track=sph) on Freepik.
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/__pycache__/inference_cpu.cpython-39.pyc
DELETED
|
Binary file (4.77 kB)
|
|
|
model/asset/FLARE22_Tr_0002_0000.nii.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:eb16eced003524fa005e28b2822c0b53503f1223d758cdf72528fad359aa10ba
|
| 3 |
-
size 30611274
|
|
|
|
|
|
|
|
|
|
|
|
model/asset/FLARE22_Tr_0005_0000.nii.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2be5019bfc7e805d5e24785bcd44ffe7720e13e38b2a3124ad25b454811b221c
|
| 3 |
-
size 26615527
|
|
|
|
|
|
|
|
|
|
|
|
model/asset/FLARE22_Tr_0034_0000.nii.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:023c5d06ea2a6c8866c1e214ecee06a4447a8d0c50225142cdfdbbccc2bf8c66
|
| 3 |
-
size 28821917
|
|
|
|
|
|
|
|
|
|
|
|
model/asset/FLARE22_Tr_0045_0000.nii.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:336b3719af673fd6fafe89d7d5d95d5f18239a9faccde9753703fc1465f43736
|
| 3 |
-
size 32885093
|
|
|
|
|
|
|
|
|
|
|
|
model/asset/model.png
DELETED
|
Binary file (888 kB)
|
|
|
model/asset/overview back.png
DELETED
|
Binary file (237 kB)
|
|
|
model/asset/overview.png
DELETED
|
Binary file (226 kB)
|
|
|
model/config/clip/config.json
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_name_or_path": "openai/clip-vit-base-patch32",
|
| 3 |
-
"architectures": [
|
| 4 |
-
"CLIPModel"
|
| 5 |
-
],
|
| 6 |
-
"initializer_factor": 1.0,
|
| 7 |
-
"logit_scale_init_value": 2.6592,
|
| 8 |
-
"model_type": "clip",
|
| 9 |
-
"projection_dim": 512,
|
| 10 |
-
"text_config": {
|
| 11 |
-
"_name_or_path": "",
|
| 12 |
-
"add_cross_attention": false,
|
| 13 |
-
"architectures": null,
|
| 14 |
-
"attention_dropout": 0.0,
|
| 15 |
-
"bad_words_ids": null,
|
| 16 |
-
"bos_token_id": 0,
|
| 17 |
-
"chunk_size_feed_forward": 0,
|
| 18 |
-
"cross_attention_hidden_size": null,
|
| 19 |
-
"decoder_start_token_id": null,
|
| 20 |
-
"diversity_penalty": 0.0,
|
| 21 |
-
"do_sample": false,
|
| 22 |
-
"dropout": 0.0,
|
| 23 |
-
"early_stopping": false,
|
| 24 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 25 |
-
"eos_token_id": 2,
|
| 26 |
-
"finetuning_task": null,
|
| 27 |
-
"forced_bos_token_id": null,
|
| 28 |
-
"forced_eos_token_id": null,
|
| 29 |
-
"hidden_act": "quick_gelu",
|
| 30 |
-
"hidden_size": 512,
|
| 31 |
-
"id2label": {
|
| 32 |
-
"0": "LABEL_0",
|
| 33 |
-
"1": "LABEL_1"
|
| 34 |
-
},
|
| 35 |
-
"initializer_factor": 1.0,
|
| 36 |
-
"initializer_range": 0.02,
|
| 37 |
-
"intermediate_size": 2048,
|
| 38 |
-
"is_decoder": false,
|
| 39 |
-
"is_encoder_decoder": false,
|
| 40 |
-
"label2id": {
|
| 41 |
-
"LABEL_0": 0,
|
| 42 |
-
"LABEL_1": 1
|
| 43 |
-
},
|
| 44 |
-
"layer_norm_eps": 1e-05,
|
| 45 |
-
"length_penalty": 1.0,
|
| 46 |
-
"max_length": 20,
|
| 47 |
-
"max_position_embeddings": 77,
|
| 48 |
-
"min_length": 0,
|
| 49 |
-
"model_type": "clip_text_model",
|
| 50 |
-
"no_repeat_ngram_size": 0,
|
| 51 |
-
"num_attention_heads": 8,
|
| 52 |
-
"num_beam_groups": 1,
|
| 53 |
-
"num_beams": 1,
|
| 54 |
-
"num_hidden_layers": 12,
|
| 55 |
-
"num_return_sequences": 1,
|
| 56 |
-
"output_attentions": false,
|
| 57 |
-
"output_hidden_states": false,
|
| 58 |
-
"output_scores": false,
|
| 59 |
-
"pad_token_id": 1,
|
| 60 |
-
"prefix": null,
|
| 61 |
-
"projection_dim": 512,
|
| 62 |
-
"problem_type": null,
|
| 63 |
-
"pruned_heads": {},
|
| 64 |
-
"remove_invalid_values": false,
|
| 65 |
-
"repetition_penalty": 1.0,
|
| 66 |
-
"return_dict": true,
|
| 67 |
-
"return_dict_in_generate": false,
|
| 68 |
-
"sep_token_id": null,
|
| 69 |
-
"task_specific_params": null,
|
| 70 |
-
"temperature": 1.0,
|
| 71 |
-
"tie_encoder_decoder": false,
|
| 72 |
-
"tie_word_embeddings": true,
|
| 73 |
-
"tokenizer_class": null,
|
| 74 |
-
"top_k": 50,
|
| 75 |
-
"top_p": 1.0,
|
| 76 |
-
"torch_dtype": null,
|
| 77 |
-
"torchscript": false,
|
| 78 |
-
"transformers_version": "4.16.0.dev0",
|
| 79 |
-
"use_bfloat16": false,
|
| 80 |
-
"vocab_size": 49408
|
| 81 |
-
},
|
| 82 |
-
"text_config_dict": null,
|
| 83 |
-
"transformers_version": null,
|
| 84 |
-
"vision_config": {
|
| 85 |
-
"_name_or_path": "",
|
| 86 |
-
"add_cross_attention": false,
|
| 87 |
-
"architectures": null,
|
| 88 |
-
"attention_dropout": 0.0,
|
| 89 |
-
"bad_words_ids": null,
|
| 90 |
-
"bos_token_id": null,
|
| 91 |
-
"chunk_size_feed_forward": 0,
|
| 92 |
-
"cross_attention_hidden_size": null,
|
| 93 |
-
"decoder_start_token_id": null,
|
| 94 |
-
"diversity_penalty": 0.0,
|
| 95 |
-
"do_sample": false,
|
| 96 |
-
"dropout": 0.0,
|
| 97 |
-
"early_stopping": false,
|
| 98 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 99 |
-
"eos_token_id": null,
|
| 100 |
-
"finetuning_task": null,
|
| 101 |
-
"forced_bos_token_id": null,
|
| 102 |
-
"forced_eos_token_id": null,
|
| 103 |
-
"hidden_act": "quick_gelu",
|
| 104 |
-
"hidden_size": 768,
|
| 105 |
-
"id2label": {
|
| 106 |
-
"0": "LABEL_0",
|
| 107 |
-
"1": "LABEL_1"
|
| 108 |
-
},
|
| 109 |
-
"image_size": 224,
|
| 110 |
-
"initializer_factor": 1.0,
|
| 111 |
-
"initializer_range": 0.02,
|
| 112 |
-
"intermediate_size": 3072,
|
| 113 |
-
"is_decoder": false,
|
| 114 |
-
"is_encoder_decoder": false,
|
| 115 |
-
"label2id": {
|
| 116 |
-
"LABEL_0": 0,
|
| 117 |
-
"LABEL_1": 1
|
| 118 |
-
},
|
| 119 |
-
"layer_norm_eps": 1e-05,
|
| 120 |
-
"length_penalty": 1.0,
|
| 121 |
-
"max_length": 20,
|
| 122 |
-
"min_length": 0,
|
| 123 |
-
"model_type": "clip_vision_model",
|
| 124 |
-
"no_repeat_ngram_size": 0,
|
| 125 |
-
"num_attention_heads": 12,
|
| 126 |
-
"num_beam_groups": 1,
|
| 127 |
-
"num_beams": 1,
|
| 128 |
-
"num_hidden_layers": 12,
|
| 129 |
-
"num_return_sequences": 1,
|
| 130 |
-
"output_attentions": false,
|
| 131 |
-
"output_hidden_states": false,
|
| 132 |
-
"output_scores": false,
|
| 133 |
-
"pad_token_id": null,
|
| 134 |
-
"patch_size": 32,
|
| 135 |
-
"prefix": null,
|
| 136 |
-
"projection_dim" : 512,
|
| 137 |
-
"problem_type": null,
|
| 138 |
-
"pruned_heads": {},
|
| 139 |
-
"remove_invalid_values": false,
|
| 140 |
-
"repetition_penalty": 1.0,
|
| 141 |
-
"return_dict": true,
|
| 142 |
-
"return_dict_in_generate": false,
|
| 143 |
-
"sep_token_id": null,
|
| 144 |
-
"task_specific_params": null,
|
| 145 |
-
"temperature": 1.0,
|
| 146 |
-
"tie_encoder_decoder": false,
|
| 147 |
-
"tie_word_embeddings": true,
|
| 148 |
-
"tokenizer_class": null,
|
| 149 |
-
"top_k": 50,
|
| 150 |
-
"top_p": 1.0,
|
| 151 |
-
"torch_dtype": null,
|
| 152 |
-
"torchscript": false,
|
| 153 |
-
"transformers_version": "4.16.0.dev0",
|
| 154 |
-
"use_bfloat16": false
|
| 155 |
-
},
|
| 156 |
-
"vision_config_dict": null
|
| 157 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/config/clip/special_tokens_map.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
|
|
|
|
|
model/config/clip/tokenizer.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/config/clip/tokenizer_config.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
|
|
|
|
|
|
model/config/clip/vocab.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/config/config_demo.json
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"dataset_name": "AbdomenCT-1k",
|
| 3 |
-
"categories": ["liver", "kidney", "spleen", "pancreas"],
|
| 4 |
-
"demo_case": {
|
| 5 |
-
"ct_path": "path/to/Case_image",
|
| 6 |
-
"gt_path": "path/to/Case_label"
|
| 7 |
-
}
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/data_process/__pycache__/demo_data_process.cpython-39.pyc
DELETED
|
Binary file (3.4 kB)
|
|
|
model/data_process/demo_data_process.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import monai.transforms as transforms
|
| 3 |
-
import streamlit as st
|
| 4 |
-
import tempfile
|
| 5 |
-
|
| 6 |
-
class MinMaxNormalization(transforms.Transform):
|
| 7 |
-
def __call__(self, data):
|
| 8 |
-
d = dict(data)
|
| 9 |
-
k = "image"
|
| 10 |
-
d[k] = d[k] - d[k].min()
|
| 11 |
-
d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
|
| 12 |
-
return d
|
| 13 |
-
|
| 14 |
-
class DimTranspose(transforms.Transform):
|
| 15 |
-
def __init__(self, keys):
|
| 16 |
-
self.keys = keys
|
| 17 |
-
|
| 18 |
-
def __call__(self, data):
|
| 19 |
-
d = dict(data)
|
| 20 |
-
for key in self.keys:
|
| 21 |
-
d[key] = np.swapaxes(d[key], -1, -3)
|
| 22 |
-
return d
|
| 23 |
-
|
| 24 |
-
class ForegroundNormalization(transforms.Transform):
|
| 25 |
-
def __init__(self, keys):
|
| 26 |
-
self.keys = keys
|
| 27 |
-
|
| 28 |
-
def __call__(self, data):
|
| 29 |
-
d = dict(data)
|
| 30 |
-
|
| 31 |
-
for key in self.keys:
|
| 32 |
-
d[key] = self.normalize(d[key])
|
| 33 |
-
return d
|
| 34 |
-
|
| 35 |
-
def normalize(self, ct_narray):
|
| 36 |
-
ct_voxel_ndarray = ct_narray.copy()
|
| 37 |
-
ct_voxel_ndarray = ct_voxel_ndarray.flatten()
|
| 38 |
-
thred = np.mean(ct_voxel_ndarray)
|
| 39 |
-
voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
|
| 40 |
-
upper_bound = np.percentile(voxel_filtered, 99.95)
|
| 41 |
-
lower_bound = np.percentile(voxel_filtered, 00.05)
|
| 42 |
-
mean = np.mean(voxel_filtered)
|
| 43 |
-
std = np.std(voxel_filtered)
|
| 44 |
-
### transform ###
|
| 45 |
-
ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
|
| 46 |
-
ct_narray = (ct_narray - mean) / max(std, 1e-8)
|
| 47 |
-
return ct_narray
|
| 48 |
-
|
| 49 |
-
@st.cache_data
|
| 50 |
-
def process_ct_gt(case_path, spatial_size=(32,256,256)):
|
| 51 |
-
if case_path is None:
|
| 52 |
-
return None
|
| 53 |
-
print('Data preprocessing...')
|
| 54 |
-
# transform
|
| 55 |
-
img_loader = transforms.LoadImage(dtype=np.float32)
|
| 56 |
-
transform = transforms.Compose(
|
| 57 |
-
[
|
| 58 |
-
transforms.Orientationd(keys=["image"], axcodes="RAS"),
|
| 59 |
-
ForegroundNormalization(keys=["image"]),
|
| 60 |
-
DimTranspose(keys=["image"]),
|
| 61 |
-
MinMaxNormalization(),
|
| 62 |
-
transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
|
| 63 |
-
transforms.CropForegroundd(keys=["image"], source_key="image"),
|
| 64 |
-
transforms.ToTensord(keys=["image"]),
|
| 65 |
-
]
|
| 66 |
-
)
|
| 67 |
-
zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
|
| 68 |
-
z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
|
| 69 |
-
###
|
| 70 |
-
item = {}
|
| 71 |
-
# generate ct_voxel_ndarray
|
| 72 |
-
if type(case_path) is str:
|
| 73 |
-
ct_voxel_ndarray, meta_tensor_dict = img_loader(case_path)
|
| 74 |
-
else:
|
| 75 |
-
bytes_data = case_path.read()
|
| 76 |
-
with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
|
| 77 |
-
tmp.write(bytes_data)
|
| 78 |
-
tmp.seek(0)
|
| 79 |
-
ct_voxel_ndarray, meta_tensor_dict = img_loader(tmp.name)
|
| 80 |
-
|
| 81 |
-
ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
|
| 82 |
-
ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
|
| 83 |
-
item['image'] = ct_voxel_ndarray
|
| 84 |
-
ori_shape = np.swapaxes(ct_voxel_ndarray, -1, -3).shape[1:]
|
| 85 |
-
|
| 86 |
-
# transform
|
| 87 |
-
item = transform(item)
|
| 88 |
-
item_zoom_out = zoom_out_transform(item)
|
| 89 |
-
item['zoom_out_image'] = item_zoom_out['image']
|
| 90 |
-
item['ori_shape'] = ori_shape
|
| 91 |
-
|
| 92 |
-
item_z = z_transform(item)
|
| 93 |
-
item['z_image'] = item_z['image']
|
| 94 |
-
item['meta'] = meta_tensor_dict
|
| 95 |
-
return item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/inference_cpu.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import json
|
| 6 |
-
import monai.transforms as transforms
|
| 7 |
-
|
| 8 |
-
from model.segment_anything_volumetric import sam_model_registry
|
| 9 |
-
from model.network.model import SegVol
|
| 10 |
-
from model.data_process.demo_data_process import process_ct_gt
|
| 11 |
-
from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
| 12 |
-
from model.utils.visualize import draw_result
|
| 13 |
-
import streamlit as st
|
| 14 |
-
|
| 15 |
-
def set_parse():
|
| 16 |
-
# %% set up parser
|
| 17 |
-
parser = argparse.ArgumentParser()
|
| 18 |
-
parser.add_argument("--test_mode", default=True, type=bool)
|
| 19 |
-
parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth')
|
| 20 |
-
parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
|
| 21 |
-
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
| 22 |
-
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
| 23 |
-
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
| 24 |
-
### demo
|
| 25 |
-
parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
|
| 26 |
-
args = parser.parse_args()
|
| 27 |
-
return args
|
| 28 |
-
|
| 29 |
-
def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
|
| 30 |
-
image_single_resize = image_resize
|
| 31 |
-
image_single = image[0,0]
|
| 32 |
-
ori_shape = image_single.shape
|
| 33 |
-
resize_shape = image_single_resize.shape[2:]
|
| 34 |
-
|
| 35 |
-
# generate prompts
|
| 36 |
-
text_single = None if text_prompt is None else [text_prompt]
|
| 37 |
-
points_single = None
|
| 38 |
-
box_single = None
|
| 39 |
-
|
| 40 |
-
if args.use_point_prompt:
|
| 41 |
-
point, point_label = point_prompt
|
| 42 |
-
points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
|
| 43 |
-
binary_points_resize = build_binary_points(point, point_label, resize_shape)
|
| 44 |
-
if args.use_box_prompt:
|
| 45 |
-
box_single = box_prompt.unsqueeze(0).float()
|
| 46 |
-
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
|
| 47 |
-
|
| 48 |
-
####################
|
| 49 |
-
# zoom-out inference:
|
| 50 |
-
print('--- zoom out inference ---')
|
| 51 |
-
print(text_single)
|
| 52 |
-
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
| 53 |
-
with torch.no_grad():
|
| 54 |
-
logits_global_single = segvol_model(image_single_resize,
|
| 55 |
-
text=text_single,
|
| 56 |
-
boxes=box_single,
|
| 57 |
-
points=points_single)
|
| 58 |
-
|
| 59 |
-
# resize back global logits
|
| 60 |
-
logits_global_single = F.interpolate(
|
| 61 |
-
logits_global_single.cpu(),
|
| 62 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 63 |
-
|
| 64 |
-
# build prompt reflection for zoom-in
|
| 65 |
-
if args.use_point_prompt:
|
| 66 |
-
binary_points = F.interpolate(
|
| 67 |
-
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 68 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 69 |
-
if args.use_box_prompt:
|
| 70 |
-
binary_cube = F.interpolate(
|
| 71 |
-
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 72 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 73 |
-
# draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
|
| 74 |
-
if not args.use_zoom_in:
|
| 75 |
-
return logits_global_single
|
| 76 |
-
|
| 77 |
-
####################
|
| 78 |
-
# zoom-in inference:
|
| 79 |
-
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
| 80 |
-
if min_d is None:
|
| 81 |
-
print('Fail to detect foreground!')
|
| 82 |
-
return logits_global_single
|
| 83 |
-
|
| 84 |
-
# Crop roi
|
| 85 |
-
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 86 |
-
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
| 87 |
-
|
| 88 |
-
assert not (args.use_box_prompt and args.use_point_prompt)
|
| 89 |
-
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 90 |
-
prompt_reflection = None
|
| 91 |
-
if args.use_box_prompt:
|
| 92 |
-
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 93 |
-
prompt_reflection = (
|
| 94 |
-
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
| 95 |
-
global_preds.unsqueeze(0).unsqueeze(0)
|
| 96 |
-
)
|
| 97 |
-
if args.use_point_prompt:
|
| 98 |
-
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 99 |
-
prompt_reflection = (
|
| 100 |
-
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
| 101 |
-
global_preds.unsqueeze(0).unsqueeze(0)
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
## inference
|
| 105 |
-
with torch.no_grad():
|
| 106 |
-
logits_single_cropped = sliding_window_inference(
|
| 107 |
-
image_single_cropped, prompt_reflection,
|
| 108 |
-
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
| 109 |
-
text=text_single,
|
| 110 |
-
use_box=args.use_box_prompt,
|
| 111 |
-
use_point=args.use_point_prompt,
|
| 112 |
-
logits_global_single=logits_global_single,
|
| 113 |
-
)
|
| 114 |
-
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
| 115 |
-
if logits_single_cropped.shape != logits_global_single.shape:
|
| 116 |
-
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
| 117 |
-
|
| 118 |
-
return logits_global_single
|
| 119 |
-
|
| 120 |
-
@st.cache_resource
|
| 121 |
-
def build_model():
|
| 122 |
-
# build model
|
| 123 |
-
st.write('building model')
|
| 124 |
-
clip_ckpt = 'model/config/clip'
|
| 125 |
-
resume = 'SegVol_v1.pth'
|
| 126 |
-
sam_model = sam_model_registry['vit']()
|
| 127 |
-
segvol_model = SegVol(
|
| 128 |
-
image_encoder=sam_model.image_encoder,
|
| 129 |
-
mask_decoder=sam_model.mask_decoder,
|
| 130 |
-
prompt_encoder=sam_model.prompt_encoder,
|
| 131 |
-
clip_ckpt=clip_ckpt,
|
| 132 |
-
roi_size=(32,256,256),
|
| 133 |
-
patch_size=(4,16,16),
|
| 134 |
-
test_mode=True,
|
| 135 |
-
)
|
| 136 |
-
segvol_model = torch.nn.DataParallel(segvol_model)
|
| 137 |
-
segvol_model.eval()
|
| 138 |
-
# load param
|
| 139 |
-
if os.path.isfile(resume):
|
| 140 |
-
## Map model to be loaded to specified single GPU
|
| 141 |
-
loc = 'cpu'
|
| 142 |
-
checkpoint = torch.load(resume, map_location=loc)
|
| 143 |
-
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
| 144 |
-
print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
|
| 145 |
-
print('model build done!')
|
| 146 |
-
return segvol_model
|
| 147 |
-
|
| 148 |
-
@st.cache_data
|
| 149 |
-
def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
|
| 150 |
-
# seg config
|
| 151 |
-
args = set_parse()
|
| 152 |
-
args.use_zoom_in = True
|
| 153 |
-
args.use_text_prompt = text_prompt is not None
|
| 154 |
-
args.use_box_prompt = _box_prompt is not None
|
| 155 |
-
args.use_point_prompt = _point_prompt is not None
|
| 156 |
-
|
| 157 |
-
segvol_model = build_model()
|
| 158 |
-
|
| 159 |
-
# run inference
|
| 160 |
-
logits = zoom_in_zoom_out(
|
| 161 |
-
args, segvol_model,
|
| 162 |
-
_image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
|
| 163 |
-
text_prompt, _point_prompt, _box_prompt)
|
| 164 |
-
print(logits.shape)
|
| 165 |
-
resize_transform = transforms.Compose([
|
| 166 |
-
transforms.AddChannel(),
|
| 167 |
-
transforms.Resize((325,325,325), mode='trilinear')
|
| 168 |
-
]
|
| 169 |
-
)
|
| 170 |
-
logits_resize = resize_transform(logits)[0]
|
| 171 |
-
return (torch.sigmoid(logits_resize) > 0.5).int().numpy(), (torch.sigmoid(logits) > 0.5).int().numpy()
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/inference_demo.py
DELETED
|
@@ -1,219 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
import json
|
| 6 |
-
from segment_anything_volumetric import sam_model_registry
|
| 7 |
-
from network.model import SegVol
|
| 8 |
-
from data_process.demo_data_process import process_ct_gt
|
| 9 |
-
import monai.transforms as transforms
|
| 10 |
-
from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
| 11 |
-
from utils.visualize import draw_result
|
| 12 |
-
|
| 13 |
-
def set_parse():
|
| 14 |
-
# %% set up parser
|
| 15 |
-
parser = argparse.ArgumentParser()
|
| 16 |
-
parser.add_argument("--test_mode", default=True, type=bool)
|
| 17 |
-
parser.add_argument("--resume", type = str, default = '')
|
| 18 |
-
parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
|
| 19 |
-
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
| 20 |
-
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
| 21 |
-
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
| 22 |
-
### demo
|
| 23 |
-
parser.add_argument('--demo_config', type=str, required=True)
|
| 24 |
-
parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
|
| 25 |
-
args = parser.parse_args()
|
| 26 |
-
return args
|
| 27 |
-
|
| 28 |
-
def dice_score(preds, labels): # on GPU
|
| 29 |
-
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
| 30 |
-
predict = preds.view(1, -1)
|
| 31 |
-
target = labels.view(1, -1)
|
| 32 |
-
if target.shape[1] < 1e8:
|
| 33 |
-
predict = predict.cuda()
|
| 34 |
-
target = target.cuda()
|
| 35 |
-
predict = torch.sigmoid(predict)
|
| 36 |
-
predict = torch.where(predict > 0.5, 1., 0.)
|
| 37 |
-
|
| 38 |
-
tp = torch.sum(torch.mul(predict, target))
|
| 39 |
-
den = torch.sum(predict) + torch.sum(target) + 1
|
| 40 |
-
dice = 2 * tp / den
|
| 41 |
-
|
| 42 |
-
if target.shape[1] < 1e8:
|
| 43 |
-
predict = predict.cpu()
|
| 44 |
-
target = target.cpu()
|
| 45 |
-
return dice
|
| 46 |
-
|
| 47 |
-
def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
|
| 48 |
-
logits_labels_record = {}
|
| 49 |
-
image_single_resize = image_resize
|
| 50 |
-
image_single = image[0,0]
|
| 51 |
-
ori_shape = image_single.shape
|
| 52 |
-
for item_idx in range(len(categories)):
|
| 53 |
-
# get label to generate prompts
|
| 54 |
-
label_single = gt3D[0][item_idx]
|
| 55 |
-
label_single_resize = gt3D_resize[0][item_idx]
|
| 56 |
-
# skip meaningless categories
|
| 57 |
-
if torch.sum(label_single) == 0:
|
| 58 |
-
print('No object, skip')
|
| 59 |
-
continue
|
| 60 |
-
# generate prompts
|
| 61 |
-
text_single = categories[item_idx] if args.use_text_prompt else None
|
| 62 |
-
if categories is not None: print(f'inference |{categories[item_idx]}| target...')
|
| 63 |
-
points_single = None
|
| 64 |
-
box_single = None
|
| 65 |
-
if args.use_point_prompt:
|
| 66 |
-
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
| 67 |
-
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
| 68 |
-
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
|
| 69 |
-
if args.use_box_prompt:
|
| 70 |
-
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
| 71 |
-
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
|
| 72 |
-
|
| 73 |
-
####################
|
| 74 |
-
# zoom-out inference:
|
| 75 |
-
print('--- zoom out inference ---')
|
| 76 |
-
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
logits_global_single = segvol_model(image_single_resize.cuda(),
|
| 79 |
-
text=text_single,
|
| 80 |
-
boxes=box_single,
|
| 81 |
-
points=points_single)
|
| 82 |
-
|
| 83 |
-
# resize back global logits
|
| 84 |
-
logits_global_single = F.interpolate(
|
| 85 |
-
logits_global_single.cpu(),
|
| 86 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 87 |
-
|
| 88 |
-
# build prompt reflection for zoom-in
|
| 89 |
-
if args.use_point_prompt:
|
| 90 |
-
binary_points = F.interpolate(
|
| 91 |
-
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 92 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 93 |
-
if args.use_box_prompt:
|
| 94 |
-
binary_cube = F.interpolate(
|
| 95 |
-
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
| 96 |
-
size=ori_shape, mode='nearest')[0][0]
|
| 97 |
-
zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
| 98 |
-
logits_labels_record[categories[item_idx]] = (
|
| 99 |
-
zoom_out_dice,
|
| 100 |
-
image_single,
|
| 101 |
-
points_single,
|
| 102 |
-
box_single,
|
| 103 |
-
logits_global_single,
|
| 104 |
-
label_single)
|
| 105 |
-
print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
|
| 106 |
-
if not args.use_zoom_in:
|
| 107 |
-
continue
|
| 108 |
-
|
| 109 |
-
####################
|
| 110 |
-
# zoom-in inference:
|
| 111 |
-
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
| 112 |
-
if min_d is None:
|
| 113 |
-
print('Fail to detect foreground!')
|
| 114 |
-
continue
|
| 115 |
-
|
| 116 |
-
# Crop roi
|
| 117 |
-
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 118 |
-
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
| 119 |
-
|
| 120 |
-
assert not (args.use_box_prompt and args.use_point_prompt)
|
| 121 |
-
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
| 122 |
-
prompt_reflection = None
|
| 123 |
-
if args.use_box_prompt:
|
| 124 |
-
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 125 |
-
prompt_reflection = (
|
| 126 |
-
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
| 127 |
-
global_preds.unsqueeze(0).unsqueeze(0)
|
| 128 |
-
)
|
| 129 |
-
if args.use_point_prompt:
|
| 130 |
-
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
| 131 |
-
prompt_reflection = (
|
| 132 |
-
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
| 133 |
-
global_preds.unsqueeze(0).unsqueeze(0)
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
## inference
|
| 137 |
-
with torch.no_grad():
|
| 138 |
-
logits_single_cropped = sliding_window_inference(
|
| 139 |
-
image_single_cropped.cuda(), prompt_reflection,
|
| 140 |
-
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
| 141 |
-
text=text_single,
|
| 142 |
-
use_box=args.use_box_prompt,
|
| 143 |
-
use_point=args.use_point_prompt,
|
| 144 |
-
)
|
| 145 |
-
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
| 146 |
-
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
| 147 |
-
zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
| 148 |
-
logits_labels_record[categories[item_idx]] = (
|
| 149 |
-
zoom_in_dice,
|
| 150 |
-
image_single,
|
| 151 |
-
points_single,
|
| 152 |
-
box_single,
|
| 153 |
-
logits_global_single,
|
| 154 |
-
label_single)
|
| 155 |
-
print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
|
| 156 |
-
return logits_labels_record
|
| 157 |
-
|
| 158 |
-
def inference_single_ct(args, segvol_model, data_item, categories):
|
| 159 |
-
segvol_model.eval()
|
| 160 |
-
image, gt3D = data_item["image"].float(), data_item["label"]
|
| 161 |
-
image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
|
| 162 |
-
|
| 163 |
-
logits_labels_record = zoom_in_zoom_out(
|
| 164 |
-
args, segvol_model,
|
| 165 |
-
image.unsqueeze(0), image_zoom_out.unsqueeze(0),
|
| 166 |
-
gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
|
| 167 |
-
categories=categories)
|
| 168 |
-
|
| 169 |
-
# visualize
|
| 170 |
-
if args.visualize:
|
| 171 |
-
for target, values in logits_labels_record.items():
|
| 172 |
-
dice_score, image, point_prompt, box_prompt, logits, labels = values
|
| 173 |
-
print(f'{target} result with Dice score {dice_score:.4f} visualizing')
|
| 174 |
-
draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
|
| 175 |
-
|
| 176 |
-
def main(args):
|
| 177 |
-
gpu = 0
|
| 178 |
-
torch.cuda.set_device(gpu)
|
| 179 |
-
# build model
|
| 180 |
-
sam_model = sam_model_registry['vit'](args=args)
|
| 181 |
-
segvol_model = SegVol(
|
| 182 |
-
image_encoder=sam_model.image_encoder,
|
| 183 |
-
mask_decoder=sam_model.mask_decoder,
|
| 184 |
-
prompt_encoder=sam_model.prompt_encoder,
|
| 185 |
-
clip_ckpt=args.clip_ckpt,
|
| 186 |
-
roi_size=args.spatial_size,
|
| 187 |
-
patch_size=args.patch_size,
|
| 188 |
-
test_mode=args.test_mode,
|
| 189 |
-
).cuda()
|
| 190 |
-
segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
|
| 191 |
-
|
| 192 |
-
# load param
|
| 193 |
-
if os.path.isfile(args.resume):
|
| 194 |
-
## Map model to be loaded to specified single GPU
|
| 195 |
-
loc = 'cuda:{}'.format(gpu)
|
| 196 |
-
checkpoint = torch.load(args.resume, map_location=loc)
|
| 197 |
-
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
| 198 |
-
print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
| 199 |
-
|
| 200 |
-
# load demo config
|
| 201 |
-
with open(args.demo_config, 'r') as file:
|
| 202 |
-
config_dict = json.load(file)
|
| 203 |
-
ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
|
| 204 |
-
|
| 205 |
-
# preprocess for data
|
| 206 |
-
data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
|
| 207 |
-
|
| 208 |
-
# seg config for prompt & zoom-in-zoom-out
|
| 209 |
-
args.use_zoom_in = True
|
| 210 |
-
args.use_text_prompt = True
|
| 211 |
-
args.use_box_prompt = True
|
| 212 |
-
args.use_point_prompt = False
|
| 213 |
-
args.visualize = False
|
| 214 |
-
|
| 215 |
-
inference_single_ct(args, segvol_model, data_item, categories)
|
| 216 |
-
|
| 217 |
-
if __name__ == "__main__":
|
| 218 |
-
args = set_parse()
|
| 219 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/network/__pycache__/model.cpython-39.pyc
DELETED
|
Binary file (3.29 kB)
|
|
|
model/network/model.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
-
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
|
| 6 |
-
|
| 7 |
-
#%% set up model
|
| 8 |
-
class SegVol(nn.Module):
|
| 9 |
-
def __init__(self,
|
| 10 |
-
image_encoder,
|
| 11 |
-
mask_decoder,
|
| 12 |
-
prompt_encoder,
|
| 13 |
-
clip_ckpt,
|
| 14 |
-
roi_size,
|
| 15 |
-
patch_size,
|
| 16 |
-
test_mode=False,
|
| 17 |
-
):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.image_encoder = image_encoder
|
| 20 |
-
self.mask_decoder = mask_decoder
|
| 21 |
-
self.prompt_encoder = prompt_encoder
|
| 22 |
-
self.text_encoder = TextEncoder(clip_ckpt)
|
| 23 |
-
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
| 24 |
-
self.test_mode = test_mode
|
| 25 |
-
|
| 26 |
-
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
| 27 |
-
bs = image.shape[0]
|
| 28 |
-
img_shape = (image.shape[2], image.shape[3], image.shape[4])
|
| 29 |
-
image_embedding, _ = self.image_encoder(image)
|
| 30 |
-
image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
|
| 31 |
-
int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
|
| 32 |
-
# test mode
|
| 33 |
-
if self.test_mode:
|
| 34 |
-
return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
|
| 35 |
-
# train mode
|
| 36 |
-
# future release
|
| 37 |
-
|
| 38 |
-
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
| 39 |
-
with torch.no_grad():
|
| 40 |
-
if boxes is not None:
|
| 41 |
-
if len(boxes.shape) == 2:
|
| 42 |
-
boxes = boxes[:, None, :] # (B, 1, 6)
|
| 43 |
-
if text is not None:
|
| 44 |
-
text_embedding = self.text_encoder(text) # (B, 768)
|
| 45 |
-
else:
|
| 46 |
-
text_embedding = None
|
| 47 |
-
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 48 |
-
points=points,
|
| 49 |
-
boxes=boxes,
|
| 50 |
-
masks=None,
|
| 51 |
-
text_embedding=text_embedding,
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
dense_pe = self.prompt_encoder.get_dense_pe()
|
| 55 |
-
low_res_masks, _ = self.mask_decoder(
|
| 56 |
-
image_embeddings=image_embedding,
|
| 57 |
-
text_embedding = text_embedding,
|
| 58 |
-
image_pe=dense_pe,
|
| 59 |
-
sparse_prompt_embeddings=sparse_embeddings,
|
| 60 |
-
dense_prompt_embeddings=dense_embeddings,
|
| 61 |
-
multimask_output=False,
|
| 62 |
-
)
|
| 63 |
-
logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
|
| 64 |
-
return logits
|
| 65 |
-
|
| 66 |
-
class TextEncoder(nn.Module):
|
| 67 |
-
def __init__(self, clip_ckpt):
|
| 68 |
-
super().__init__()
|
| 69 |
-
config = CLIPTextConfig()
|
| 70 |
-
self.clip_text_model = CLIPTextModel(config)
|
| 71 |
-
self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
|
| 72 |
-
self.dim_align = nn.Linear(512, 768)
|
| 73 |
-
# freeze text encoder
|
| 74 |
-
for param in self.clip_text_model.parameters():
|
| 75 |
-
param.requires_grad = False
|
| 76 |
-
|
| 77 |
-
def organ2tokens(self, organ_names):
|
| 78 |
-
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
| 79 |
-
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
| 80 |
-
return tokens
|
| 81 |
-
|
| 82 |
-
def forward(self, text):
|
| 83 |
-
if text is None:
|
| 84 |
-
return None
|
| 85 |
-
if type(text) is str:
|
| 86 |
-
text = [text]
|
| 87 |
-
tokens = self.organ2tokens(text)
|
| 88 |
-
clip_outputs = self.clip_text_model(**tokens)
|
| 89 |
-
text_embedding = clip_outputs.pooler_output
|
| 90 |
-
text_embedding = self.dim_align(text_embedding)
|
| 91 |
-
return text_embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/script/inference_demo.sh
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
export segvol_ckpt="path/to/SegVol_v1.pth"
|
| 2 |
-
export work_dir="path/to/work_dir"
|
| 3 |
-
export demo_config_path="./config/config_demo.json"
|
| 4 |
-
|
| 5 |
-
CUDA_VISIBLE_DEVICES=0 python inference_demo.py \
|
| 6 |
-
--resume $segvol_ckpt \
|
| 7 |
-
-work_dir $work_dir \
|
| 8 |
-
--demo_config $demo_config_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
from functools import partial
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
import urllib.request
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from .modeling import (
|
| 12 |
-
ImageEncoderViT,
|
| 13 |
-
MaskDecoder,
|
| 14 |
-
PromptEncoder,
|
| 15 |
-
Sam,
|
| 16 |
-
TwoWayTransformer,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
from .modeling.image_encoder_swin import SwinTransformer
|
| 20 |
-
|
| 21 |
-
from monai.utils import ensure_tuple_rep, optional_import
|
| 22 |
-
|
| 23 |
-
def build_sam_vit_h(checkpoint=None, image_size=1024):
|
| 24 |
-
return _build_sam(
|
| 25 |
-
encoder_embed_dim=1280,
|
| 26 |
-
encoder_depth=32,
|
| 27 |
-
encoder_num_heads=16,
|
| 28 |
-
encoder_global_attn_indexes=[7, 15, 23, 31],
|
| 29 |
-
checkpoint=checkpoint,
|
| 30 |
-
image_size=image_size,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
build_sam = build_sam_vit_h
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def build_sam_vit_l(checkpoint=None, image_size=1024):
|
| 38 |
-
return _build_sam(
|
| 39 |
-
encoder_embed_dim=1024,
|
| 40 |
-
encoder_depth=24,
|
| 41 |
-
encoder_num_heads=16,
|
| 42 |
-
encoder_global_attn_indexes=[5, 11, 17, 23],
|
| 43 |
-
checkpoint=checkpoint,
|
| 44 |
-
image_size=image_size,
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def build_sam_vit_b(checkpoint=None, image_size=1024):
|
| 49 |
-
return _build_sam(
|
| 50 |
-
encoder_embed_dim=768,
|
| 51 |
-
encoder_depth=12,
|
| 52 |
-
encoder_num_heads=12,
|
| 53 |
-
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 54 |
-
checkpoint=checkpoint,
|
| 55 |
-
image_size=image_size,
|
| 56 |
-
)
|
| 57 |
-
"""
|
| 58 |
-
Examples::
|
| 59 |
-
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
| 60 |
-
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
| 61 |
-
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
| 62 |
-
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
| 63 |
-
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
| 64 |
-
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
| 65 |
-
"""
|
| 66 |
-
|
| 67 |
-
def build_sam_vit_swin(checkpoint=None, image_size=96):
|
| 68 |
-
print('==> build_sam_vit_swin')
|
| 69 |
-
return _build_sam(
|
| 70 |
-
encoder_embed_dim=48,
|
| 71 |
-
encoder_depth=12,
|
| 72 |
-
encoder_num_heads=12,
|
| 73 |
-
encoder_global_attn_indexes=[2, 5, 8, 11],
|
| 74 |
-
checkpoint=checkpoint,
|
| 75 |
-
image_size=image_size,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
sam_model_registry = {
|
| 79 |
-
"default": build_sam_vit_h,
|
| 80 |
-
"vit_h": build_sam_vit_h,
|
| 81 |
-
"vit_l": build_sam_vit_l,
|
| 82 |
-
"vit_b": build_sam_vit_b,
|
| 83 |
-
"swin_vit": build_sam_vit_swin,
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def _build_sam(
|
| 88 |
-
encoder_embed_dim,
|
| 89 |
-
encoder_depth,
|
| 90 |
-
encoder_num_heads,
|
| 91 |
-
encoder_global_attn_indexes,
|
| 92 |
-
checkpoint=None,
|
| 93 |
-
image_size=None,
|
| 94 |
-
spatial_dims=3,
|
| 95 |
-
):
|
| 96 |
-
prompt_embed_dim = 768
|
| 97 |
-
patch_size = ensure_tuple_rep(2, spatial_dims)
|
| 98 |
-
window_size = ensure_tuple_rep(7, spatial_dims)
|
| 99 |
-
image_embedding_size = [size // 32 for size in image_size]
|
| 100 |
-
sam = Sam(
|
| 101 |
-
image_encoder=SwinTransformer(
|
| 102 |
-
in_chans=1,
|
| 103 |
-
embed_dim=encoder_embed_dim,
|
| 104 |
-
window_size=window_size,
|
| 105 |
-
patch_size=patch_size,
|
| 106 |
-
depths=(2, 2, 6, 2), #(2, 2, 6, 2),
|
| 107 |
-
num_heads=(3, 6, 12, 24),
|
| 108 |
-
mlp_ratio=4.0,
|
| 109 |
-
qkv_bias=True,
|
| 110 |
-
spatial_dims=spatial_dims,
|
| 111 |
-
),
|
| 112 |
-
prompt_encoder=PromptEncoder(
|
| 113 |
-
embed_dim=prompt_embed_dim,
|
| 114 |
-
image_embedding_size=image_embedding_size,
|
| 115 |
-
input_image_size=image_size,
|
| 116 |
-
mask_in_chans=16,
|
| 117 |
-
),
|
| 118 |
-
mask_decoder=MaskDecoder(
|
| 119 |
-
num_multimask_outputs=3,
|
| 120 |
-
transformer=TwoWayTransformer(
|
| 121 |
-
depth=2,
|
| 122 |
-
embedding_dim=prompt_embed_dim,
|
| 123 |
-
mlp_dim=2048,
|
| 124 |
-
num_heads=8,
|
| 125 |
-
),
|
| 126 |
-
transformer_dim=prompt_embed_dim,
|
| 127 |
-
iou_head_depth=3,
|
| 128 |
-
iou_head_hidden_dim=256,
|
| 129 |
-
),
|
| 130 |
-
pixel_mean=[123.675, 116.28, 103.53],
|
| 131 |
-
pixel_std=[58.395, 57.12, 57.375],
|
| 132 |
-
)
|
| 133 |
-
sam.eval()
|
| 134 |
-
if checkpoint is not None:
|
| 135 |
-
checkpoint = Path(checkpoint)
|
| 136 |
-
if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
|
| 137 |
-
cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
|
| 138 |
-
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 139 |
-
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
-
print("Downloading SAM ViT-B checkpoint...")
|
| 141 |
-
urllib.request.urlretrieve(
|
| 142 |
-
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
| 143 |
-
checkpoint,
|
| 144 |
-
)
|
| 145 |
-
print(checkpoint.name, " is downloaded!")
|
| 146 |
-
elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
|
| 147 |
-
cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
|
| 148 |
-
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 149 |
-
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 150 |
-
print("Downloading SAM ViT-H checkpoint...")
|
| 151 |
-
urllib.request.urlretrieve(
|
| 152 |
-
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 153 |
-
checkpoint,
|
| 154 |
-
)
|
| 155 |
-
print(checkpoint.name, " is downloaded!")
|
| 156 |
-
elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
|
| 157 |
-
cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
|
| 158 |
-
if len(cmd) == 0 or cmd.lower() == 'y':
|
| 159 |
-
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
| 160 |
-
print("Downloading SAM ViT-L checkpoint...")
|
| 161 |
-
urllib.request.urlretrieve(
|
| 162 |
-
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
| 163 |
-
checkpoint,
|
| 164 |
-
)
|
| 165 |
-
print(checkpoint.name, " is downloaded!")
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
if checkpoint is not None:
|
| 169 |
-
with open(checkpoint, "rb") as f:
|
| 170 |
-
state_dict = torch.load(f)
|
| 171 |
-
sam.load_state_dict(state_dict)
|
| 172 |
-
return sam
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/__init__.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from .build_sam import (
|
| 8 |
-
build_sam_vit_3d,
|
| 9 |
-
sam_model_registry,
|
| 10 |
-
)
|
| 11 |
-
from .predictor import SamPredictor
|
| 12 |
-
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (407 Bytes)
|
|
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (386 Bytes)
|
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc
DELETED
|
Binary file (11.4 kB)
|
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc
DELETED
|
Binary file (11.4 kB)
|
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc
DELETED
|
Binary file (3.3 kB)
|
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc
DELETED
|
Binary file (2.63 kB)
|
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc
DELETED
|
Binary file (9.96 kB)
|
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc
DELETED
|
Binary file (9.99 kB)
|
|
|
model/segment_anything_volumetric/automatic_mask_generator.py
DELETED
|
@@ -1,372 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 10 |
-
|
| 11 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
-
|
| 13 |
-
from .modeling import Sam
|
| 14 |
-
from .predictor import SamPredictor
|
| 15 |
-
from .utils.amg import (
|
| 16 |
-
MaskData,
|
| 17 |
-
area_from_rle,
|
| 18 |
-
batch_iterator,
|
| 19 |
-
batched_mask_to_box,
|
| 20 |
-
box_xyxy_to_xywh,
|
| 21 |
-
build_all_layer_point_grids,
|
| 22 |
-
calculate_stability_score,
|
| 23 |
-
coco_encode_rle,
|
| 24 |
-
generate_crop_boxes,
|
| 25 |
-
is_box_near_crop_edge,
|
| 26 |
-
mask_to_rle_pytorch,
|
| 27 |
-
remove_small_regions,
|
| 28 |
-
rle_to_mask,
|
| 29 |
-
uncrop_boxes_xyxy,
|
| 30 |
-
uncrop_masks,
|
| 31 |
-
uncrop_points,
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class SamAutomaticMaskGenerator:
|
| 36 |
-
def __init__(
|
| 37 |
-
self,
|
| 38 |
-
model: Sam,
|
| 39 |
-
points_per_side: Optional[int] = 32,
|
| 40 |
-
points_per_batch: int = 64,
|
| 41 |
-
pred_iou_thresh: float = 0.88,
|
| 42 |
-
stability_score_thresh: float = 0.95,
|
| 43 |
-
stability_score_offset: float = 1.0,
|
| 44 |
-
box_nms_thresh: float = 0.7,
|
| 45 |
-
crop_n_layers: int = 0,
|
| 46 |
-
crop_nms_thresh: float = 0.7,
|
| 47 |
-
crop_overlap_ratio: float = 512 / 1500,
|
| 48 |
-
crop_n_points_downscale_factor: int = 1,
|
| 49 |
-
point_grids: Optional[List[np.ndarray]] = None,
|
| 50 |
-
min_mask_region_area: int = 0,
|
| 51 |
-
output_mode: str = "binary_mask",
|
| 52 |
-
) -> None:
|
| 53 |
-
"""
|
| 54 |
-
Using a SAM model, generates masks for the entire image.
|
| 55 |
-
Generates a grid of point prompts over the image, then filters
|
| 56 |
-
low quality and duplicate masks. The default settings are chosen
|
| 57 |
-
for SAM with a ViT-H backbone.
|
| 58 |
-
|
| 59 |
-
Arguments:
|
| 60 |
-
model (Sam): The SAM model to use for mask prediction.
|
| 61 |
-
points_per_side (int or None): The number of points to be sampled
|
| 62 |
-
along one side of the image. The total number of points is
|
| 63 |
-
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 64 |
-
point sampling.
|
| 65 |
-
points_per_batch (int): Sets the number of points run simultaneously
|
| 66 |
-
by the model. Higher numbers may be faster but use more GPU memory.
|
| 67 |
-
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 68 |
-
model's predicted mask quality.
|
| 69 |
-
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 70 |
-
the stability of the mask under changes to the cutoff used to binarize
|
| 71 |
-
the model's mask predictions.
|
| 72 |
-
stability_score_offset (float): The amount to shift the cutoff when
|
| 73 |
-
calculated the stability score.
|
| 74 |
-
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 75 |
-
suppression to filter duplicate masks.
|
| 76 |
-
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 77 |
-
crops of the image. Sets the number of layers to run, where each
|
| 78 |
-
layer has 2**i_layer number of image crops.
|
| 79 |
-
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 80 |
-
suppression to filter duplicate masks between different crops.
|
| 81 |
-
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 82 |
-
In the first crop layer, crops will overlap by this fraction of
|
| 83 |
-
the image length. Later layers with more crops scale down this overlap.
|
| 84 |
-
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 85 |
-
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 86 |
-
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 87 |
-
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 88 |
-
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 89 |
-
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 90 |
-
to remove disconnected regions and holes in masks with area smaller
|
| 91 |
-
than min_mask_region_area. Requires opencv.
|
| 92 |
-
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 93 |
-
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 94 |
-
For large resolutions, 'binary_mask' may consume large amounts of
|
| 95 |
-
memory.
|
| 96 |
-
"""
|
| 97 |
-
|
| 98 |
-
assert (points_per_side is None) != (
|
| 99 |
-
point_grids is None
|
| 100 |
-
), "Exactly one of points_per_side or point_grid must be provided."
|
| 101 |
-
if points_per_side is not None:
|
| 102 |
-
self.point_grids = build_all_layer_point_grids(
|
| 103 |
-
points_per_side,
|
| 104 |
-
crop_n_layers,
|
| 105 |
-
crop_n_points_downscale_factor,
|
| 106 |
-
)
|
| 107 |
-
elif point_grids is not None:
|
| 108 |
-
self.point_grids = point_grids
|
| 109 |
-
else:
|
| 110 |
-
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 111 |
-
|
| 112 |
-
assert output_mode in [
|
| 113 |
-
"binary_mask",
|
| 114 |
-
"uncompressed_rle",
|
| 115 |
-
"coco_rle",
|
| 116 |
-
], f"Unknown output_mode {output_mode}."
|
| 117 |
-
if output_mode == "coco_rle":
|
| 118 |
-
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 119 |
-
|
| 120 |
-
if min_mask_region_area > 0:
|
| 121 |
-
import cv2 # type: ignore # noqa: F401
|
| 122 |
-
|
| 123 |
-
self.predictor = SamPredictor(model)
|
| 124 |
-
self.points_per_batch = points_per_batch
|
| 125 |
-
self.pred_iou_thresh = pred_iou_thresh
|
| 126 |
-
self.stability_score_thresh = stability_score_thresh
|
| 127 |
-
self.stability_score_offset = stability_score_offset
|
| 128 |
-
self.box_nms_thresh = box_nms_thresh
|
| 129 |
-
self.crop_n_layers = crop_n_layers
|
| 130 |
-
self.crop_nms_thresh = crop_nms_thresh
|
| 131 |
-
self.crop_overlap_ratio = crop_overlap_ratio
|
| 132 |
-
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 133 |
-
self.min_mask_region_area = min_mask_region_area
|
| 134 |
-
self.output_mode = output_mode
|
| 135 |
-
|
| 136 |
-
@torch.no_grad()
|
| 137 |
-
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
| 138 |
-
"""
|
| 139 |
-
Generates masks for the given image.
|
| 140 |
-
|
| 141 |
-
Arguments:
|
| 142 |
-
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 143 |
-
|
| 144 |
-
Returns:
|
| 145 |
-
list(dict(str, any)): A list over records for masks. Each record is
|
| 146 |
-
a dict containing the following keys:
|
| 147 |
-
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 148 |
-
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 149 |
-
is a dictionary containing the RLE.
|
| 150 |
-
bbox (list(float)): The box around the mask, in XYWH format.
|
| 151 |
-
area (int): The area in pixels of the mask.
|
| 152 |
-
predicted_iou (float): The model's own prediction of the mask's
|
| 153 |
-
quality. This is filtered by the pred_iou_thresh parameter.
|
| 154 |
-
point_coords (list(list(float))): The point coordinates input
|
| 155 |
-
to the model to generate this mask.
|
| 156 |
-
stability_score (float): A measure of the mask's quality. This
|
| 157 |
-
is filtered on using the stability_score_thresh parameter.
|
| 158 |
-
crop_box (list(float)): The crop of the image used to generate
|
| 159 |
-
the mask, given in XYWH format.
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
# Generate masks
|
| 163 |
-
mask_data = self._generate_masks(image)
|
| 164 |
-
|
| 165 |
-
# Filter small disconnected regions and holes in masks
|
| 166 |
-
if self.min_mask_region_area > 0:
|
| 167 |
-
mask_data = self.postprocess_small_regions(
|
| 168 |
-
mask_data,
|
| 169 |
-
self.min_mask_region_area,
|
| 170 |
-
max(self.box_nms_thresh, self.crop_nms_thresh),
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
# Encode masks
|
| 174 |
-
if self.output_mode == "coco_rle":
|
| 175 |
-
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
| 176 |
-
elif self.output_mode == "binary_mask":
|
| 177 |
-
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 178 |
-
else:
|
| 179 |
-
mask_data["segmentations"] = mask_data["rles"]
|
| 180 |
-
|
| 181 |
-
# Write mask records
|
| 182 |
-
curr_anns = []
|
| 183 |
-
for idx in range(len(mask_data["segmentations"])):
|
| 184 |
-
ann = {
|
| 185 |
-
"segmentation": mask_data["segmentations"][idx],
|
| 186 |
-
"area": area_from_rle(mask_data["rles"][idx]),
|
| 187 |
-
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 188 |
-
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 189 |
-
"point_coords": [mask_data["points"][idx].tolist()],
|
| 190 |
-
"stability_score": mask_data["stability_score"][idx].item(),
|
| 191 |
-
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 192 |
-
}
|
| 193 |
-
curr_anns.append(ann)
|
| 194 |
-
|
| 195 |
-
return curr_anns
|
| 196 |
-
|
| 197 |
-
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 198 |
-
orig_size = image.shape[:2]
|
| 199 |
-
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 200 |
-
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
# Iterate over image crops
|
| 204 |
-
data = MaskData()
|
| 205 |
-
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 206 |
-
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
| 207 |
-
data.cat(crop_data)
|
| 208 |
-
|
| 209 |
-
# Remove duplicate masks between crops
|
| 210 |
-
if len(crop_boxes) > 1:
|
| 211 |
-
# Prefer masks from smaller crops
|
| 212 |
-
scores = 1 / box_area(data["crop_boxes"])
|
| 213 |
-
scores = scores.to(data["boxes"].device)
|
| 214 |
-
keep_by_nms = batched_nms(
|
| 215 |
-
data["boxes"].float(),
|
| 216 |
-
scores,
|
| 217 |
-
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 218 |
-
iou_threshold=self.crop_nms_thresh,
|
| 219 |
-
)
|
| 220 |
-
data.filter(keep_by_nms)
|
| 221 |
-
|
| 222 |
-
data.to_numpy()
|
| 223 |
-
return data
|
| 224 |
-
|
| 225 |
-
def _process_crop(
|
| 226 |
-
self,
|
| 227 |
-
image: np.ndarray,
|
| 228 |
-
crop_box: List[int],
|
| 229 |
-
crop_layer_idx: int,
|
| 230 |
-
orig_size: Tuple[int, ...],
|
| 231 |
-
) -> MaskData:
|
| 232 |
-
# Crop the image and calculate embeddings
|
| 233 |
-
x0, y0, x1, y1 = crop_box
|
| 234 |
-
cropped_im = image[y0:y1, x0:x1, :]
|
| 235 |
-
cropped_im_size = cropped_im.shape[:2]
|
| 236 |
-
self.predictor.set_image(cropped_im)
|
| 237 |
-
|
| 238 |
-
# Get points for this crop
|
| 239 |
-
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 240 |
-
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 241 |
-
|
| 242 |
-
# Generate masks for this crop in batches
|
| 243 |
-
data = MaskData()
|
| 244 |
-
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 245 |
-
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
| 246 |
-
data.cat(batch_data)
|
| 247 |
-
del batch_data
|
| 248 |
-
self.predictor.reset_image()
|
| 249 |
-
|
| 250 |
-
# Remove duplicates within this crop.
|
| 251 |
-
keep_by_nms = batched_nms(
|
| 252 |
-
data["boxes"].float(),
|
| 253 |
-
data["iou_preds"],
|
| 254 |
-
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 255 |
-
iou_threshold=self.box_nms_thresh,
|
| 256 |
-
)
|
| 257 |
-
data.filter(keep_by_nms)
|
| 258 |
-
|
| 259 |
-
# Return to the original image frame
|
| 260 |
-
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 261 |
-
data["points"] = uncrop_points(data["points"], crop_box)
|
| 262 |
-
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 263 |
-
|
| 264 |
-
return data
|
| 265 |
-
|
| 266 |
-
def _process_batch(
|
| 267 |
-
self,
|
| 268 |
-
points: np.ndarray,
|
| 269 |
-
im_size: Tuple[int, ...],
|
| 270 |
-
crop_box: List[int],
|
| 271 |
-
orig_size: Tuple[int, ...],
|
| 272 |
-
) -> MaskData:
|
| 273 |
-
orig_h, orig_w = orig_size
|
| 274 |
-
|
| 275 |
-
# Run model on this batch
|
| 276 |
-
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
| 277 |
-
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
| 278 |
-
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
| 279 |
-
masks, iou_preds, _ = self.predictor.predict_torch(
|
| 280 |
-
in_points[:, None, :],
|
| 281 |
-
in_labels[:, None],
|
| 282 |
-
multimask_output=True,
|
| 283 |
-
return_logits=True,
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
# Serialize predictions and store in MaskData
|
| 287 |
-
data = MaskData(
|
| 288 |
-
masks=masks.flatten(0, 1),
|
| 289 |
-
iou_preds=iou_preds.flatten(0, 1),
|
| 290 |
-
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
| 291 |
-
)
|
| 292 |
-
del masks
|
| 293 |
-
|
| 294 |
-
# Filter by predicted IoU
|
| 295 |
-
if self.pred_iou_thresh > 0.0:
|
| 296 |
-
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 297 |
-
data.filter(keep_mask)
|
| 298 |
-
|
| 299 |
-
# Calculate stability score
|
| 300 |
-
data["stability_score"] = calculate_stability_score(
|
| 301 |
-
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
| 302 |
-
)
|
| 303 |
-
if self.stability_score_thresh > 0.0:
|
| 304 |
-
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 305 |
-
data.filter(keep_mask)
|
| 306 |
-
|
| 307 |
-
# Threshold masks and calculate boxes
|
| 308 |
-
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
| 309 |
-
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 310 |
-
|
| 311 |
-
# Filter boxes that touch crop boundaries
|
| 312 |
-
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
| 313 |
-
if not torch.all(keep_mask):
|
| 314 |
-
data.filter(keep_mask)
|
| 315 |
-
|
| 316 |
-
# Compress to RLE
|
| 317 |
-
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 318 |
-
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 319 |
-
del data["masks"]
|
| 320 |
-
|
| 321 |
-
return data
|
| 322 |
-
|
| 323 |
-
@staticmethod
|
| 324 |
-
def postprocess_small_regions(
|
| 325 |
-
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 326 |
-
) -> MaskData:
|
| 327 |
-
"""
|
| 328 |
-
Removes small disconnected regions and holes in masks, then reruns
|
| 329 |
-
box NMS to remove any new duplicates.
|
| 330 |
-
|
| 331 |
-
Edits mask_data in place.
|
| 332 |
-
|
| 333 |
-
Requires open-cv as a dependency.
|
| 334 |
-
"""
|
| 335 |
-
if len(mask_data["rles"]) == 0:
|
| 336 |
-
return mask_data
|
| 337 |
-
|
| 338 |
-
# Filter small disconnected regions and holes
|
| 339 |
-
new_masks = []
|
| 340 |
-
scores = []
|
| 341 |
-
for rle in mask_data["rles"]:
|
| 342 |
-
mask = rle_to_mask(rle)
|
| 343 |
-
|
| 344 |
-
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 345 |
-
unchanged = not changed
|
| 346 |
-
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 347 |
-
unchanged = unchanged and not changed
|
| 348 |
-
|
| 349 |
-
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 350 |
-
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 351 |
-
# so NMS will prefer ones that didn't need postprocessing
|
| 352 |
-
scores.append(float(unchanged))
|
| 353 |
-
|
| 354 |
-
# Recalculate boxes and remove any new duplicates
|
| 355 |
-
masks = torch.cat(new_masks, dim=0)
|
| 356 |
-
boxes = batched_mask_to_box(masks)
|
| 357 |
-
keep_by_nms = batched_nms(
|
| 358 |
-
boxes.float(),
|
| 359 |
-
torch.as_tensor(scores),
|
| 360 |
-
torch.zeros_like(boxes[:, 0]), # categories
|
| 361 |
-
iou_threshold=nms_thresh,
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
# Only recalculate RLEs for masks that have changed
|
| 365 |
-
for i_mask in keep_by_nms:
|
| 366 |
-
if scores[i_mask] == 0.0:
|
| 367 |
-
mask_torch = masks[i_mask].unsqueeze(0)
|
| 368 |
-
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 369 |
-
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 370 |
-
mask_data.filter(keep_by_nms)
|
| 371 |
-
|
| 372 |
-
return mask_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/build_sam.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
from functools import partial
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
import urllib.request
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from .modeling import (
|
| 12 |
-
ImageEncoderViT,
|
| 13 |
-
MaskDecoder,
|
| 14 |
-
PromptEncoder,
|
| 15 |
-
Sam,
|
| 16 |
-
TwoWayTransformer,
|
| 17 |
-
)
|
| 18 |
-
import numpy as np
|
| 19 |
-
from .modeling.image_encoder_swin import SwinTransformer
|
| 20 |
-
from monai.networks.nets import ViT
|
| 21 |
-
from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
|
| 22 |
-
|
| 23 |
-
from monai.utils import ensure_tuple_rep, optional_import
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
"""
|
| 27 |
-
Examples::
|
| 28 |
-
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
| 29 |
-
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
| 30 |
-
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
| 31 |
-
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
| 32 |
-
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
| 33 |
-
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
def build_sam_vit_3d(checkpoint=None):
|
| 37 |
-
print('build_sam_vit_3d...')
|
| 38 |
-
return _build_sam(
|
| 39 |
-
image_encoder_type='vit',
|
| 40 |
-
embed_dim = 768,
|
| 41 |
-
patch_size=[4,16,16],
|
| 42 |
-
checkpoint=checkpoint,
|
| 43 |
-
image_size=[32,256,256],
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
sam_model_registry = {
|
| 47 |
-
"vit": build_sam_vit_3d,
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def _build_sam(
|
| 52 |
-
image_encoder_type,
|
| 53 |
-
embed_dim,
|
| 54 |
-
patch_size,
|
| 55 |
-
checkpoint,
|
| 56 |
-
image_size,
|
| 57 |
-
):
|
| 58 |
-
mlp_dim = 3072
|
| 59 |
-
num_layers = 12
|
| 60 |
-
num_heads = 12
|
| 61 |
-
pos_embed = 'perceptron'
|
| 62 |
-
dropout_rate = 0.0
|
| 63 |
-
|
| 64 |
-
image_encoder=ViT(
|
| 65 |
-
in_channels=1,
|
| 66 |
-
img_size=image_size,
|
| 67 |
-
patch_size=patch_size,
|
| 68 |
-
hidden_size=embed_dim,
|
| 69 |
-
mlp_dim=mlp_dim,
|
| 70 |
-
num_layers=num_layers,
|
| 71 |
-
num_heads=num_heads,
|
| 72 |
-
pos_embed=pos_embed,
|
| 73 |
-
classification=False,
|
| 74 |
-
dropout_rate=dropout_rate,
|
| 75 |
-
)
|
| 76 |
-
image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
|
| 77 |
-
|
| 78 |
-
if checkpoint is not None:
|
| 79 |
-
with open(checkpoint, "rb") as f:
|
| 80 |
-
state_dict = torch.load(f, map_location='cpu')['state_dict']
|
| 81 |
-
encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
|
| 82 |
-
image_encoder.load_state_dict(encoder_dict)
|
| 83 |
-
print(f'===> image_encoder.load_param: {checkpoint}')
|
| 84 |
-
sam = Sam(
|
| 85 |
-
image_encoder=image_encoder,
|
| 86 |
-
prompt_encoder=PromptEncoder(
|
| 87 |
-
embed_dim=embed_dim,
|
| 88 |
-
image_embedding_size=image_embedding_size,
|
| 89 |
-
input_image_size=image_size,
|
| 90 |
-
mask_in_chans=16,
|
| 91 |
-
),
|
| 92 |
-
mask_decoder=MaskDecoder(
|
| 93 |
-
image_encoder_type=image_encoder_type,
|
| 94 |
-
num_multimask_outputs=3,
|
| 95 |
-
transformer=TwoWayTransformer(
|
| 96 |
-
depth=2,
|
| 97 |
-
embedding_dim=embed_dim,
|
| 98 |
-
mlp_dim=2048,
|
| 99 |
-
num_heads=8,
|
| 100 |
-
),
|
| 101 |
-
transformer_dim=embed_dim,
|
| 102 |
-
iou_head_depth=3,
|
| 103 |
-
iou_head_hidden_dim=256,
|
| 104 |
-
image_size=np.array(image_size),
|
| 105 |
-
patch_size=np.array(patch_size),
|
| 106 |
-
),
|
| 107 |
-
pixel_mean=[123.675, 116.28, 103.53],
|
| 108 |
-
pixel_std=[58.395, 57.12, 57.375],
|
| 109 |
-
)
|
| 110 |
-
sam.eval()
|
| 111 |
-
return sam
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py
DELETED
|
@@ -1,709 +0,0 @@
|
|
| 1 |
-
from typing import Sequence, Tuple, Type, Union
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
import torch.utils.checkpoint as checkpoint
|
| 8 |
-
from torch.nn import LayerNorm
|
| 9 |
-
|
| 10 |
-
from monai.networks.blocks import MLPBlock as Mlp
|
| 11 |
-
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
|
| 12 |
-
from monai.networks.layers import DropPath, trunc_normal_
|
| 13 |
-
from monai.utils import ensure_tuple_rep, optional_import
|
| 14 |
-
|
| 15 |
-
rearrange, _ = optional_import("einops", name="rearrange")
|
| 16 |
-
|
| 17 |
-
def window_partition(x, window_size):
|
| 18 |
-
"""window partition operation based on: "Liu et al.,
|
| 19 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 20 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 21 |
-
https://github.com/microsoft/Swin-Transformer
|
| 22 |
-
Args:
|
| 23 |
-
x: input tensor.
|
| 24 |
-
window_size: local window size.
|
| 25 |
-
"""
|
| 26 |
-
x_shape = x.size()
|
| 27 |
-
if len(x_shape) == 5:
|
| 28 |
-
b, d, h, w, c = x_shape
|
| 29 |
-
x = x.view(
|
| 30 |
-
b,
|
| 31 |
-
d // window_size[0],
|
| 32 |
-
window_size[0],
|
| 33 |
-
h // window_size[1],
|
| 34 |
-
window_size[1],
|
| 35 |
-
w // window_size[2],
|
| 36 |
-
window_size[2],
|
| 37 |
-
c,
|
| 38 |
-
)
|
| 39 |
-
windows = (
|
| 40 |
-
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
| 41 |
-
)
|
| 42 |
-
elif len(x_shape) == 4:
|
| 43 |
-
b, h, w, c = x.shape
|
| 44 |
-
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
| 45 |
-
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
| 46 |
-
return windows
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def window_reverse(windows, window_size, dims):
|
| 50 |
-
"""window reverse operation based on: "Liu et al.,
|
| 51 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 52 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 53 |
-
https://github.com/microsoft/Swin-Transformer
|
| 54 |
-
Args:
|
| 55 |
-
windows: windows tensor.
|
| 56 |
-
window_size: local window size.
|
| 57 |
-
dims: dimension values.
|
| 58 |
-
"""
|
| 59 |
-
if len(dims) == 4:
|
| 60 |
-
b, d, h, w = dims
|
| 61 |
-
x = windows.view(
|
| 62 |
-
b,
|
| 63 |
-
d // window_size[0],
|
| 64 |
-
h // window_size[1],
|
| 65 |
-
w // window_size[2],
|
| 66 |
-
window_size[0],
|
| 67 |
-
window_size[1],
|
| 68 |
-
window_size[2],
|
| 69 |
-
-1,
|
| 70 |
-
)
|
| 71 |
-
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
|
| 72 |
-
|
| 73 |
-
elif len(dims) == 3:
|
| 74 |
-
b, h, w = dims
|
| 75 |
-
x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
|
| 76 |
-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
| 77 |
-
return x
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def get_window_size(x_size, window_size, shift_size=None):
|
| 81 |
-
"""Computing window size based on: "Liu et al.,
|
| 82 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 83 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 84 |
-
https://github.com/microsoft/Swin-Transformer
|
| 85 |
-
Args:
|
| 86 |
-
x_size: input size.
|
| 87 |
-
window_size: local window size.
|
| 88 |
-
shift_size: window shifting size.
|
| 89 |
-
"""
|
| 90 |
-
|
| 91 |
-
use_window_size = list(window_size)
|
| 92 |
-
if shift_size is not None:
|
| 93 |
-
use_shift_size = list(shift_size)
|
| 94 |
-
for i in range(len(x_size)):
|
| 95 |
-
if x_size[i] <= window_size[i]:
|
| 96 |
-
use_window_size[i] = x_size[i]
|
| 97 |
-
if shift_size is not None:
|
| 98 |
-
use_shift_size[i] = 0
|
| 99 |
-
|
| 100 |
-
if shift_size is None:
|
| 101 |
-
return tuple(use_window_size)
|
| 102 |
-
else:
|
| 103 |
-
return tuple(use_window_size), tuple(use_shift_size)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class WindowAttention(nn.Module):
|
| 107 |
-
"""
|
| 108 |
-
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
|
| 109 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 110 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 111 |
-
https://github.com/microsoft/Swin-Transformer
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
def __init__(
|
| 115 |
-
self,
|
| 116 |
-
dim: int,
|
| 117 |
-
num_heads: int,
|
| 118 |
-
window_size: Sequence[int],
|
| 119 |
-
qkv_bias: bool = False,
|
| 120 |
-
attn_drop: float = 0.0,
|
| 121 |
-
proj_drop: float = 0.0,
|
| 122 |
-
) -> None:
|
| 123 |
-
"""
|
| 124 |
-
Args:
|
| 125 |
-
dim: number of feature channels.
|
| 126 |
-
num_heads: number of attention heads.
|
| 127 |
-
window_size: local window size.
|
| 128 |
-
qkv_bias: add a learnable bias to query, key, value.
|
| 129 |
-
attn_drop: attention dropout rate.
|
| 130 |
-
proj_drop: dropout rate of output.
|
| 131 |
-
"""
|
| 132 |
-
|
| 133 |
-
super().__init__()
|
| 134 |
-
self.dim = dim
|
| 135 |
-
self.window_size = window_size
|
| 136 |
-
self.num_heads = num_heads
|
| 137 |
-
head_dim = dim // num_heads
|
| 138 |
-
self.scale = head_dim**-0.5
|
| 139 |
-
mesh_args = torch.meshgrid.__kwdefaults__
|
| 140 |
-
|
| 141 |
-
if len(self.window_size) == 3:
|
| 142 |
-
self.relative_position_bias_table = nn.Parameter(
|
| 143 |
-
torch.zeros(
|
| 144 |
-
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
|
| 145 |
-
num_heads,
|
| 146 |
-
)
|
| 147 |
-
)
|
| 148 |
-
coords_d = torch.arange(self.window_size[0])
|
| 149 |
-
coords_h = torch.arange(self.window_size[1])
|
| 150 |
-
coords_w = torch.arange(self.window_size[2])
|
| 151 |
-
if mesh_args is not None:
|
| 152 |
-
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
|
| 153 |
-
else:
|
| 154 |
-
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
|
| 155 |
-
coords_flatten = torch.flatten(coords, 1)
|
| 156 |
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 157 |
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 158 |
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 159 |
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 160 |
-
relative_coords[:, :, 2] += self.window_size[2] - 1
|
| 161 |
-
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
| 162 |
-
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
| 163 |
-
elif len(self.window_size) == 2:
|
| 164 |
-
self.relative_position_bias_table = nn.Parameter(
|
| 165 |
-
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
| 166 |
-
)
|
| 167 |
-
coords_h = torch.arange(self.window_size[0])
|
| 168 |
-
coords_w = torch.arange(self.window_size[1])
|
| 169 |
-
if mesh_args is not None:
|
| 170 |
-
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
|
| 171 |
-
else:
|
| 172 |
-
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
|
| 173 |
-
coords_flatten = torch.flatten(coords, 1)
|
| 174 |
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
| 175 |
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 176 |
-
relative_coords[:, :, 0] += self.window_size[0] - 1
|
| 177 |
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 178 |
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 179 |
-
|
| 180 |
-
relative_position_index = relative_coords.sum(-1)
|
| 181 |
-
self.register_buffer("relative_position_index", relative_position_index)
|
| 182 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 183 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 184 |
-
self.proj = nn.Linear(dim, dim)
|
| 185 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
| 186 |
-
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
| 187 |
-
self.softmax = nn.Softmax(dim=-1)
|
| 188 |
-
|
| 189 |
-
def forward(self, x, mask):
|
| 190 |
-
b, n, c = x.shape
|
| 191 |
-
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 192 |
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 193 |
-
q = q * self.scale
|
| 194 |
-
attn = q @ k.transpose(-2, -1)
|
| 195 |
-
relative_position_bias = self.relative_position_bias_table[
|
| 196 |
-
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
| 197 |
-
].reshape(n, n, -1)
|
| 198 |
-
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
| 199 |
-
attn = attn + relative_position_bias.unsqueeze(0)
|
| 200 |
-
if mask is not None:
|
| 201 |
-
nw = mask.shape[0]
|
| 202 |
-
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
| 203 |
-
attn = attn.view(-1, self.num_heads, n, n)
|
| 204 |
-
attn = self.softmax(attn)
|
| 205 |
-
else:
|
| 206 |
-
attn = self.softmax(attn)
|
| 207 |
-
|
| 208 |
-
attn = self.attn_drop(attn)
|
| 209 |
-
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
| 210 |
-
x = self.proj(x)
|
| 211 |
-
x = self.proj_drop(x)
|
| 212 |
-
return x
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
class SwinTransformerBlock(nn.Module):
|
| 216 |
-
"""
|
| 217 |
-
Swin Transformer block based on: "Liu et al.,
|
| 218 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 219 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 220 |
-
https://github.com/microsoft/Swin-Transformer
|
| 221 |
-
"""
|
| 222 |
-
|
| 223 |
-
def __init__(
|
| 224 |
-
self,
|
| 225 |
-
dim: int,
|
| 226 |
-
num_heads: int,
|
| 227 |
-
window_size: Sequence[int],
|
| 228 |
-
shift_size: Sequence[int],
|
| 229 |
-
mlp_ratio: float = 4.0,
|
| 230 |
-
qkv_bias: bool = True,
|
| 231 |
-
drop: float = 0.0,
|
| 232 |
-
attn_drop: float = 0.0,
|
| 233 |
-
drop_path: float = 0.0,
|
| 234 |
-
act_layer: str = "GELU",
|
| 235 |
-
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 236 |
-
use_checkpoint: bool = False,
|
| 237 |
-
) -> None:
|
| 238 |
-
"""
|
| 239 |
-
Args:
|
| 240 |
-
dim: number of feature channels.
|
| 241 |
-
num_heads: number of attention heads.
|
| 242 |
-
window_size: local window size.
|
| 243 |
-
shift_size: window shift size.
|
| 244 |
-
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 245 |
-
qkv_bias: add a learnable bias to query, key, value.
|
| 246 |
-
drop: dropout rate.
|
| 247 |
-
attn_drop: attention dropout rate.
|
| 248 |
-
drop_path: stochastic depth rate.
|
| 249 |
-
act_layer: activation layer.
|
| 250 |
-
norm_layer: normalization layer.
|
| 251 |
-
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 252 |
-
"""
|
| 253 |
-
|
| 254 |
-
super().__init__()
|
| 255 |
-
self.dim = dim
|
| 256 |
-
self.num_heads = num_heads
|
| 257 |
-
self.window_size = window_size
|
| 258 |
-
self.shift_size = shift_size
|
| 259 |
-
self.mlp_ratio = mlp_ratio
|
| 260 |
-
self.use_checkpoint = use_checkpoint
|
| 261 |
-
self.norm1 = norm_layer(dim)
|
| 262 |
-
self.attn = WindowAttention(
|
| 263 |
-
dim,
|
| 264 |
-
window_size=self.window_size,
|
| 265 |
-
num_heads=num_heads,
|
| 266 |
-
qkv_bias=qkv_bias,
|
| 267 |
-
attn_drop=attn_drop,
|
| 268 |
-
proj_drop=drop,
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 272 |
-
self.norm2 = norm_layer(dim)
|
| 273 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 274 |
-
self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
|
| 275 |
-
|
| 276 |
-
def forward_part1(self, x, mask_matrix):
|
| 277 |
-
x_shape = x.size()
|
| 278 |
-
x = self.norm1(x)
|
| 279 |
-
if len(x_shape) == 5:
|
| 280 |
-
b, d, h, w, c = x.shape
|
| 281 |
-
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 282 |
-
pad_l = pad_t = pad_d0 = 0
|
| 283 |
-
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
|
| 284 |
-
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
|
| 285 |
-
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
|
| 286 |
-
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
| 287 |
-
_, dp, hp, wp, _ = x.shape
|
| 288 |
-
dims = [b, dp, hp, wp]
|
| 289 |
-
|
| 290 |
-
elif len(x_shape) == 4:
|
| 291 |
-
b, h, w, c = x.shape
|
| 292 |
-
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 293 |
-
pad_l = pad_t = 0
|
| 294 |
-
pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
|
| 295 |
-
pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
|
| 296 |
-
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 297 |
-
_, hp, wp, _ = x.shape
|
| 298 |
-
dims = [b, hp, wp]
|
| 299 |
-
|
| 300 |
-
if any(i > 0 for i in shift_size):
|
| 301 |
-
if len(x_shape) == 5:
|
| 302 |
-
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
| 303 |
-
elif len(x_shape) == 4:
|
| 304 |
-
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
| 305 |
-
attn_mask = mask_matrix
|
| 306 |
-
else:
|
| 307 |
-
shifted_x = x
|
| 308 |
-
attn_mask = None
|
| 309 |
-
x_windows = window_partition(shifted_x, window_size)
|
| 310 |
-
attn_windows = self.attn(x_windows, mask=attn_mask)
|
| 311 |
-
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
|
| 312 |
-
shifted_x = window_reverse(attn_windows, window_size, dims)
|
| 313 |
-
if any(i > 0 for i in shift_size):
|
| 314 |
-
if len(x_shape) == 5:
|
| 315 |
-
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
| 316 |
-
elif len(x_shape) == 4:
|
| 317 |
-
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
| 318 |
-
else:
|
| 319 |
-
x = shifted_x
|
| 320 |
-
|
| 321 |
-
if len(x_shape) == 5:
|
| 322 |
-
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
|
| 323 |
-
x = x[:, :d, :h, :w, :].contiguous()
|
| 324 |
-
elif len(x_shape) == 4:
|
| 325 |
-
if pad_r > 0 or pad_b > 0:
|
| 326 |
-
x = x[:, :h, :w, :].contiguous()
|
| 327 |
-
|
| 328 |
-
return x
|
| 329 |
-
|
| 330 |
-
def forward_part2(self, x):
|
| 331 |
-
return self.drop_path(self.mlp(self.norm2(x)))
|
| 332 |
-
|
| 333 |
-
def load_from(self, weights, n_block, layer):
|
| 334 |
-
root = f"module.{layer}.0.blocks.{n_block}."
|
| 335 |
-
block_names = [
|
| 336 |
-
"norm1.weight",
|
| 337 |
-
"norm1.bias",
|
| 338 |
-
"attn.relative_position_bias_table",
|
| 339 |
-
"attn.relative_position_index",
|
| 340 |
-
"attn.qkv.weight",
|
| 341 |
-
"attn.qkv.bias",
|
| 342 |
-
"attn.proj.weight",
|
| 343 |
-
"attn.proj.bias",
|
| 344 |
-
"norm2.weight",
|
| 345 |
-
"norm2.bias",
|
| 346 |
-
"mlp.fc1.weight",
|
| 347 |
-
"mlp.fc1.bias",
|
| 348 |
-
"mlp.fc2.weight",
|
| 349 |
-
"mlp.fc2.bias",
|
| 350 |
-
]
|
| 351 |
-
with torch.no_grad():
|
| 352 |
-
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
| 353 |
-
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
| 354 |
-
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
| 355 |
-
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
| 356 |
-
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
| 357 |
-
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
| 358 |
-
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
| 359 |
-
self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
|
| 360 |
-
self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
|
| 361 |
-
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
|
| 362 |
-
self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
|
| 363 |
-
self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
|
| 364 |
-
self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
|
| 365 |
-
self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
|
| 366 |
-
|
| 367 |
-
def forward(self, x, mask_matrix):
|
| 368 |
-
shortcut = x
|
| 369 |
-
if self.use_checkpoint:
|
| 370 |
-
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
| 371 |
-
else:
|
| 372 |
-
x = self.forward_part1(x, mask_matrix)
|
| 373 |
-
x = shortcut + self.drop_path(x)
|
| 374 |
-
if self.use_checkpoint:
|
| 375 |
-
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
| 376 |
-
else:
|
| 377 |
-
x = x + self.forward_part2(x)
|
| 378 |
-
return x
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
class PatchMerging(nn.Module):
|
| 382 |
-
"""
|
| 383 |
-
Patch merging layer based on: "Liu et al.,
|
| 384 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 385 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 386 |
-
https://github.com/microsoft/Swin-Transformer
|
| 387 |
-
"""
|
| 388 |
-
|
| 389 |
-
def __init__(
|
| 390 |
-
self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
|
| 391 |
-
) -> None: # type: ignore
|
| 392 |
-
"""
|
| 393 |
-
Args:
|
| 394 |
-
dim: number of feature channels.
|
| 395 |
-
norm_layer: normalization layer.
|
| 396 |
-
spatial_dims: number of spatial dims.
|
| 397 |
-
"""
|
| 398 |
-
|
| 399 |
-
super().__init__()
|
| 400 |
-
self.dim = dim
|
| 401 |
-
if spatial_dims == 3:
|
| 402 |
-
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
|
| 403 |
-
self.norm = norm_layer(8 * dim)
|
| 404 |
-
elif spatial_dims == 2:
|
| 405 |
-
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 406 |
-
self.norm = norm_layer(4 * dim)
|
| 407 |
-
|
| 408 |
-
def forward(self, x):
|
| 409 |
-
|
| 410 |
-
x_shape = x.size()
|
| 411 |
-
if len(x_shape) == 5:
|
| 412 |
-
b, d, h, w, c = x_shape
|
| 413 |
-
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
|
| 414 |
-
if pad_input:
|
| 415 |
-
x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
|
| 416 |
-
x0 = x[:, 0::2, 0::2, 0::2, :]
|
| 417 |
-
x1 = x[:, 1::2, 0::2, 0::2, :]
|
| 418 |
-
x2 = x[:, 0::2, 1::2, 0::2, :]
|
| 419 |
-
x3 = x[:, 0::2, 0::2, 1::2, :]
|
| 420 |
-
x4 = x[:, 1::2, 0::2, 1::2, :]
|
| 421 |
-
x5 = x[:, 0::2, 1::2, 0::2, :]
|
| 422 |
-
x6 = x[:, 0::2, 0::2, 1::2, :]
|
| 423 |
-
x7 = x[:, 1::2, 1::2, 1::2, :]
|
| 424 |
-
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
|
| 425 |
-
|
| 426 |
-
elif len(x_shape) == 4:
|
| 427 |
-
b, h, w, c = x_shape
|
| 428 |
-
pad_input = (h % 2 == 1) or (w % 2 == 1)
|
| 429 |
-
if pad_input:
|
| 430 |
-
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
|
| 431 |
-
x0 = x[:, 0::2, 0::2, :]
|
| 432 |
-
x1 = x[:, 1::2, 0::2, :]
|
| 433 |
-
x2 = x[:, 0::2, 1::2, :]
|
| 434 |
-
x3 = x[:, 1::2, 1::2, :]
|
| 435 |
-
x = torch.cat([x0, x1, x2, x3], -1)
|
| 436 |
-
|
| 437 |
-
x = self.norm(x)
|
| 438 |
-
x = self.reduction(x)
|
| 439 |
-
return x
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
def compute_mask(dims, window_size, shift_size, device):
|
| 443 |
-
"""Computing region masks based on: "Liu et al.,
|
| 444 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 445 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 446 |
-
https://github.com/microsoft/Swin-Transformer
|
| 447 |
-
Args:
|
| 448 |
-
dims: dimension values.
|
| 449 |
-
window_size: local window size.
|
| 450 |
-
shift_size: shift size.
|
| 451 |
-
device: device.
|
| 452 |
-
"""
|
| 453 |
-
|
| 454 |
-
cnt = 0
|
| 455 |
-
|
| 456 |
-
if len(dims) == 3:
|
| 457 |
-
d, h, w = dims
|
| 458 |
-
img_mask = torch.zeros((1, d, h, w, 1), device=device)
|
| 459 |
-
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 460 |
-
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 461 |
-
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
|
| 462 |
-
img_mask[:, d, h, w, :] = cnt
|
| 463 |
-
cnt += 1
|
| 464 |
-
|
| 465 |
-
elif len(dims) == 2:
|
| 466 |
-
h, w = dims
|
| 467 |
-
img_mask = torch.zeros((1, h, w, 1), device=device)
|
| 468 |
-
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
| 469 |
-
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
| 470 |
-
img_mask[:, h, w, :] = cnt
|
| 471 |
-
cnt += 1
|
| 472 |
-
|
| 473 |
-
mask_windows = window_partition(img_mask, window_size)
|
| 474 |
-
mask_windows = mask_windows.squeeze(-1)
|
| 475 |
-
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 476 |
-
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 477 |
-
|
| 478 |
-
return attn_mask
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
class BasicLayer(nn.Module):
|
| 482 |
-
"""
|
| 483 |
-
Basic Swin Transformer layer in one stage based on: "Liu et al.,
|
| 484 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 485 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 486 |
-
https://github.com/microsoft/Swin-Transformer
|
| 487 |
-
"""
|
| 488 |
-
|
| 489 |
-
def __init__(
|
| 490 |
-
self,
|
| 491 |
-
dim: int,
|
| 492 |
-
depth: int,
|
| 493 |
-
num_heads: int,
|
| 494 |
-
window_size: Sequence[int],
|
| 495 |
-
drop_path: list,
|
| 496 |
-
mlp_ratio: float = 4.0,
|
| 497 |
-
qkv_bias: bool = False,
|
| 498 |
-
drop: float = 0.0,
|
| 499 |
-
attn_drop: float = 0.0,
|
| 500 |
-
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 501 |
-
downsample: isinstance = None, # type: ignore
|
| 502 |
-
use_checkpoint: bool = False,
|
| 503 |
-
) -> None:
|
| 504 |
-
"""
|
| 505 |
-
Args:
|
| 506 |
-
dim: number of feature channels.
|
| 507 |
-
depths: number of layers in each stage.
|
| 508 |
-
num_heads: number of attention heads.
|
| 509 |
-
window_size: local window size.
|
| 510 |
-
drop_path: stochastic depth rate.
|
| 511 |
-
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 512 |
-
qkv_bias: add a learnable bias to query, key, value.
|
| 513 |
-
drop: dropout rate.
|
| 514 |
-
attn_drop: attention dropout rate.
|
| 515 |
-
norm_layer: normalization layer.
|
| 516 |
-
downsample: downsample layer at the end of the layer.
|
| 517 |
-
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 518 |
-
"""
|
| 519 |
-
|
| 520 |
-
super().__init__()
|
| 521 |
-
self.window_size = window_size
|
| 522 |
-
self.shift_size = tuple(i // 2 for i in window_size)
|
| 523 |
-
self.no_shift = tuple(0 for i in window_size)
|
| 524 |
-
self.depth = depth
|
| 525 |
-
self.use_checkpoint = use_checkpoint
|
| 526 |
-
self.blocks = nn.ModuleList(
|
| 527 |
-
[
|
| 528 |
-
SwinTransformerBlock(
|
| 529 |
-
dim=dim,
|
| 530 |
-
num_heads=num_heads,
|
| 531 |
-
window_size=self.window_size,
|
| 532 |
-
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
|
| 533 |
-
mlp_ratio=mlp_ratio,
|
| 534 |
-
qkv_bias=qkv_bias,
|
| 535 |
-
drop=drop,
|
| 536 |
-
attn_drop=attn_drop,
|
| 537 |
-
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 538 |
-
norm_layer=norm_layer,
|
| 539 |
-
use_checkpoint=use_checkpoint,
|
| 540 |
-
)
|
| 541 |
-
for i in range(depth)
|
| 542 |
-
]
|
| 543 |
-
)
|
| 544 |
-
self.downsample = downsample
|
| 545 |
-
if self.downsample is not None:
|
| 546 |
-
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
|
| 547 |
-
|
| 548 |
-
def forward(self, x):
|
| 549 |
-
x_shape = x.size()
|
| 550 |
-
if len(x_shape) == 5:
|
| 551 |
-
b, c, d, h, w = x_shape
|
| 552 |
-
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
| 553 |
-
x = rearrange(x, "b c d h w -> b d h w c")
|
| 554 |
-
dp = int(np.ceil(d / window_size[0])) * window_size[0]
|
| 555 |
-
hp = int(np.ceil(h / window_size[1])) * window_size[1]
|
| 556 |
-
wp = int(np.ceil(w / window_size[2])) * window_size[2]
|
| 557 |
-
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
|
| 558 |
-
for blk in self.blocks:
|
| 559 |
-
x = blk(x, attn_mask)
|
| 560 |
-
x = x.view(b, d, h, w, -1)
|
| 561 |
-
if self.downsample is not None:
|
| 562 |
-
x = self.downsample(x)
|
| 563 |
-
x = rearrange(x, "b d h w c -> b c d h w")
|
| 564 |
-
|
| 565 |
-
elif len(x_shape) == 4:
|
| 566 |
-
b, c, h, w = x_shape
|
| 567 |
-
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
| 568 |
-
x = rearrange(x, "b c h w -> b h w c")
|
| 569 |
-
hp = int(np.ceil(h / window_size[0])) * window_size[0]
|
| 570 |
-
wp = int(np.ceil(w / window_size[1])) * window_size[1]
|
| 571 |
-
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
|
| 572 |
-
for blk in self.blocks:
|
| 573 |
-
x = blk(x, attn_mask)
|
| 574 |
-
x = x.view(b, h, w, -1)
|
| 575 |
-
if self.downsample is not None:
|
| 576 |
-
x = self.downsample(x)
|
| 577 |
-
x = rearrange(x, "b h w c -> b c h w")
|
| 578 |
-
return x
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
class SwinTransformer(nn.Module):
|
| 582 |
-
"""
|
| 583 |
-
Swin Transformer based on: "Liu et al.,
|
| 584 |
-
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
| 585 |
-
<https://arxiv.org/abs/2103.14030>"
|
| 586 |
-
https://github.com/microsoft/Swin-Transformer
|
| 587 |
-
"""
|
| 588 |
-
|
| 589 |
-
def __init__(
|
| 590 |
-
self,
|
| 591 |
-
in_chans: int,
|
| 592 |
-
embed_dim: int,
|
| 593 |
-
window_size: Sequence[int],
|
| 594 |
-
patch_size: Sequence[int],
|
| 595 |
-
depths: Sequence[int],
|
| 596 |
-
num_heads: Sequence[int],
|
| 597 |
-
mlp_ratio: float = 4.0,
|
| 598 |
-
qkv_bias: bool = True,
|
| 599 |
-
drop_rate: float = 0.0,
|
| 600 |
-
attn_drop_rate: float = 0.0,
|
| 601 |
-
drop_path_rate: float = 0.0,
|
| 602 |
-
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
| 603 |
-
patch_norm: bool = False,
|
| 604 |
-
use_checkpoint: bool = False,
|
| 605 |
-
spatial_dims: int = 3,
|
| 606 |
-
) -> None:
|
| 607 |
-
"""
|
| 608 |
-
Args:
|
| 609 |
-
in_chans: dimension of input channels.
|
| 610 |
-
embed_dim: number of linear projection output channels.
|
| 611 |
-
window_size: local window size.
|
| 612 |
-
patch_size: patch size.
|
| 613 |
-
depths: number of layers in each stage.
|
| 614 |
-
num_heads: number of attention heads.
|
| 615 |
-
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
| 616 |
-
qkv_bias: add a learnable bias to query, key, value.
|
| 617 |
-
drop_rate: dropout rate.
|
| 618 |
-
attn_drop_rate: attention dropout rate.
|
| 619 |
-
drop_path_rate: stochastic depth rate.
|
| 620 |
-
norm_layer: normalization layer.
|
| 621 |
-
patch_norm: add normalization after patch embedding.
|
| 622 |
-
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
| 623 |
-
spatial_dims: spatial dimension.
|
| 624 |
-
"""
|
| 625 |
-
|
| 626 |
-
super().__init__()
|
| 627 |
-
self.num_layers = len(depths)
|
| 628 |
-
self.embed_dim = embed_dim
|
| 629 |
-
self.patch_norm = patch_norm
|
| 630 |
-
self.window_size = window_size
|
| 631 |
-
self.patch_size = patch_size
|
| 632 |
-
self.patch_embed = PatchEmbed(
|
| 633 |
-
patch_size=self.patch_size,
|
| 634 |
-
in_chans=in_chans,
|
| 635 |
-
embed_dim=embed_dim,
|
| 636 |
-
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
|
| 637 |
-
spatial_dims=spatial_dims,
|
| 638 |
-
)
|
| 639 |
-
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 640 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 641 |
-
# self.layers1 = nn.ModuleList()
|
| 642 |
-
# self.layers2 = nn.ModuleList()
|
| 643 |
-
# self.layers3 = nn.ModuleList()
|
| 644 |
-
# self.layers4 = nn.ModuleList()
|
| 645 |
-
self.layers = nn.ModuleList()
|
| 646 |
-
for i_layer in range(self.num_layers):
|
| 647 |
-
layer = BasicLayer(
|
| 648 |
-
dim=int(embed_dim * 2**i_layer),
|
| 649 |
-
depth=depths[i_layer],
|
| 650 |
-
num_heads=num_heads[i_layer],
|
| 651 |
-
window_size=self.window_size,
|
| 652 |
-
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
| 653 |
-
mlp_ratio=mlp_ratio,
|
| 654 |
-
qkv_bias=qkv_bias,
|
| 655 |
-
drop=drop_rate,
|
| 656 |
-
attn_drop=attn_drop_rate,
|
| 657 |
-
norm_layer=norm_layer,
|
| 658 |
-
downsample=PatchMerging,
|
| 659 |
-
use_checkpoint=use_checkpoint,
|
| 660 |
-
)
|
| 661 |
-
self.layers.append(layer)
|
| 662 |
-
# if i_layer == 0:
|
| 663 |
-
# self.layers1.append(layer)
|
| 664 |
-
# elif i_layer == 1:
|
| 665 |
-
# self.layers2.append(layer)
|
| 666 |
-
# elif i_layer == 2:
|
| 667 |
-
# self.layers3.append(layer)
|
| 668 |
-
# elif i_layer == 3:
|
| 669 |
-
# self.layers4.append(layer)
|
| 670 |
-
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 671 |
-
|
| 672 |
-
def proj_out(self, x, normalize=False):
|
| 673 |
-
if normalize:
|
| 674 |
-
x_shape = x.size()
|
| 675 |
-
if len(x_shape) == 5:
|
| 676 |
-
n, ch, d, h, w = x_shape
|
| 677 |
-
x = rearrange(x, "n c d h w -> n d h w c")
|
| 678 |
-
x = F.layer_norm(x, [ch])
|
| 679 |
-
x = rearrange(x, "n d h w c -> n c d h w")
|
| 680 |
-
elif len(x_shape) == 4:
|
| 681 |
-
n, ch, h, w = x_shape
|
| 682 |
-
x = rearrange(x, "n c h w -> n h w c")
|
| 683 |
-
x = F.layer_norm(x, [ch])
|
| 684 |
-
x = rearrange(x, "n h w c -> n c h w")
|
| 685 |
-
return x
|
| 686 |
-
|
| 687 |
-
def forward(self, x, normalize=True):
|
| 688 |
-
# x input: [B*sample, C(1), H, W, D]
|
| 689 |
-
# x = rearrange(x, "b c h w d -> b c d h w")
|
| 690 |
-
# print('>> input: ', x.shape)
|
| 691 |
-
x = self.patch_embed(x)
|
| 692 |
-
# print('>> patch_embed: ', x.shape)
|
| 693 |
-
x = self.pos_drop(x)
|
| 694 |
-
for layer in self.layers:
|
| 695 |
-
x = layer(x.contiguous())
|
| 696 |
-
# print('>> layer: ', x.shape)
|
| 697 |
-
return x
|
| 698 |
-
# # x0_out = self.proj_out(x0, normalize)
|
| 699 |
-
# x1 = self.layers1[0](x0.contiguous())
|
| 700 |
-
# # x1_out = self.proj_out(x1, normalize)
|
| 701 |
-
# x2 = self.layers2[0](x1.contiguous())
|
| 702 |
-
# # x2_out = self.proj_out(x2, normalize)
|
| 703 |
-
# x3 = self.layers3[0](x2.contiguous())
|
| 704 |
-
# # x3_out = self.proj_out(x3, normalize)
|
| 705 |
-
# x4 = self.layers4[0](x3.contiguous())
|
| 706 |
-
# # x4_out = self.proj_out(x4, normalize)
|
| 707 |
-
# # return [x0_out, x1_out, x2_out, x3_out, x4_out]
|
| 708 |
-
|
| 709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
|
| 11 |
-
from typing import Any, Optional, Tuple, Type
|
| 12 |
-
|
| 13 |
-
from .common import LayerNorm2d
|
| 14 |
-
import os
|
| 15 |
-
|
| 16 |
-
class PromptEncoder(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
embed_dim: int,
|
| 20 |
-
image_embedding_size: Tuple[int, int, int],
|
| 21 |
-
input_image_size: Tuple[int, int, int],
|
| 22 |
-
mask_in_chans: int,
|
| 23 |
-
activation: Type[nn.Module] = nn.GELU,
|
| 24 |
-
) -> None:
|
| 25 |
-
"""
|
| 26 |
-
Encodes prompts for input to SAM's mask decoder.
|
| 27 |
-
|
| 28 |
-
Arguments:
|
| 29 |
-
embed_dim (int): The prompts' embedding dimension
|
| 30 |
-
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 31 |
-
image embedding, as (H, W).
|
| 32 |
-
input_image_size (int): The padded size of the image as input
|
| 33 |
-
to the image encoder, as (H, W).
|
| 34 |
-
mask_in_chans (int): The number of hidden channels used for
|
| 35 |
-
encoding input masks.
|
| 36 |
-
activation (nn.Module): The activation to use when encoding
|
| 37 |
-
input masks.
|
| 38 |
-
"""
|
| 39 |
-
super().__init__()
|
| 40 |
-
self.embed_dim = embed_dim
|
| 41 |
-
self.input_image_size = input_image_size
|
| 42 |
-
self.image_embedding_size = image_embedding_size
|
| 43 |
-
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 44 |
-
|
| 45 |
-
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 46 |
-
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
| 47 |
-
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 48 |
-
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 49 |
-
|
| 50 |
-
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
|
| 51 |
-
self.mask_downscaling = nn.Sequential(
|
| 52 |
-
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 53 |
-
LayerNorm2d(mask_in_chans // 4),
|
| 54 |
-
activation(),
|
| 55 |
-
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 56 |
-
LayerNorm2d(mask_in_chans),
|
| 57 |
-
activation(),
|
| 58 |
-
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 59 |
-
)
|
| 60 |
-
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 61 |
-
|
| 62 |
-
def get_dense_pe(self) -> torch.Tensor:
|
| 63 |
-
"""
|
| 64 |
-
Returns the positional encoding used to encode point prompts,
|
| 65 |
-
applied to a dense set of points the shape of the image encoding.
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
torch.Tensor: Positional encoding with shape
|
| 69 |
-
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 70 |
-
"""
|
| 71 |
-
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 72 |
-
|
| 73 |
-
def _embed_points(
|
| 74 |
-
self,
|
| 75 |
-
points: torch.Tensor,
|
| 76 |
-
labels: torch.Tensor,
|
| 77 |
-
pad: bool,
|
| 78 |
-
) -> torch.Tensor:
|
| 79 |
-
"""Embeds point prompts."""
|
| 80 |
-
points = points + 0.5 # Shift to center of pixel
|
| 81 |
-
if pad:
|
| 82 |
-
padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
|
| 83 |
-
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 84 |
-
points = torch.cat([points, padding_point], dim=1)
|
| 85 |
-
labels = torch.cat([labels, padding_label], dim=1)
|
| 86 |
-
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
| 87 |
-
point_embedding[labels == -1] = 0.0
|
| 88 |
-
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 89 |
-
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
| 90 |
-
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
| 91 |
-
return point_embedding
|
| 92 |
-
|
| 93 |
-
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 94 |
-
"""Embeds box prompts."""
|
| 95 |
-
boxes = boxes + 0.5 # Shift to center of pixel
|
| 96 |
-
coords = boxes.reshape(-1, 2, 3)
|
| 97 |
-
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
| 98 |
-
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 99 |
-
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 100 |
-
return corner_embedding
|
| 101 |
-
|
| 102 |
-
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 103 |
-
"""Embeds mask inputs."""
|
| 104 |
-
mask_embedding = self.mask_downscaling(masks)
|
| 105 |
-
return mask_embedding
|
| 106 |
-
|
| 107 |
-
def _get_batch_size(
|
| 108 |
-
self,
|
| 109 |
-
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 110 |
-
boxes: Optional[torch.Tensor],
|
| 111 |
-
masks: Optional[torch.Tensor],
|
| 112 |
-
text_embedding: Optional[torch.Tensor],
|
| 113 |
-
) -> int:
|
| 114 |
-
"""
|
| 115 |
-
Gets the batch size of the output given the batch size of the input prompts.
|
| 116 |
-
"""
|
| 117 |
-
if points is not None:
|
| 118 |
-
return points[0].shape[0]
|
| 119 |
-
elif boxes is not None:
|
| 120 |
-
return boxes.shape[0]
|
| 121 |
-
elif masks is not None:
|
| 122 |
-
return masks.shape[0]
|
| 123 |
-
elif text_embedding is not None:
|
| 124 |
-
return text_embedding.shape[0]
|
| 125 |
-
else:
|
| 126 |
-
return 1
|
| 127 |
-
|
| 128 |
-
def _get_device(self) -> torch.device:
|
| 129 |
-
return self.point_embeddings[0].weight.device
|
| 130 |
-
|
| 131 |
-
def forward(
|
| 132 |
-
self,
|
| 133 |
-
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 134 |
-
boxes: Optional[torch.Tensor],
|
| 135 |
-
masks: Optional[torch.Tensor],
|
| 136 |
-
text_embedding: Optional[torch.Tensor],
|
| 137 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 138 |
-
"""
|
| 139 |
-
Embeds different types of prompts, returning both sparse and dense
|
| 140 |
-
embeddings.
|
| 141 |
-
|
| 142 |
-
Arguments:
|
| 143 |
-
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 144 |
-
and labels to embed.
|
| 145 |
-
boxes (torch.Tensor or none): boxes to embed
|
| 146 |
-
masks (torch.Tensor or none): masks to embed
|
| 147 |
-
text: test prompt (B, 768)
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 151 |
-
BxNx(embed_dim), where N is determined by the number of input points
|
| 152 |
-
and boxes.
|
| 153 |
-
torch.Tensor: dense embeddings for the masks, in the shape
|
| 154 |
-
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 155 |
-
"""
|
| 156 |
-
# print('prompt encoder here...')
|
| 157 |
-
|
| 158 |
-
bs = self._get_batch_size(points, boxes, masks, text_embedding)
|
| 159 |
-
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
| 160 |
-
# print('sparse_embeddings ', sparse_embeddings.shape)
|
| 161 |
-
if points is not None:
|
| 162 |
-
coords, labels = points
|
| 163 |
-
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 164 |
-
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 165 |
-
|
| 166 |
-
if boxes is not None:
|
| 167 |
-
box_embeddings = self._embed_boxes(boxes)
|
| 168 |
-
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 169 |
-
|
| 170 |
-
if text_embedding is not None:
|
| 171 |
-
sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
|
| 172 |
-
|
| 173 |
-
# print('box_embeddings ', box_embeddings.shape)
|
| 174 |
-
# print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
|
| 175 |
-
|
| 176 |
-
if masks is not None:
|
| 177 |
-
dense_embeddings = self._embed_masks(masks)
|
| 178 |
-
else:
|
| 179 |
-
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
|
| 180 |
-
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]
|
| 181 |
-
)
|
| 182 |
-
# print('dense_embeddings ', dense_embeddings.shape)
|
| 183 |
-
return sparse_embeddings, dense_embeddings
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class PositionEmbeddingRandom(nn.Module):
|
| 187 |
-
"""
|
| 188 |
-
Positional encoding using random spatial frequencies.
|
| 189 |
-
"""
|
| 190 |
-
|
| 191 |
-
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 192 |
-
super().__init__()
|
| 193 |
-
if scale is None or scale <= 0.0:
|
| 194 |
-
scale = 1.0
|
| 195 |
-
self.register_buffer(
|
| 196 |
-
"positional_encoding_gaussian_matrix",
|
| 197 |
-
scale * torch.randn((3, num_pos_feats)),
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 201 |
-
"""Positionally encode points that are normalized to [0,1]."""
|
| 202 |
-
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 203 |
-
coords = 2 * coords - 1
|
| 204 |
-
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 205 |
-
coords = 2 * np.pi * coords
|
| 206 |
-
# outputs d_1 x ... x d_n x C shape
|
| 207 |
-
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 208 |
-
|
| 209 |
-
def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
|
| 210 |
-
"""Generate positional encoding for a grid of the specified size."""
|
| 211 |
-
h, w, d = size
|
| 212 |
-
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 213 |
-
grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
|
| 214 |
-
y_embed = grid.cumsum(dim=0) - 0.5
|
| 215 |
-
x_embed = grid.cumsum(dim=1) - 0.5
|
| 216 |
-
z_embed = grid.cumsum(dim=2) - 0.5
|
| 217 |
-
y_embed = y_embed / h
|
| 218 |
-
x_embed = x_embed / w
|
| 219 |
-
z_embed = z_embed / d
|
| 220 |
-
|
| 221 |
-
pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
|
| 222 |
-
return pe.permute(3, 0, 1, 2) # C x H x W x D
|
| 223 |
-
|
| 224 |
-
def forward_with_coords(
|
| 225 |
-
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 226 |
-
) -> torch.Tensor:
|
| 227 |
-
"""Positionally encode points that are not normalized to [0,1]."""
|
| 228 |
-
coords = coords_input.clone()
|
| 229 |
-
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 230 |
-
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 231 |
-
coords[:, :, 2] = coords[:, :, 2] / image_size[2]
|
| 232 |
-
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/modeling/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
|
| 4 |
-
# This source code is licensed under the license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
from .sam import Sam
|
| 8 |
-
from .image_encoder import ImageEncoderViT
|
| 9 |
-
from .mask_decoder import MaskDecoder
|
| 10 |
-
from .prompt_encoder import PromptEncoder
|
| 11 |
-
from .transformer import TwoWayTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (394 Bytes)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (433 Bytes)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc
DELETED
|
Binary file (1.75 kB)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc
DELETED
|
Binary file (1.78 kB)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc
DELETED
|
Binary file (12.6 kB)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc
DELETED
|
Binary file (11.5 kB)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc
DELETED
|
Binary file (21.5 kB)
|
|
|
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc
DELETED
|
Binary file (5.5 kB)
|
|
|