Ishan Kumarasinghe commited on
Commit
956cffa
·
1 Parent(s): 54c16ed

Add Mask generation Model

Browse files
Files changed (3) hide show
  1. app.py +4 -4
  2. models/mask_diffusion.pth +3 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -24,7 +24,7 @@ models = {
24
 
25
  def load_mask_model():
26
  if models["mask"] is None:
27
- print("🔄 Loading Mask Model...")
28
  model = DiffusionModelUNet(
29
  spatial_dims=2,
30
  in_channels=4,
@@ -97,7 +97,7 @@ def synthesize_image(mask_input, source_type, model_choice):
97
  # A. Handle Input Source
98
  if source_type == "Upload Mask":
99
  if mask_input is None:
100
- return None, "⚠️ Please upload a mask first."
101
  # Expecting RGB upload, need to convert to integer map?
102
  # Or if your conditional models take RGB, pass raw.
103
  # For safety, let's assume we convert upload to numpy.
@@ -111,7 +111,7 @@ def synthesize_image(mask_input, source_type, model_choice):
111
  else:
112
  # Input comes from the "Generate Mask" step (State variable)
113
  if mask_input is None:
114
- return None, "⚠️ Please generate a mask first."
115
  mask_idx = mask_input
116
 
117
  # B. Run Conditional Inference
@@ -147,7 +147,7 @@ with gr.Blocks(title="Cardiac MRI Synthesis") as demo:
147
 
148
  # Tab 1: Generate
149
  with gr.Group(visible=True) as group_gen:
150
- btn_gen_mask = gr.Button("🎲 Generate Random Mask", variant="primary")
151
  out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
152
  state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view
153
 
 
24
 
25
  def load_mask_model():
26
  if models["mask"] is None:
27
+ print("Loading Mask Model...")
28
  model = DiffusionModelUNet(
29
  spatial_dims=2,
30
  in_channels=4,
 
97
  # A. Handle Input Source
98
  if source_type == "Upload Mask":
99
  if mask_input is None:
100
+ return None, "Please upload a mask first."
101
  # Expecting RGB upload, need to convert to integer map?
102
  # Or if your conditional models take RGB, pass raw.
103
  # For safety, let's assume we convert upload to numpy.
 
111
  else:
112
  # Input comes from the "Generate Mask" step (State variable)
113
  if mask_input is None:
114
+ return None, "Please generate a mask first."
115
  mask_idx = mask_input
116
 
117
  # B. Run Conditional Inference
 
147
 
148
  # Tab 1: Generate
149
  with gr.Group(visible=True) as group_gen:
150
+ btn_gen_mask = gr.Button("Generate Random Mask", variant="primary")
151
  out_gen_mask = gr.Image(label="Generated Mask", type="numpy", interactive=False)
152
  state_mask = gr.State() # Stores the raw integer mask (0-3) hidden from view
153
 
models/mask_diffusion.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0482056875d7646c42da7219300c8e9bd57ad5f8051b3426c0caf1adf9b85e7
3
+ size 252760728
requirements.txt CHANGED
@@ -8,4 +8,5 @@ pillow
8
  tqdm
9
  gradio
10
  scipy
11
- safetensors
 
 
8
  tqdm
9
  gradio
10
  scipy
11
+ safetensors
12
+ huggingface_hub