| | import os |
| | import sys |
| | import subprocess |
| | import tempfile |
| | import threading |
| | import time |
| | from pathlib import Path |
| | import shutil |
| | import ants |
| | import matplotlib.pyplot as plt |
| | import nibabel as nib |
| | import numpy as np |
| | import streamlit as st |
| | import torch |
| | from network.generator import ResnetGenerator |
| | from scipy.ndimage import zoom |
| |
|
| | |
| | class MRIInference: |
| | def __init__(self, model, device, input_shape, output_shape): |
| | |
| | self.model = model |
| | self.device = device |
| | self.input_shape = input_shape |
| | self.output_shape = output_shape |
| |
|
| | def load_image(self, file_path): |
| | |
| | nib_image = nib.load(file_path) |
| |
|
| | image_data = nib_image.get_fdata() |
| | rotated_image = np.rot90(image_data, k=1, axes=(1, 2)) |
| |
|
| | |
| | min_val, max_val = np.min(rotated_image), np.max(rotated_image) |
| | scale = 255 / (max_val - min_val) |
| | normalized_image = scale * (rotated_image - min_val) |
| |
|
| | scale_factors = ( |
| | self.input_shape[0] / normalized_image.shape[0], |
| | self.input_shape[1] / normalized_image.shape[1], |
| | self.input_shape[2] / normalized_image.shape[2] |
| | ) |
| | resampled_image = zoom(normalized_image, scale_factors, order=3) |
| | return torch.tensor(resampled_image[np.newaxis, np.newaxis, ...], dtype=torch.float32) |
| |
|
| | def save_image(self, image, file_name): |
| | |
| | image = image.squeeze().cpu().numpy() |
| | scale_factors = ( |
| | self.output_shape[0] / image.shape[0], |
| | self.output_shape[1] / image.shape[1], |
| | self.output_shape[2] / image.shape[2] |
| | ) |
| | resampled_image = zoom(image, scale_factors, order=3) |
| | nib.save(nib.Nifti1Image(resampled_image, np.eye(4)), file_name) |
| |
|
| | def match_sform_affine(self, orig_path, gen_path): |
| | |
| | orig_img = nib.load(orig_path) |
| | orig_affine = orig_img.affine |
| | gen_img = nib.load(gen_path) |
| | gen_data = gen_img.get_fdata() |
| | matched_gen_img = nib.Nifti1Image(gen_data, orig_affine) |
| | nib.save(matched_gen_img, gen_path) |
| |
|
| | def infer(self, aligned_image_path, original_file_path, output_path): |
| | |
| | input_tensor = self.load_image(aligned_image_path) |
| | |
| | |
| | with torch.no_grad(): |
| | self.model.eval() |
| | output = self.model(input_tensor.to(self.device)) |
| | |
| | |
| | scale_factor = ( |
| | self.output_shape[0] / output.shape[2], |
| | self.output_shape[1] / output.shape[3], |
| | self.output_shape[2] / output.shape[4] |
| | ) |
| | resampled_output = zoom( |
| | output.squeeze().cpu().numpy(), scale_factor, order=3) |
| | generated_image = torch.tensor(resampled_output[np.newaxis, ...]) |
| | |
| | |
| | temp_generated_path = os.path.join(output_path, 'temp_generated.nii.gz') |
| | self.save_image(generated_image, temp_generated_path) |
| |
|
| | |
| | orig_img = ants.image_read(original_file_path) |
| | orig_orientation = ants.get_orientation(orig_img) |
| |
|
| | |
| | gen_img = nib.load(temp_generated_path) |
| | gen_data = gen_img.get_fdata() |
| | reoriented_image = ants.from_numpy(gen_data) |
| | |
| |
|
| | if orig_orientation == 'LSP': |
| | reoriented_image = ants.reorient_image2(reoriented_image, 'RAI') |
| | elif orig_orientation == 'LPI': |
| | reoriented_image = ants.reorient_image2(reoriented_image, 'RIP') |
| | elif orig_orientation == 'RAS': |
| | reoriented_image = ants.reorient_image2(reoriented_image, 'LSA') |
| | |
| |
|
| | |
| | nib.save(nib.Nifti1Image(reoriented_image.numpy(), np.eye(4)), temp_generated_path) |
| |
|
| | |
| | temp_orig_path = os.path.join(output_path, 'temp_orig.nii.gz') |
| | resampled_file_path = resample_to_isotropic( |
| | original_file_path, temp_orig_path) |
| | self.match_sform_affine(resampled_file_path, temp_generated_path) |
| | |
| | resampled_generated_path = os.path.join(output_path, 'resampled_generated.nii.gz') |
| | resample_to_isotropic(temp_generated_path, resampled_generated_path) |
| |
|
| | base_name = os.path.basename(original_file_path) |
| | gen_file_name = f"{Path(base_name).stem}_{int(time.time())}_gen.nii.gz" |
| | warped_file_path = os.path.join(output_path, gen_file_name) |
| | affine_registration( |
| | resampled_file_path, temp_generated_path, warped_file_path) |
| |
|
| | |
| | for temp_file in [temp_orig_path, temp_generated_path, resampled_generated_path]: |
| | os.remove(temp_file) |
| | |
| | return warped_file_path |
| |
|
| | |
| | def resample_to_isotropic(image_path, output_path): |
| | |
| | image = ants.image_read(image_path) |
| | resampled_image = ants.resample_image( |
| | image, (0.15, 0.15, 0.15), use_voxels=False, interp_type=3) |
| | ants.image_write(resampled_image, output_path) |
| | return output_path |
| |
|
| | def affine_registration(fixed_image_path, moving_image_path, output_path): |
| | |
| | fixed_image = ants.image_read(fixed_image_path) |
| | moving_image = ants.image_read(moving_image_path) |
| | registration = ants.registration( |
| | fixed=fixed_image, moving=moving_image, |
| | type_of_transform='Rigid') |
| | ants.image_write(registration['warpedmovout'], output_path) |
| |
|
| | def align_to_template(resampled_image_path, template_path, output_path): |
| | |
| | moving_image = ants.image_read(resampled_image_path) |
| | fixed_image = ants.image_read(template_path) |
| | registration = ants.registration( |
| | fixed=fixed_image, moving=moving_image, |
| | type_of_transform='Rigid') |
| | aligned_image = registration['warpedmovout'] |
| | ants.image_write(aligned_image, output_path) |
| | return output_path |
| |
|
| | def download_model_if_needed(templates_folder): |
| | """Downloads model from Hugging Face if template folder is empty or doesn't exist.""" |
| | if not os.path.exists(templates_folder) or not os.listdir(templates_folder): |
| | print("Downloading model from Hugging Face...") |
| | os.makedirs(templates_folder, exist_ok=True) |
| | subprocess.run(["huggingface-cli", "download", "hwonheo/easysr_templates", |
| | "--local-dir", "templates", "--local-dir-use-symlinks", "False"], check=True) |
| |
|
| | @st.cache_data |
| | def load_model(model_choice): |
| | |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | generator = ResnetGenerator().to(device) |
| |
|
| | if model_choice == "T1-Model": |
| | checkpoint_path = 'ckpt/ckpt_final/G_latest_T1.pth' |
| | else: |
| | checkpoint_path = 'ckpt/ckpt_final/G_latest_Mixed.pth' |
| |
|
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | generator.load_state_dict(checkpoint) |
| | return generator, device |
| |
|
| | def run_bias_field_correction(file_path, output_path, correction_type): |
| | """Bias field correction script and return corrected file path""" |
| | corrected_file_name = os.path.basename(file_path).replace('.nii', '_corrected.nii') |
| | corrected_file_path = os.path.join(output_path, corrected_file_name) |
| |
|
| | subprocess.run([ |
| | sys.executable, "utils/BiasFieldCorrection.py", |
| | "--input", file_path, |
| | "--output", output_path, |
| | "--type", correction_type |
| | ]) |
| |
|
| | |
| | original_corrected_file_path = os.path.join(output_path, os.path.basename(file_path)) |
| | if os.path.exists(original_corrected_file_path) and original_corrected_file_path != corrected_file_path: |
| | shutil.move(original_corrected_file_path, corrected_file_path) |
| |
|
| | return corrected_file_path |
| |
|
| | |
| | def run_inference(inference_engine, aligned_image_path, original_file_path, output_path): |
| | try: |
| | |
| | warped_image_path = inference_engine.infer(aligned_image_path, original_file_path, output_path) |
| |
|
| | |
| | gen_file_name = os.path.basename(original_file_path).replace(".nii", "_gen.nii") |
| | download_file_path = os.path.join(output_path, gen_file_name) |
| |
|
| | |
| | shutil.copy(warped_image_path, download_file_path) |
| |
|
| | |
| | original_img = nib.load(original_file_path).get_fdata() |
| | inferred_img = nib.load(warped_image_path).get_fdata() |
| |
|
| | |
| | original_slice_path = os.path.join(output_path, "original_slice.jpg") |
| | inferred_slice_path = os.path.join(output_path, "inferred_slice.jpg") |
| | save_middle_slice(original_img, original_slice_path) |
| | save_middle_slice(inferred_img, inferred_slice_path) |
| |
|
| | |
| | return (original_slice_path, inferred_slice_path, download_file_path, gen_file_name) |
| | except Exception as e: |
| | st.error(f"Error during inference: {e}") |
| | return None, None, None, None |
| |
|
| | def save_middle_slice(image, file_path): |
| | |
| | middle_slice = image[image.shape[0] // 2] |
| | |
| | |
| | rotated_slice = np.rot90(middle_slice) |
| |
|
| | fig, ax = plt.subplots(figsize=(5, 5)) |
| | ax.imshow(rotated_slice, cmap='gray', aspect='auto') |
| | ax.axis('off') |
| | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
| | plt.savefig(file_path, format='jpg', bbox_inches='tight', pad_inches=0, dpi=500) |
| | plt.close() |
| |
|
| | def display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name): |
| | st.subheader("Comparison of Original and EasySR Inferred Slice") |
| | col1, col2 = st.columns([0.5, 0.5]) |
| | with col1: |
| | st.image(original_slice_path, caption="Original MRI", width=300) |
| | with col2: |
| | st.image(inferred_slice_path, caption="Inferred MRI", width=300) |
| |
|
| | if os.path.exists(download_file_path): |
| | with open(download_file_path, "rb") as file: |
| | st.download_button( |
| | label="Download (EasySR Inferred-MRI)", |
| | data=file, |
| | file_name=gen_file_name, |
| | mime="application/gzip", |
| | type="primary" |
| | ) |
| |
|
| | def clear_output_folder(folder_path): |
| | |
| | for filename in os.listdir(folder_path): |
| | file_path = os.path.join(folder_path, filename) |
| | if os.path.isfile(file_path) or os.path.islink(file_path): |
| | os.unlink(file_path) |
| | elif os.path.isdir(file_path): |
| | shutil.rmtree(file_path) |
| |
|
| | def clear_session(): |
| | |
| | for key in list(st.session_state.keys()): |
| | del st.session_state[key] |
| |
|
| | |
| | def main(): |
| | global original_slice_path, inferred_slice_path, download_file_path, gen_file_name, intensity_adjust |
| |
|
| | st.sidebar.markdown("# ") |
| | st.sidebar.markdown( |
| | "[]" |
| | "(https://github.com/hwonheo/easysr)" |
| | ) |
| | st.sidebar.markdown("# ") |
| |
|
| | |
| | st.sidebar.subheader("*Model Selection*", divider='red') |
| | model_choice = st.sidebar.selectbox( |
| | "Choose the model type:", |
| | ("Mixed-Model", "T1-Model"), |
| | index=1 |
| | ) |
| |
|
| | st.sidebar.header("\n") |
| |
|
| | |
| | st.sidebar.subheader("_How to Use EasySR_", divider='red') |
| | with st.sidebar.expander("Step-by-Step Guide:"): |
| | st.markdown( |
| | "1. **Prepare Your Data**: Make sure your rat brain MRI data " |
| | "is in NIFTI format. Convert if needed.\n\n" |
| | "2. **Upload Your MRI**: Drag and drop your NIFTI file " |
| | "or use the upload button.\n\n" |
| | "3. **Start the EasySR**: Click 'EasySR' to begin processing. " |
| | "It usually takes a few minutes.\n\n" |
| | "4. **Sit Back and Relax**: Wait while your data is processed quickly.\n\n" |
| | "5. **View and Download**: After processing, view the results and " |
| | "use the download button to save the enhanced MRI data.\n\n" |
| | "6. **Use as Needed**: Download and utilize your enhanced MRI. " |
| | "Continue using EasySR for more enhancements.\n\n" |
| | ) |
| | |
| | |
| | generator, device = load_model(model_choice) |
| | inference_engine = MRIInference(generator, device, (128, 128, 64), (128, 128, 192)) |
| | |
| | |
| | st.markdown("<h1 style='text-align: center;'>EasySR</h1>", unsafe_allow_html=True) |
| | st.subheader("_Easy Web UI for Generative 3D Inference of Rat Brain MRI_", divider='red') |
| |
|
| | |
| | original_slice_path = None |
| | inferred_slice_path = None |
| | download_file_path = None |
| |
|
| | output_path = "infer/generate" |
| | if not os.path.exists(output_path): |
| | os.makedirs(output_path) |
| |
|
| | |
| | uploaded_file = st.file_uploader("_MRI File Upload (NIFTI)_", |
| | type=["nii", "nii.gz"], key='file_uploader') |
| | |
| | |
| | intensity_adjust = st.checkbox("Bias Field Correction (enhance signal intensity)", |
| | help="Apply intensity truncation and bias correction to an image: " |
| | "Check this option if the input image exhibits low signal intensity " |
| | "(common in T2RARE, TOF, etc.) or if the output from the inference " |
| | "process appears weakly signaled. This will enhance the signals by " |
| | "N4-bias correction and very low- or high-signal intensity truncation, " |
| | "yielding clearer and more defined results.") |
| |
|
| | if uploaded_file is not None: |
| | |
| | st.session_state['uploaded_file'] = uploaded_file |
| | file_name = uploaded_file.name |
| |
|
| | |
| | temp_dir = tempfile.gettempdir() |
| | temp_file_path = os.path.join(temp_dir, file_name) |
| |
|
| | |
| | with open(temp_file_path, "wb") as tmp_file: |
| | tmp_file.write(uploaded_file.getvalue()) |
| |
|
| | |
| | if st.button("EasySR (start inference)", type="primary"): |
| | try: |
| | |
| | corrected_file_path = run_bias_field_correction( |
| | temp_file_path, temp_dir, "abp") if intensity_adjust else temp_file_path |
| |
|
| | |
| | templates_folder = "templates" |
| | download_model_if_needed(templates_folder) |
| | template_path = os.path.join(templates_folder, "bmc_t2_rat.nii.gz") |
| |
|
| | |
| | resampled_path = resample_to_isotropic( |
| | corrected_file_path, os.path.join(temp_dir, "resampled.nii.gz")) |
| | aligned_path = align_to_template( |
| | resampled_path, template_path, os.path.join(temp_dir, "aligned.nii.gz")) |
| |
|
| | |
| | original_slice_path, inferred_slice_path, download_file_path, gen_file_name = run_inference( |
| | inference_engine, aligned_path, corrected_file_path, output_path) |
| |
|
| | |
| | display_results(original_slice_path, inferred_slice_path, download_file_path, gen_file_name) |
| |
|
| | except Exception as e: |
| | st.error(f"Error during inference: {e}") |
| |
|
| | |
| | if st.button('Clear Generated All', |
| | help='Pressing this will delete the contents of the generate folder.'): |
| | clear_output_folder('infer/generate') |
| | clear_session() |
| | st.rerun() |
| |
|
| | |
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|