SynSpine-AI / app.py
icyriss's picture
Update app.py
909a461 verified
import gradio as gr
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from PIL import Image
print("Loading models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# MR β†’ CT
pipe_mr2ct = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"icyriss/MR2CT-model",
torch_dtype=torch.float32
).to(device)
# CT β†’ MRI
pipe_ct2mr = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"icyriss/CT2MRI-model",
torch_dtype=torch.float32
).to(device)
print("Models loaded")
def translate(image, task):
image = image.convert("RGB")
if task == "MRI β†’ CT":
prompt = "convert MRI scan to CT scan of cervical spine"
result = pipe_mr2ct(
prompt=prompt,
image=image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7.5
).images[0]
else:
prompt = "convert CT scan to MRI of cervical spine"
result = pipe_ct2mr(
prompt=prompt,
image=image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7.5
).images[0]
return result
with gr.Blocks(title="SynSpine AI") as demo:
gr.Markdown("# SynSpine AI")
gr.Markdown("AI-based CT ↔ MRI Image Translation")
with gr.Row():
input_image = gr.Image(
type="pil",
label="Upload CT or MRI Image"
)
output_image = gr.Image(
label="Translated Image"
)
task = gr.Radio(
["MRI β†’ CT","CT β†’ MRI"],
label="Translation Task",
value="MRI β†’ CT"
)
translate_btn = gr.Button("Run Translation")
translate_btn.click(
fn=translate,
inputs=[input_image,task],
outputs=output_image
)
demo.launch()