Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import skimage | |
| from skimage import io | |
| import torch | |
| import monai | |
| from monai.transforms import Rotate | |
| # Placeholder for the 3D reconstruction model | |
| class Simple3DReconstructionModel: | |
| def __init__(self): | |
| # Load your pre-trained model here | |
| self.model = None # replace with actual model loading | |
| def reconstruct_3d(self, image): | |
| # Implement the 3D reconstruction logic here | |
| # This is a placeholder example | |
| return np.zeros((128, 128, 128)) | |
| def rotate_3d(self, volume, angles): | |
| # Rotate the 3D volume using MONAI | |
| rotate = Rotate(angles, mode='bilinear') | |
| rotated_volume = rotate(volume) | |
| return rotated_volume | |
| def project_2d(self, volume): | |
| # Project the 3D volume back to 2D | |
| # This is a placeholder example | |
| projection = np.max(volume, axis=0) | |
| return projection | |
| # Initialize the model | |
| model = Simple3DReconstructionModel() | |
| # Gradio helper functions | |
| def process_image(img, xt, yt, zt): | |
| # Reconstruct the 3D volume | |
| volume = model.reconstruct_3d(img) | |
| # Rotate the 3D volume | |
| rotated_volume = model.rotate_3d(volume, (xt, yt, zt)) | |
| # Project the rotated volume back to 2D | |
| output_img = model.project_2d(rotated_volume) | |
| return output_img | |
| def rotate_btn_fn(img, xt, yt, zt, add_bone_cmap=False): | |
| try: | |
| angles = (xt, yt, zt) | |
| print(f"Rotating with angles: {angles}") | |
| if isinstance(img, np.ndarray): | |
| input_img_path = "uploaded_image.png" | |
| skimage.io.imsave(input_img_path, img) | |
| elif isinstance(img, str) and os.path.exists(img): | |
| input_img_path = img | |
| img = skimage.io.imread(input_img_path) | |
| else: | |
| raise ValueError("Invalid input image") | |
| # Process the image with the model | |
| out_img = process_image(img, xt, yt, zt) | |
| if not add_bone_cmap: | |
| return out_img | |
| cmap = plt.get_cmap('bone') | |
| out_img = cmap(out_img) | |
| out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
| return out_img | |
| except Exception as e: | |
| print(f"Error in rotate_btn_fn: {e}") | |
| return None | |
| css_style = "./style.css" | |
| callback = gr.CSVLogger() | |
| with gr.Blocks(css=css_style, title="RadRotator") as app: | |
| gr.HTML("RadRotator: 3D Rotation of Radiographs with Diffusion Models", elem_classes="title") | |
| gr.HTML("Developed by:<br>Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles<br><a href='https://pouriarouzrokh.github.io/RadRotator'>[Our website]</a>, <a href='https://arxiv.org/abs/2404.13000'>[arXiv Paper]</a>", elem_classes="note") | |
| gr.HTML("Note: The demo operates on a CPU, and since diffusion models require more computational capacity to function, all predictions are precomputed.", elem_classes="note") | |
| with gr.TabItem("Demo"): | |
| with gr.Row(): | |
| input_img = gr.Image(type='numpy', label='Input image', interactive=True, elem_classes='imgs') | |
| output_img = gr.Image(type='numpy', label='Output image', interactive=False, elem_classes='imgs') | |
| with gr.Row(): | |
| with gr.Column(scale=0.25): | |
| pass | |
| with gr.Column(scale=1): | |
| gr.Examples( | |
| examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
| inputs = [input_img], | |
| label = "Xray Examples", | |
| elem_id='examples', | |
| ) | |
| with gr.Column(scale=0.25): | |
| pass | |
| with gr.Row(): | |
| gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| xt = gr.Slider(label='x axis (medial/lateral rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
| with gr.Column(scale=1): | |
| yt = gr.Slider(label='y axis (inlet/outlet rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
| with gr.Column(scale=1): | |
| zt = gr.Slider(label='z axis (plane rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) | |
| with gr.Row(): | |
| rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') | |
| rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) | |
| try: | |
| app.close() | |
| gr.close_all() | |
| except Exception as e: | |
| print(f"Error closing app: {e}") | |
| demo = app.launch( | |
| max_threads=4, | |
| share=True, | |
| inline=False, | |
| show_api=False, | |
| show_error=False, | |
| ) |