Spaces:
Sleeping
Sleeping
Ishan Kumarasinghe commited on
Commit ·
956cffa
1
Parent(s): 54c16ed
Add Mask generation Model
Browse files- app.py +4 -4
- models/mask_diffusion.pth +3 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -24,7 +24,7 @@ models = {
|
|
| 24 |
|
| 25 |
def load_mask_model():
|
| 26 |
if models["mask"] is None:
|
| 27 |
-
print("
|
| 28 |
model = DiffusionModelUNet(
|
| 29 |
spatial_dims=2,
|
| 30 |
in_channels=4,
|
|
@@ -97,7 +97,7 @@ def synthesize_image(mask_input, source_type, model_choice):
|
|
| 97 |
# A. Handle Input Source
|
| 98 |
if source_type == "Upload Mask":
|
| 99 |
if mask_input is None:
|
| 100 |
-
return None, "
|
| 101 |
# Expecting RGB upload, need to convert to integer map?
|
| 102 |
# Or if your conditional models take RGB, pass raw.
|
| 103 |
# For safety, let's assume we convert upload to numpy.
|
|
@@ -111,7 +111,7 @@ def synthesize_image(mask_input, source_type, model_choice):
|
|
| 111 |
else:
|
| 112 |
# Input comes from the "Generate Mask" step (State variable)
|
| 113 |
if mask_input is None:
|
| 114 |
-
return None, "
|
| 115 |
mask_idx = mask_input
|
| 116 |
|
| 117 |
# B. Run Conditional Inference
|
|
@@ -147,7 +147,7 @@ with gr.Blocks(title="Cardiac MRI Synthesis") as demo:
|
|
| 147 |
|
| 148 |
# Tab 1: Generate
|
| 149 |
with gr.Group(visible=True) as group_gen:
|
| 150 |
-
btn_gen_mask = gr.Button("
|
| 151 |
out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
|
| 152 |
state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view
|
| 153 |
|
|
|
|
| 24 |
|
| 25 |
def load_mask_model():
|
| 26 |
if models["mask"] is None:
|
| 27 |
+
print("Loading Mask Model...")
|
| 28 |
model = DiffusionModelUNet(
|
| 29 |
spatial_dims=2,
|
| 30 |
in_channels=4,
|
|
|
|
| 97 |
# A. Handle Input Source
|
| 98 |
if source_type == "Upload Mask":
|
| 99 |
if mask_input is None:
|
| 100 |
+
return None, "Please upload a mask first."
|
| 101 |
# Expecting RGB upload, need to convert to integer map?
|
| 102 |
# Or if your conditional models take RGB, pass raw.
|
| 103 |
# For safety, let's assume we convert upload to numpy.
|
|
|
|
| 111 |
else:
|
| 112 |
# Input comes from the "Generate Mask" step (State variable)
|
| 113 |
if mask_input is None:
|
| 114 |
+
return None, "Please generate a mask first."
|
| 115 |
mask_idx = mask_input
|
| 116 |
|
| 117 |
# B. Run Conditional Inference
|
|
|
|
| 147 |
|
| 148 |
# Tab 1: Generate
|
| 149 |
with gr.Group(visible=True) as group_gen:
|
| 150 |
+
btn_gen_mask = gr.Button("Generate Random Mask", variant="primary")
|
| 151 |
out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
|
| 152 |
state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view
|
| 153 |
|
models/mask_diffusion.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0482056875d7646c42da7219300c8e9bd57ad5f8051b3426c0caf1adf9b85e7
|
| 3 |
+
size 252760728
|
requirements.txt
CHANGED
|
@@ -8,4 +8,5 @@ pillow
|
|
| 8 |
tqdm
|
| 9 |
gradio
|
| 10 |
scipy
|
| 11 |
-
safetensors
|
|
|
|
|
|
| 8 |
tqdm
|
| 9 |
gradio
|
| 10 |
scipy
|
| 11 |
+
safetensors
|
| 12 |
+
huggingface_hub
|