Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionInstructPix2PixPipeline | |
| from PIL import Image | |
| print("Loading models...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # MR β CT | |
| pipe_mr2ct = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| "icyriss/MR2CT-model", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| # CT β MRI | |
| pipe_ct2mr = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
| "icyriss/CT2MRI-model", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| print("Models loaded") | |
| def translate(image, task): | |
| image = image.convert("RGB") | |
| if task == "MRI β CT": | |
| prompt = "convert MRI scan to CT scan of cervical spine" | |
| result = pipe_mr2ct( | |
| prompt=prompt, | |
| image=image, | |
| num_inference_steps=20, | |
| image_guidance_scale=1.5, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| else: | |
| prompt = "convert CT scan to MRI of cervical spine" | |
| result = pipe_ct2mr( | |
| prompt=prompt, | |
| image=image, | |
| num_inference_steps=20, | |
| image_guidance_scale=1.5, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| return result | |
| with gr.Blocks(title="SynSpine AI") as demo: | |
| gr.Markdown("# SynSpine AI") | |
| gr.Markdown("AI-based CT β MRI Image Translation") | |
| with gr.Row(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload CT or MRI Image" | |
| ) | |
| output_image = gr.Image( | |
| label="Translated Image" | |
| ) | |
| task = gr.Radio( | |
| ["MRI β CT","CT β MRI"], | |
| label="Translation Task", | |
| value="MRI β CT" | |
| ) | |
| translate_btn = gr.Button("Run Translation") | |
| translate_btn.click( | |
| fn=translate, | |
| inputs=[input_image,task], | |
| outputs=output_image | |
| ) | |
| demo.launch() |