Hpsoyl commited on
Commit
88153a5
·
1 Parent(s): 98cfd70
Files changed (1) hide show
  1. app.py +52 -2
app.py CHANGED
@@ -100,7 +100,7 @@ CLS_CLASS_NAMES = ['Nucleus', 'Endoplasmic Reticulum', 'Giantin', 'GPP130', 'Lys
100
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
101
 
102
  # --- Constants for Visualization ---
103
- COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno", "Magma", "Plasma"]
104
 
105
  # --- Helper Functions ---
106
  def sanitize_prompt_for_filename(prompt):
@@ -120,6 +120,22 @@ def generate_colorbar_preview(color_name):
120
 
121
  gradient = np.linspace(0, 1, 256).reshape(1, 256)
122
  rgb = np.zeros((1, 256, 3))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if color_name == "Green (GFP)": rgb[..., 1] = gradient
125
  elif color_name == "Red (RFP)": rgb[..., 0] = gradient
@@ -155,6 +171,22 @@ def apply_pseudocolor(image_np, color_name="Grayscale"):
155
 
156
  h, w = norm_img.shape
157
  rgb = np.zeros((h, w, 3), dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  if color_name == "Green (GFP)": rgb[..., 1] = norm_img
160
  elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
@@ -175,6 +207,23 @@ def apply_pseudocolor(image_np, color_name="Grayscale"):
175
 
176
  return Image.fromarray((rgb * 255).astype(np.uint8))
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def save_temp_tiff(image_np, prefix="output"):
179
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
180
  if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
@@ -598,7 +647,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
598
 
599
  sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
600
 
601
- sr_model.change(update_sr_prompt, sr_model, sr_prompt)
 
602
  # Run returns both input/output states and displays
603
  sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar])
604
  # Change color updates both displays
 
100
  CLS_EXAMPLE_IMG_DIR = "example_images_cls"
101
 
102
  # --- Constants for Visualization ---
103
+ COLOR_MAPS = ["Grayscale", "Green (GFP)", "Red (RFP)", "Blue (DAPI)", "Magenta", "Cyan", "Yellow", "Fire", "Viridis", "Inferno", "Magma", "Plasma", "Red Hot", "Cyan Hot", "Magenta Hot"]
104
 
105
  # --- Helper Functions ---
106
  def sanitize_prompt_for_filename(prompt):
 
120
 
121
  gradient = np.linspace(0, 1, 256).reshape(1, 256)
122
  rgb = np.zeros((1, 256, 3))
123
+
124
+ if "Hot" in color_name:
125
+ low_half = np.clip(gradient * 2, 0, 1)
126
+ high_half = np.clip((gradient - 0.5) * 2, 0, 1)
127
+ if color_name == "Red Hot":
128
+ rgb[..., 0] = low_half
129
+ rgb[..., 1] = high_half
130
+ rgb[..., 2] = high_half
131
+ elif color_name == "Cyan Hot":
132
+ rgb[..., 0] = high_half
133
+ rgb[..., 1] = low_half
134
+ rgb[..., 2] = low_half
135
+ elif color_name == "Magenta Hot":
136
+ rgb[..., 0] = low_half
137
+ rgb[..., 1] = high_half
138
+ rgb[..., 2] = low_half
139
 
140
  if color_name == "Green (GFP)": rgb[..., 1] = gradient
141
  elif color_name == "Red (RFP)": rgb[..., 0] = gradient
 
171
 
172
  h, w = norm_img.shape
173
  rgb = np.zeros((h, w, 3), dtype=np.float32)
174
+
175
+ if "Hot" in color_name:
176
+ low_half = np.clip(norm_img * 2, 0, 1)
177
+ high_half = np.clip((norm_img - 0.5) * 2, 0, 1)
178
+ if color_name == "Red Hot":
179
+ rgb[..., 0] = low_half
180
+ rgb[..., 1] = high_half
181
+ rgb[..., 2] = high_half
182
+ elif color_name == "Cyan Hot":
183
+ rgb[..., 0] = high_half
184
+ rgb[..., 1] = low_half
185
+ rgb[..., 2] = low_half
186
+ elif color_name == "Magenta Hot":
187
+ rgb[..., 0] = low_half
188
+ rgb[..., 1] = high_half
189
+ rgb[..., 2] = low_half
190
 
191
  if color_name == "Green (GFP)": rgb[..., 1] = norm_img
192
  elif color_name == "Red (RFP)": rgb[..., 0] = norm_img
 
207
 
208
  return Image.fromarray((rgb * 255).astype(np.uint8))
209
 
210
+ def update_sr_settings(model_name):
211
+ """
212
+ - Microtubules -> Cyan Hot
213
+ - F-actin -> Red Hot
214
+ - CCPs -> Green (GFP)
215
+ - ER -> Magenta Hot
216
+ """
217
+ if model_name == "Checkpoint ER":
218
+ return "ER of COS-7", "Magenta Hot"
219
+ if model_name == "Checkpoint Microtubules":
220
+ return "Microtubules of COS-7", "Cyan Hot"
221
+ if model_name == "Checkpoint CCPs":
222
+ return "CCPs of COS-7", "Green (GFP)"
223
+ elif model_name == "Checkpoint F-actin":
224
+ return "F-actin of COS-7", "Red Hot"
225
+ return "", "Grayscale"
226
+
227
  def save_temp_tiff(image_np, prefix="output"):
228
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".tif", prefix=f"{prefix}_")
229
  if image_np.dtype == np.float16: save_data = image_np.astype(np.float32)
 
647
 
648
  sr_gal = gr.Gallery(value=sr_examples, label="Examples", columns=6, height="auto")
649
 
650
+ # sr_model.change(update_sr_prompt, sr_model, sr_prompt)
651
+ sr_model.change(update_sr_settings, inputs=sr_model, outputs=[sr_prompt, sr_color])
652
  # Run returns both input/output states and displays
653
  sr_btn.click(run_super_resolution, [sr_file, sr_model, sr_prompt, sr_steps, sr_seed, sr_color], [sr_in_disp, sr_out_disp, sr_dl, sr_input_state, sr_raw_state, sr_colorbar])
654
  # Change color updates both displays