icyriss commited on
Commit
909a461
Β·
verified Β·
1 Parent(s): bfa2007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -10
app.py CHANGED
@@ -1,14 +1,85 @@
1
-
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def hello(image):
5
- return image
6
 
7
- demo = gr.Interface(
8
- fn=hello,
9
- inputs="image",
10
- outputs="image",
11
- title="SynSpine AI"
12
- )
13
 
14
- demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionInstructPix2PixPipeline
4
+ from PIL import Image
5
+
6
+ print("Loading models...")
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # MR β†’ CT
11
+ pipe_mr2ct = StableDiffusionInstructPix2PixPipeline.from_pretrained(
12
+ "icyriss/MR2CT-model",
13
+ torch_dtype=torch.float32
14
+ ).to(device)
15
+
16
+ # CT β†’ MRI
17
+ pipe_ct2mr = StableDiffusionInstructPix2PixPipeline.from_pretrained(
18
+ "icyriss/CT2MRI-model",
19
+ torch_dtype=torch.float32
20
+ ).to(device)
21
+
22
+ print("Models loaded")
23
+
24
+ def translate(image, task):
25
+
26
+ image = image.convert("RGB")
27
+
28
+ if task == "MRI β†’ CT":
29
+
30
+ prompt = "convert MRI scan to CT scan of cervical spine"
31
+
32
+ result = pipe_mr2ct(
33
+ prompt=prompt,
34
+ image=image,
35
+ num_inference_steps=20,
36
+ image_guidance_scale=1.5,
37
+ guidance_scale=7.5
38
+ ).images[0]
39
+
40
+ else:
41
+
42
+ prompt = "convert CT scan to MRI of cervical spine"
43
+
44
+ result = pipe_ct2mr(
45
+ prompt=prompt,
46
+ image=image,
47
+ num_inference_steps=20,
48
+ image_guidance_scale=1.5,
49
+ guidance_scale=7.5
50
+ ).images[0]
51
+
52
+ return result
53
+
54
+
55
+ with gr.Blocks(title="SynSpine AI") as demo:
56
+
57
+ gr.Markdown("# SynSpine AI")
58
+ gr.Markdown("AI-based CT ↔ MRI Image Translation")
59
+
60
+ with gr.Row():
61
+
62
+ input_image = gr.Image(
63
+ type="pil",
64
+ label="Upload CT or MRI Image"
65
+ )
66
+
67
+ output_image = gr.Image(
68
+ label="Translated Image"
69
+ )
70
+
71
+ task = gr.Radio(
72
+ ["MRI β†’ CT","CT β†’ MRI"],
73
+ label="Translation Task",
74
+ value="MRI β†’ CT"
75
+ )
76
 
77
+ translate_btn = gr.Button("Run Translation")
 
78
 
79
+ translate_btn.click(
80
+ fn=translate,
81
+ inputs=[input_image,task],
82
+ outputs=output_image
83
+ )
 
84
 
85
+ demo.launch()