Upload 2 files
Browse files- app.py +244 -142
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import threading
|
| 4 |
import time
|
|
@@ -23,22 +25,24 @@ class MRIInference:
|
|
| 23 |
self.output_shape = output_shape
|
| 24 |
|
| 25 |
def load_image(self, file_path):
|
| 26 |
-
# Load
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
scale = 255 / (max_val - min_val)
|
| 33 |
-
normalized_image = scale * (
|
|
|
|
| 34 |
scale_factors = (
|
| 35 |
self.input_shape[0] / normalized_image.shape[0],
|
| 36 |
self.input_shape[1] / normalized_image.shape[1],
|
| 37 |
self.input_shape[2] / normalized_image.shape[2]
|
| 38 |
)
|
| 39 |
-
resampled_image = zoom(normalized_image, scale_factors, order=
|
| 40 |
-
return torch.tensor(
|
| 41 |
-
resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
|
| 42 |
|
| 43 |
def save_image(self, image, file_name):
|
| 44 |
# Save processed image to file
|
|
@@ -48,8 +52,7 @@ class MRIInference:
|
|
| 48 |
self.output_shape[1] / image.shape[1],
|
| 49 |
self.output_shape[2] / image.shape[2]
|
| 50 |
)
|
| 51 |
-
resampled_image = zoom(image, scale_factors, order=
|
| 52 |
-
resampled_image = np.rot90(resampled_image, k=-1, axes=(1, 2))
|
| 53 |
nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
|
| 54 |
|
| 55 |
def match_sform_affine(self, orig_path, gen_path):
|
|
@@ -61,62 +64,77 @@ class MRIInference:
|
|
| 61 |
matched_gen_img = nib.Nifti1Image(gen_data, orig_affine)
|
| 62 |
nib.save(matched_gen_img, gen_path)
|
| 63 |
|
| 64 |
-
def infer(self,
|
|
|
|
|
|
|
|
|
|
| 65 |
# Perform inference on input tensor
|
| 66 |
with torch.no_grad():
|
| 67 |
self.model.eval()
|
| 68 |
output = self.model(input_tensor.to(self.device))
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
resampled_output = zoom(
|
| 71 |
-
output.squeeze().cpu().numpy(), scale_factor, order=
|
| 72 |
generated_image = torch.tensor(resampled_output[np.newaxis, ...])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz')
|
| 74 |
resampled_file_path = resample_to_isotropic(
|
| 75 |
original_file_path, temp_orig_path)
|
| 76 |
-
temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz')
|
| 77 |
-
self.save_image(generated_image, temp_generated_path)
|
| 78 |
self.match_sform_affine(resampled_file_path, temp_generated_path)
|
|
|
|
| 79 |
resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz')
|
| 80 |
resample_to_isotropic(temp_generated_path, resampled_generated_path)
|
|
|
|
| 81 |
base_name = os.path.basename(original_file_path)
|
| 82 |
gen_file_name = f"{Path(base_name).stem}_{int(time.time())}_gen.nii.gz"
|
| 83 |
warped_file_path = os.path.join(output_path, gen_file_name)
|
| 84 |
affine_registration(
|
| 85 |
-
resampled_file_path,
|
|
|
|
|
|
|
| 86 |
for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
|
| 87 |
os.remove(temp_file)
|
|
|
|
| 88 |
return warped_file_path
|
| 89 |
-
|
| 90 |
-
# Perform inference and handle images
|
| 91 |
-
def run_inference(input_tensor, temp_file_path, output_path):
|
| 92 |
-
try:
|
| 93 |
-
warped_image_path = inference_engine.infer(
|
| 94 |
-
input_tensor, temp_file_path, output_path)
|
| 95 |
-
|
| 96 |
-
gen_file_name = temp_file_path.replace(".nii", "_gen.nii")
|
| 97 |
-
download_file_path = os.path.join(output_path, gen_file_name)
|
| 98 |
-
shutil.copy(warped_image_path, download_file_path)
|
| 99 |
-
|
| 100 |
-
original_img = nib.load(temp_file_path).get_fdata()
|
| 101 |
-
inferred_img = nib.load(warped_image_path).get_fdata()
|
| 102 |
-
|
| 103 |
-
original_slice_path = os.path.join(output_path, "original_slice.jpg")
|
| 104 |
-
inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
|
| 105 |
-
save_middle_slice(original_img, original_slice_path)
|
| 106 |
-
save_middle_slice(inferred_img, inferred_slice_path)
|
| 107 |
-
|
| 108 |
-
return (original_slice_path, inferred_slice_path,
|
| 109 |
-
download_file_path, gen_file_name)
|
| 110 |
-
except Exception as e:
|
| 111 |
-
st.error(f"Error during inference: {e}")
|
| 112 |
-
return None, None, None, None
|
| 113 |
|
| 114 |
# Image processing functions
|
| 115 |
def resample_to_isotropic(image_path, output_path):
|
| 116 |
# Resample image to isotropic resolution
|
| 117 |
image = ants.image_read(image_path)
|
| 118 |
resampled_image = ants.resample_image(
|
| 119 |
-
image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=
|
| 120 |
ants.image_write(resampled_image, output_path)
|
| 121 |
return output_path
|
| 122 |
|
|
@@ -126,33 +144,123 @@ def affine_registration(fixed_image_path, moving_image_path, output_path):
|
|
| 126 |
moving_image = ants.image_read(moving_image_path)
|
| 127 |
registration = ants.registration(
|
| 128 |
fixed=fixed_image, moving=moving_image,
|
| 129 |
-
type_of_transform='
|
| 130 |
ants.image_write(registration['warpedmovout'], output_path)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
@st.cache_data
|
| 133 |
-
def load_model():
|
| 134 |
-
# Load pre-trained model
|
| 135 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 136 |
generator = ResnetGenerator().to(device)
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 139 |
generator.load_state_dict(checkpoint)
|
| 140 |
return generator, device
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def save_middle_slice(image, file_path):
|
| 147 |
# Save the middle slice of the MRI image
|
| 148 |
middle_slice = image[image.shape[0] // 2]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
fig, ax = plt.subplots(figsize=(5, 5))
|
| 150 |
-
ax.imshow(
|
| 151 |
ax.axis('off')
|
| 152 |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
| 153 |
plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500)
|
| 154 |
plt.close()
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
def clear_output_folder(folder_path):
|
| 157 |
# Clear contents of a specified folder
|
| 158 |
for filename in os.listdir(folder_path):
|
|
@@ -169,27 +277,46 @@ def clear_session():
|
|
| 169 |
|
| 170 |
# Main function for Streamlit UI
|
| 171 |
def main():
|
| 172 |
-
global original_slice_path, inferred_slice_path, download_file_path, gen_file_name
|
| 173 |
|
| 174 |
-
|
| 175 |
-
st.sidebar.subheader("_How to Use EasySR_", divider='red')
|
| 176 |
st.sidebar.markdown(
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
"
|
| 186 |
-
"
|
| 187 |
-
|
| 188 |
-
"Continue using EasySR for more enhancements.\n\n"
|
| 189 |
-
":rocket: :red[*EasySR*] \t [Github](https://github.com/hwonheo/easysr)\n\n"
|
| 190 |
-
":hugging_face: :orange[*EasySR*] \t [Huggingface](https://huggingface.co/spaces/hwonheo/easysr)"
|
| 191 |
)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# Main interface layout
|
| 194 |
st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
|
| 195 |
st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red')
|
|
@@ -206,88 +333,63 @@ def main():
|
|
| 206 |
# File uploader for MRI files
|
| 207 |
uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
|
| 208 |
type=["nii", "nii.gz"], key='file_uploader')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
if uploaded_file is not None:
|
| 211 |
# Store uploaded file in session state
|
| 212 |
st.session_state['uploaded_file'] = uploaded_file
|
| 213 |
file_name = uploaded_file.name
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
# Inference start button
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
# Save middle slice of both images for comparison
|
| 247 |
-
save_middle_slice(original_img, original_slice_path)
|
| 248 |
-
save_middle_slice(inferred_img, inferred_slice_path)
|
| 249 |
-
except Exception as e:
|
| 250 |
-
st.error(f"Error during inference: {e}")
|
| 251 |
-
finally:
|
| 252 |
-
if temp_file_path and os.path.exists(temp_file_path):
|
| 253 |
-
os.remove(temp_file_path)
|
| 254 |
-
|
| 255 |
-
# Start thread for inference
|
| 256 |
-
inference_thread = threading.Thread(target=inference_wrapper)
|
| 257 |
-
inference_thread.start()
|
| 258 |
-
|
| 259 |
-
# Display spinner while processing
|
| 260 |
-
with st.spinner("Processing your MRI image..."):
|
| 261 |
-
inference_thread.join()
|
| 262 |
-
|
| 263 |
-
# Display comparison images and download button after processing
|
| 264 |
-
if original_slice_path and os.path.exists(original_slice_path) \
|
| 265 |
-
and inferred_slice_path and os.path.exists(inferred_slice_path):
|
| 266 |
-
st.subheader("Comparison of Original and EasySR Inferred Slice")
|
| 267 |
-
col1, col2 = st.columns([0.5, 0.5])
|
| 268 |
-
with col1:
|
| 269 |
-
st.markdown("**Original**")
|
| 270 |
-
st.image(original_slice_path, caption="Original MRI", width=300)
|
| 271 |
-
with col2:
|
| 272 |
-
st.markdown("**EasySR**")
|
| 273 |
-
st.image(inferred_slice_path, caption="Inferred MRI", width=300)
|
| 274 |
-
|
| 275 |
-
if download_file_path and os.path.exists(download_file_path):
|
| 276 |
-
with open(download_file_path, "rb") as file:
|
| 277 |
-
st.download_button(
|
| 278 |
-
label="Download (EasySR Inferred-MRI)",
|
| 279 |
-
data=file,
|
| 280 |
-
file_name=gen_file_name,
|
| 281 |
-
mime="application/gzip",
|
| 282 |
-
type="primary"
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
# Button to clear generated content
|
| 286 |
-
if st.button('Clear Generated All',
|
| 287 |
help='Pressing this will delete the contents of the generate folder.'):
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
|
| 292 |
# Entry point for the Streamlit application
|
| 293 |
if __name__ == '__main__':
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
+
import subprocess
|
| 4 |
import tempfile
|
| 5 |
import threading
|
| 6 |
import time
|
|
|
|
| 25 |
self.output_shape = output_shape
|
| 26 |
|
| 27 |
def load_image(self, file_path):
|
| 28 |
+
# Load the image using nibabel
|
| 29 |
+
nib_image = nib.load(file_path)
|
| 30 |
+
|
| 31 |
+
image_data = nib_image.get_fdata()
|
| 32 |
+
rotated_image = np.rot90(image_data, k=1, axes=(1, 2))
|
| 33 |
+
|
| 34 |
+
# Standard normalization to 0-255
|
| 35 |
+
min_val, max_val = np.min(rotated_image), np.max(rotated_image)
|
| 36 |
scale = 255 / (max_val - min_val)
|
| 37 |
+
normalized_image = scale * (rotated_image - min_val)
|
| 38 |
+
|
| 39 |
scale_factors = (
|
| 40 |
self.input_shape[0] / normalized_image.shape[0],
|
| 41 |
self.input_shape[1] / normalized_image.shape[1],
|
| 42 |
self.input_shape[2] / normalized_image.shape[2]
|
| 43 |
)
|
| 44 |
+
resampled_image = zoom(normalized_image, scale_factors, order=3)
|
| 45 |
+
return torch.tensor(resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32)
|
|
|
|
| 46 |
|
| 47 |
def save_image(self, image, file_name):
|
| 48 |
# Save processed image to file
|
|
|
|
| 52 |
self.output_shape[1] / image.shape[1],
|
| 53 |
self.output_shape[2] / image.shape[2]
|
| 54 |
)
|
| 55 |
+
resampled_image = zoom(image, scale_factors, order=3)
|
|
|
|
| 56 |
nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name)
|
| 57 |
|
| 58 |
def match_sform_affine(self, orig_path, gen_path):
|
|
|
|
| 64 |
matched_gen_img = nib.Nifti1Image(gen_data, orig_affine)
|
| 65 |
nib.save(matched_gen_img, gen_path)
|
| 66 |
|
| 67 |
+
def infer(self, aligned_image_path, original_file_path, output_path):
|
| 68 |
+
# Load and preprocess the image from aligned_image_path
|
| 69 |
+
input_tensor = self.load_image(aligned_image_path)
|
| 70 |
+
|
| 71 |
# Perform inference on input tensor
|
| 72 |
with torch.no_grad():
|
| 73 |
self.model.eval()
|
| 74 |
output = self.model(input_tensor.to(self.device))
|
| 75 |
+
|
| 76 |
+
# Resample output to target shape
|
| 77 |
+
scale_factor = (
|
| 78 |
+
self.output_shape[0] / output.shape[2],
|
| 79 |
+
self.output_shape[1] / output.shape[3],
|
| 80 |
+
self.output_shape[2] / output.shape[4]
|
| 81 |
+
)
|
| 82 |
resampled_output = zoom(
|
| 83 |
+
output.squeeze().cpu().numpy(), scale_factor, order=3)
|
| 84 |
generated_image = torch.tensor(resampled_output[np.newaxis, ...])
|
| 85 |
+
|
| 86 |
+
# Save the generated image
|
| 87 |
+
temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz')
|
| 88 |
+
self.save_image(generated_image, temp_generated_path)
|
| 89 |
+
|
| 90 |
+
# Get and print orientation code of the original image
|
| 91 |
+
orig_img = ants.image_read(original_file_path)
|
| 92 |
+
orig_orientation = ants.get_orientation(orig_img)
|
| 93 |
+
|
| 94 |
+
# Reorient the generated image based on original orientation
|
| 95 |
+
gen_img = nib.load(temp_generated_path)
|
| 96 |
+
gen_data = gen_img.get_fdata()
|
| 97 |
+
reoriented_image = ants.from_numpy(gen_data)
|
| 98 |
+
#print(f"Orientation of the original image: {orig_orientation}") ## print Orientation ##
|
| 99 |
+
|
| 100 |
+
if orig_orientation == 'LSP':
|
| 101 |
+
reoriented_image = ants.reorient_image2(reoriented_image, 'RAI')
|
| 102 |
+
elif orig_orientation == 'LPI':
|
| 103 |
+
reoriented_image = ants.reorient_image2(reoriented_image, 'RIP')
|
| 104 |
+
elif orig_orientation == 'RAS':
|
| 105 |
+
reoriented_image = ants.reorient_image2(reoriented_image, 'LSA')
|
| 106 |
+
# No reorientation for other cases
|
| 107 |
+
|
| 108 |
+
# Save the reoriented image
|
| 109 |
+
nib.save(nib.Nifti1Image(reoriented_image.numpy(), np.eye(4)), temp_generated_path)
|
| 110 |
+
|
| 111 |
+
# Match affine and resample
|
| 112 |
temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz')
|
| 113 |
resampled_file_path = resample_to_isotropic(
|
| 114 |
original_file_path, temp_orig_path)
|
|
|
|
|
|
|
| 115 |
self.match_sform_affine(resampled_file_path, temp_generated_path)
|
| 116 |
+
|
| 117 |
resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz')
|
| 118 |
resample_to_isotropic(temp_generated_path, resampled_generated_path)
|
| 119 |
+
|
| 120 |
base_name = os.path.basename(original_file_path)
|
| 121 |
gen_file_name = f"{Path(base_name).stem}_{int(time.time())}_gen.nii.gz"
|
| 122 |
warped_file_path = os.path.join(output_path, gen_file_name)
|
| 123 |
affine_registration(
|
| 124 |
+
resampled_file_path, temp_generated_path, warped_file_path)
|
| 125 |
+
|
| 126 |
+
# Remove temporary files
|
| 127 |
for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]:
|
| 128 |
os.remove(temp_file)
|
| 129 |
+
|
| 130 |
return warped_file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
# Image processing functions
|
| 133 |
def resample_to_isotropic(image_path, output_path):
|
| 134 |
# Resample image to isotropic resolution
|
| 135 |
image = ants.image_read(image_path)
|
| 136 |
resampled_image = ants.resample_image(
|
| 137 |
+
image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=3)
|
| 138 |
ants.image_write(resampled_image, output_path)
|
| 139 |
return output_path
|
| 140 |
|
|
|
|
| 144 |
moving_image = ants.image_read(moving_image_path)
|
| 145 |
registration = ants.registration(
|
| 146 |
fixed=fixed_image, moving=moving_image,
|
| 147 |
+
type_of_transform='Rigid')
|
| 148 |
ants.image_write(registration['warpedmovout'], output_path)
|
| 149 |
|
| 150 |
+
def align_to_template(resampled_image_path, template_path, output_path):
|
| 151 |
+
# Align the resampled image to the template
|
| 152 |
+
moving_image = ants.image_read(resampled_image_path)
|
| 153 |
+
fixed_image = ants.image_read(template_path)
|
| 154 |
+
registration = ants.registration(
|
| 155 |
+
fixed=fixed_image, moving=moving_image,
|
| 156 |
+
type_of_transform='Rigid')
|
| 157 |
+
aligned_image = registration['warpedmovout']
|
| 158 |
+
ants.image_write(aligned_image, output_path)
|
| 159 |
+
return output_path
|
| 160 |
+
|
| 161 |
+
def download_model_if_needed(templates_folder):
|
| 162 |
+
"""Downloads model from Hugging Face if template folder is empty or doesn't exist."""
|
| 163 |
+
if not os.path.exists(templates_folder) or not os.listdir(templates_folder):
|
| 164 |
+
print("Downloading model from Hugging Face...")
|
| 165 |
+
os.makedirs(templates_folder, exist_ok=True)
|
| 166 |
+
subprocess.run(["huggingface-cli", "download", "hwonheo/easysr_templates",
|
| 167 |
+
"--local-dir", "templates", "--local-dir-use-symlinks", "False"], check=True)
|
| 168 |
+
|
| 169 |
@st.cache_data
|
| 170 |
+
def load_model(model_choice):
|
| 171 |
+
# Load pre-trained model based on user selection
|
| 172 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 173 |
generator = ResnetGenerator().to(device)
|
| 174 |
+
|
| 175 |
+
if model_choice == "T1-Model":
|
| 176 |
+
checkpoint_path = 'ckpt/ckpt_final/G_latest_T1.pth'
|
| 177 |
+
else: # "Mixed-Model"
|
| 178 |
+
checkpoint_path = 'ckpt/ckpt_final/G_latest_Mixed.pth'
|
| 179 |
+
|
| 180 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 181 |
generator.load_state_dict(checkpoint)
|
| 182 |
return generator, device
|
| 183 |
|
| 184 |
+
def run_bias_field_correction(file_path, output_path, correction_type):
|
| 185 |
+
"""Bias field correction script and return corrected file path"""
|
| 186 |
+
corrected_file_name = os.path.basename(file_path).replace('.nii', '_corrected.nii')
|
| 187 |
+
corrected_file_path = os.path.join(output_path, corrected_file_name)
|
| 188 |
+
|
| 189 |
+
subprocess.run([
|
| 190 |
+
sys.executable, "utils/BiasFieldCorrection.py",
|
| 191 |
+
"--input", file_path,
|
| 192 |
+
"--output", output_path,
|
| 193 |
+
"--type", correction_type
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
# Rename the processed file if necessary
|
| 197 |
+
original_corrected_file_path = os.path.join(output_path, os.path.basename(file_path))
|
| 198 |
+
if os.path.exists(original_corrected_file_path) and original_corrected_file_path != corrected_file_path:
|
| 199 |
+
shutil.move(original_corrected_file_path, corrected_file_path)
|
| 200 |
+
|
| 201 |
+
return corrected_file_path
|
| 202 |
+
|
| 203 |
+
# Perform inference and handle images
|
| 204 |
+
def run_inference(inference_engine, aligned_image_path, original_file_path, output_path):
|
| 205 |
+
try:
|
| 206 |
+
# Perform inference using the aligned image and original file path
|
| 207 |
+
warped_image_path = inference_engine.infer(aligned_image_path, original_file_path, output_path)
|
| 208 |
+
|
| 209 |
+
# Generate file name for output
|
| 210 |
+
gen_file_name = os.path.basename(original_file_path).replace(".nii", "_gen.nii")
|
| 211 |
+
download_file_path = os.path.join(output_path, gen_file_name)
|
| 212 |
+
|
| 213 |
+
# Copy the processed file to the download path
|
| 214 |
+
shutil.copy(warped_image_path, download_file_path)
|
| 215 |
+
|
| 216 |
+
# Load original and inferred images for display
|
| 217 |
+
original_img = nib.load(original_file_path).get_fdata()
|
| 218 |
+
inferred_img = nib.load(warped_image_path).get_fdata()
|
| 219 |
+
|
| 220 |
+
# Save middle slices of both images for comparison
|
| 221 |
+
original_slice_path = os.path.join(output_path, "original_slice.jpg")
|
| 222 |
+
inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg")
|
| 223 |
+
save_middle_slice(original_img, original_slice_path)
|
| 224 |
+
save_middle_slice(inferred_img, inferred_slice_path)
|
| 225 |
+
|
| 226 |
+
# Return paths for UI display
|
| 227 |
+
return (original_slice_path, inferred_slice_path, download_file_path, gen_file_name)
|
| 228 |
+
except Exception as e:
|
| 229 |
+
st.error(f"Error during inference: {e}")
|
| 230 |
+
return None, None, None, None
|
| 231 |
|
| 232 |
def save_middle_slice(image, file_path):
|
| 233 |
# Save the middle slice of the MRI image
|
| 234 |
middle_slice = image[image.shape[0] // 2]
|
| 235 |
+
|
| 236 |
+
# Rotate the image 90 degrees counterclockwise
|
| 237 |
+
rotated_slice = np.rot90(middle_slice)
|
| 238 |
+
|
| 239 |
fig, ax = plt.subplots(figsize=(5, 5))
|
| 240 |
+
ax.imshow(rotated_slice, cmap='gray', aspect='auto')
|
| 241 |
ax.axis('off')
|
| 242 |
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
| 243 |
plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500)
|
| 244 |
plt.close()
|
| 245 |
|
| 246 |
+
def display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name):
|
| 247 |
+
st.subheader("Comparison of Original and EasySR Inferred Slice")
|
| 248 |
+
col1, col2 = st.columns([0.5, 0.5])
|
| 249 |
+
with col1:
|
| 250 |
+
st.image(original_slice_path, caption="Original MRI", width=300)
|
| 251 |
+
with col2:
|
| 252 |
+
st.image(inferred_slice_path, caption="Inferred MRI", width=300)
|
| 253 |
+
|
| 254 |
+
if os.path.exists(download_file_path):
|
| 255 |
+
with open(download_file_path, "rb") as file:
|
| 256 |
+
st.download_button(
|
| 257 |
+
label="Download (EasySR Inferred-MRI)",
|
| 258 |
+
data=file,
|
| 259 |
+
file_name=gen_file_name,
|
| 260 |
+
mime="application/gzip",
|
| 261 |
+
type="primary"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
def clear_output_folder(folder_path):
|
| 265 |
# Clear contents of a specified folder
|
| 266 |
for filename in os.listdir(folder_path):
|
|
|
|
| 277 |
|
| 278 |
# Main function for Streamlit UI
|
| 279 |
def main():
|
| 280 |
+
global original_slice_path, inferred_slice_path, download_file_path, gen_file_name, intensity_adjust
|
| 281 |
|
| 282 |
+
st.sidebar.markdown("# ")
|
|
|
|
| 283 |
st.sidebar.markdown(
|
| 284 |
+
"[]"
|
| 285 |
+
"(https://github.com/hwonheo/easysr)"
|
| 286 |
+
)
|
| 287 |
+
st.sidebar.markdown("# ")
|
| 288 |
+
|
| 289 |
+
# Setup sidebar with instructions and model selection
|
| 290 |
+
st.sidebar.subheader("*Model Selection*", divider='red')
|
| 291 |
+
model_choice = st.sidebar.selectbox(
|
| 292 |
+
"Choose the model type:",
|
| 293 |
+
("Mixed-Model", "T1-Model"),
|
| 294 |
+
index=1 # Default is Combined-Model
|
|
|
|
|
|
|
|
|
|
| 295 |
)
|
| 296 |
|
| 297 |
+
st.sidebar.header("\n")
|
| 298 |
+
|
| 299 |
+
# Setup sidebar with instructions
|
| 300 |
+
st.sidebar.subheader("_How to Use EasySR_", divider='red')
|
| 301 |
+
with st.sidebar.expander("Step-by-Step Guide:"):
|
| 302 |
+
st.markdown(
|
| 303 |
+
"1. **Prepare Your Data**: Make sure your rat brain MRI data "
|
| 304 |
+
"is in NIFTI format. Convert if needed.\n\n"
|
| 305 |
+
"2. **Upload Your MRI**: Drag and drop your NIFTI file "
|
| 306 |
+
"or use the upload button.\n\n"
|
| 307 |
+
"3. **Start the EasySR**: Click 'EasySR' to begin processing. "
|
| 308 |
+
"It usually takes a few minutes.\n\n"
|
| 309 |
+
"4. **Sit Back and Relax**: Wait while your data is processed quickly.\n\n"
|
| 310 |
+
"5. **View and Download**: After processing, view the results and "
|
| 311 |
+
"use the download button to save the enhanced MRI data.\n\n"
|
| 312 |
+
"6. **Use as Needed**: Download and utilize your enhanced MRI. "
|
| 313 |
+
"Continue using EasySR for more enhancements.\n\n"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Initialize model and inference engine with the selected model
|
| 317 |
+
generator, device = load_model(model_choice)
|
| 318 |
+
inference_engine = MRIInference(generator, device, (128, 128, 64), (128, 128, 192))
|
| 319 |
+
|
| 320 |
# Main interface layout
|
| 321 |
st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True)
|
| 322 |
st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red')
|
|
|
|
| 333 |
# File uploader for MRI files
|
| 334 |
uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_",
|
| 335 |
type=["nii", "nii.gz"], key='file_uploader')
|
| 336 |
+
|
| 337 |
+
# Checkbox for intensity adjustment
|
| 338 |
+
intensity_adjust = st.checkbox("Bias Field Correction (enhance signal intensity)",
|
| 339 |
+
help="Apply intensity truncation and bias correction to an image: "
|
| 340 |
+
"Check this option if the input image exhibits low signal intensity "
|
| 341 |
+
"(common in T2RARE, TOF, etc.) or if the output from the inference "
|
| 342 |
+
"process appears weakly signaled. This will enhance the signals by "
|
| 343 |
+
"N4-bias correction and very low- or high-signal intensity truncation, "
|
| 344 |
+
"yielding clearer and more defined results.")
|
| 345 |
|
| 346 |
if uploaded_file is not None:
|
| 347 |
# Store uploaded file in session state
|
| 348 |
st.session_state['uploaded_file'] = uploaded_file
|
| 349 |
file_name = uploaded_file.name
|
| 350 |
|
| 351 |
+
# Temporary directory for file processing
|
| 352 |
+
temp_dir = tempfile.gettempdir()
|
| 353 |
+
temp_file_path = os.path.join(temp_dir, file_name)
|
| 354 |
+
|
| 355 |
+
# Write uploaded file to temp directory
|
| 356 |
+
with open(temp_file_path, "wb") as tmp_file:
|
| 357 |
+
tmp_file.write(uploaded_file.getvalue())
|
| 358 |
+
|
| 359 |
# Inference start button
|
| 360 |
+
if st.button("EasySR (start inference)", type="primary"):
|
| 361 |
+
try:
|
| 362 |
+
# Bias Field Correction
|
| 363 |
+
corrected_file_path = run_bias_field_correction(
|
| 364 |
+
temp_file_path, temp_dir, "abp") if intensity_adjust else temp_file_path
|
| 365 |
+
|
| 366 |
+
# Ensure template files are available
|
| 367 |
+
templates_folder = "templates"
|
| 368 |
+
download_model_if_needed(templates_folder)
|
| 369 |
+
template_path = os.path.join(templates_folder, "bmc_t2_rat.nii.gz")
|
| 370 |
+
|
| 371 |
+
# Resample and align the image
|
| 372 |
+
resampled_path = resample_to_isotropic(
|
| 373 |
+
corrected_file_path, os.path.join(temp_dir, "resampled.nii.gz"))
|
| 374 |
+
aligned_path = align_to_template(
|
| 375 |
+
resampled_path, template_path, os.path.join(temp_dir, "aligned.nii.gz"))
|
| 376 |
+
|
| 377 |
+
# Perform inference and process results
|
| 378 |
+
original_slice_path, inferred_slice_path, download_file_path, gen_file_name = run_inference(
|
| 379 |
+
inference_engine, aligned_path, corrected_file_path, output_path)
|
| 380 |
+
|
| 381 |
+
# Display results
|
| 382 |
+
display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name)
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
st.error(f"Error during inference: {e}")
|
| 386 |
+
|
| 387 |
+
# Button to clear generated content
|
| 388 |
+
if st.button('Clear Generated All',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
help='Pressing this will delete the contents of the generate folder.'):
|
| 390 |
+
clear_output_folder('infer/generate')
|
| 391 |
+
clear_session()
|
| 392 |
+
st.rerun()
|
| 393 |
|
| 394 |
# Entry point for the Streamlit application
|
| 395 |
if __name__ == '__main__':
|
requirements.txt
CHANGED
|
@@ -6,4 +6,5 @@ matplotlib
|
|
| 6 |
SimpleITK
|
| 7 |
torchio
|
| 8 |
antspyx
|
| 9 |
-
streamlit
|
|
|
|
|
|
| 6 |
SimpleITK
|
| 7 |
torchio
|
| 8 |
antspyx
|
| 9 |
+
streamlit
|
| 10 |
+
scikit-image
|