saim1309 commited on
Commit
85771fd
·
verified ·
1 Parent(s): cd9e6ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -38
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import os
3
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU if needed
4
  import torch
 
 
5
  import numpy as np
6
  from PIL import Image
7
  from PIL import Image as PILImage
@@ -13,31 +14,38 @@ from skimage.color import rgb2gray
13
  from csbdeep.utils import normalize
14
  from stardist.models import StarDist2D
15
  from stardist.plot import render_label
16
- from MEDIARFormer import MEDIARFormer
17
- from Predictor import Predictor
18
  from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
 
 
 
 
 
19
 
20
  # Load SegFormer
21
- from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
22
  processor_segformer = SegformerImageProcessor(do_reduce_labels=False)
23
  model_segformer = SegformerForSemanticSegmentation.from_pretrained(
24
  "nvidia/segformer-b0-finetuned-ade-512-512",
25
  num_labels=8,
26
  ignore_mismatched_sizes=True
27
  )
28
- model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location="cpu"))
 
29
  model_segformer.eval()
30
 
31
- # StarDist model
32
  model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo')
33
 
34
- # Cellpose model
35
- model_cellpose = cellpose_models.CellposeModel(gpu=False)
36
 
37
- # Handle SegFormer prediction
38
  def infer_segformer(image):
39
  image = image.convert("RGB")
40
  inputs = processor_segformer(images=image, return_tensors="pt")
 
 
41
  with torch.no_grad():
42
  logits = model_segformer(**inputs).logits
43
  pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()
@@ -49,7 +57,7 @@ def infer_segformer(image):
49
  color_mask[pred_mask == c] = colors[c]
50
  return image, Image.fromarray(color_mask)
51
 
52
- # Handle StarDist prediction
53
  def infer_stardist(image):
54
  image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image)
55
  labels, _ = model_stardist.predict_instances(normalize(image_gray))
@@ -57,12 +65,11 @@ def infer_stardist(image):
57
  overlay = (overlay[..., :3] * 255).astype(np.uint8)
58
  return image, Image.fromarray(overlay)
59
 
60
- # Handle MEDIAR prediction
61
  def infer_mediar(image, temp_dir="temp_mediar"):
62
  os.makedirs(temp_dir, exist_ok=True)
63
  input_path = os.path.join(temp_dir, "input_image.tiff")
64
  output_path = os.path.join(temp_dir, "input_image_label.tiff")
65
-
66
  image.save(input_path)
67
 
68
  model_args = {
@@ -74,11 +81,12 @@ def infer_mediar(image, temp_dir="temp_mediar"):
74
  }
75
 
76
  model = MEDIARFormer(**model_args)
77
- weights = torch.load("from_phase1.pth", map_location="cpu")
78
  model.load_state_dict(weights, strict=False)
 
79
  model.eval()
80
 
81
- predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
82
  predictor.img_names = ["input_image.tiff"]
83
  _ = predictor.conduct_prediction()
84
 
@@ -93,13 +101,13 @@ def infer_mediar(image, temp_dir="temp_mediar"):
93
  buf.seek(0)
94
 
95
  return image, Image.open(buf)
96
- # Handle Cellpose prediction
 
97
  def infer_cellpose(image, temp_dir="temp_cellpose"):
98
  os.makedirs(temp_dir, exist_ok=True)
99
  input_path = os.path.join(temp_dir, "input_image.tif")
100
  output_overlay = os.path.join(temp_dir, "overlay.png")
101
 
102
- # Save image
103
  image.save(input_path)
104
  img = cellpose_io.imread(input_path)
105
  masks, flows, styles = model_cellpose.eval(img, batch_size=1)
@@ -112,18 +120,12 @@ def infer_cellpose(image, temp_dir="temp_cellpose"):
112
 
113
  return image, Image.open(output_overlay)
114
 
115
- # Wrapper function
116
  def segment(model_name, image):
117
- # Gradio passes a PIL.Image without filename attribute
118
- # Try to check format if available, else skip check
119
- ext = None
120
- if hasattr(image, 'format') and image.format is not None:
121
- ext = image.format.lower()
122
- if model_name == "Cellpose":
123
- # Accept only TIFF images for Cellpose
124
- if ext not in ["tiff", "tif", None]:
125
- return None, f"❌ Cellpose only supports `.tif` or `.tiff` images."
126
- # ...existing code...
127
  if model_name == "SegFormer":
128
  return infer_segformer(image)
129
  elif model_name == "StarDist":
@@ -135,6 +137,7 @@ def segment(model_name, image):
135
  else:
136
  return None, f"❌ Unknown model: {model_name}"
137
 
 
138
  with gr.Blocks(title="Cell Segmentation Explorer") as app:
139
  gr.Markdown("## Cell Segmentation Explorer")
140
  gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.")
@@ -156,7 +159,7 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
156
  def handle_submit(model_name, img):
157
  if img is None:
158
  return None
159
- _, result = segment(model_name, img) # Only return the mask (segmentation result)
160
  return result
161
 
162
  submit_btn.click(
@@ -171,18 +174,12 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
171
  outputs=[image_input, output_image]
172
  )
173
 
174
- # === SAMPLE IMAGES SECTION ===
175
  gr.Markdown("---")
176
  gr.Markdown("### Sample Images (click to use as input)")
177
 
178
- # Original and resized thumbnails
179
- original_sample_paths = [
180
- "img1.png",
181
- "img2.png",
182
- "img3.png"
183
- ]
184
-
185
  resized_sample_paths = []
 
186
  for idx, p in enumerate(original_sample_paths):
187
  img = PILImage.open(p).resize((128, 128))
188
  temp_path = f"/tmp/sample_resized_{idx}.png"
@@ -192,7 +189,7 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
192
  sample_image_components = []
193
  with gr.Row():
194
  for i, img_path in enumerate(resized_sample_paths):
195
- def load_full_image(idx=i): # Capture loop index properly
196
  return PILImage.open(original_sample_paths[idx])
197
 
198
  sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False)
@@ -203,5 +200,4 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
203
  )
204
  sample_image_components.append(sample_img)
205
 
206
-
207
  app.launch()
 
1
  import gradio as gr
2
  import os
 
3
  import torch
4
+ import tensorflow as tf
5
+ tf.config.set_visible_devices([], 'GPU')
6
  import numpy as np
7
  from PIL import Image
8
  from PIL import Image as PILImage
 
14
  from csbdeep.utils import normalize
15
  from stardist.models import StarDist2D
16
  from stardist.plot import render_label
17
+ from train_tools.models import MEDIARFormer
18
+ from core.MEDIAR import Predictor
19
  from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
20
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
21
+
22
+ # === Setup for GPU or CPU ===
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ print(f"Using device: {device}")
25
 
26
  # Load SegFormer
 
27
  processor_segformer = SegformerImageProcessor(do_reduce_labels=False)
28
  model_segformer = SegformerForSemanticSegmentation.from_pretrained(
29
  "nvidia/segformer-b0-finetuned-ade-512-512",
30
  num_labels=8,
31
  ignore_mismatched_sizes=True
32
  )
33
+ model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location=device))
34
+ model_segformer.to(device)
35
  model_segformer.eval()
36
 
37
+ # Load StarDist model (CPU-only, no GPU support)
38
  model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo')
39
 
40
+ # Load Cellpose model with GPU if available
41
+ model_cellpose = cellpose_models.CellposeModel(gpu=torch.cuda.is_available())
42
 
43
+ # SegFormer Inference
44
  def infer_segformer(image):
45
  image = image.convert("RGB")
46
  inputs = processor_segformer(images=image, return_tensors="pt")
47
+ inputs = {k: v.to(device) for k, v in inputs.items()}
48
+
49
  with torch.no_grad():
50
  logits = model_segformer(**inputs).logits
51
  pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()
 
57
  color_mask[pred_mask == c] = colors[c]
58
  return image, Image.fromarray(color_mask)
59
 
60
+ # StarDist Inference
61
  def infer_stardist(image):
62
  image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image)
63
  labels, _ = model_stardist.predict_instances(normalize(image_gray))
 
65
  overlay = (overlay[..., :3] * 255).astype(np.uint8)
66
  return image, Image.fromarray(overlay)
67
 
68
+ # MEDIAR Inference
69
  def infer_mediar(image, temp_dir="temp_mediar"):
70
  os.makedirs(temp_dir, exist_ok=True)
71
  input_path = os.path.join(temp_dir, "input_image.tiff")
72
  output_path = os.path.join(temp_dir, "input_image_label.tiff")
 
73
  image.save(input_path)
74
 
75
  model_args = {
 
81
  }
82
 
83
  model = MEDIARFormer(**model_args)
84
+ weights = torch.load("MEDIAR_Weights/from_phase1.pth", map_location=device)
85
  model.load_state_dict(weights, strict=False)
86
+ model.to(device)
87
  model.eval()
88
 
89
+ predictor = Predictor(model, device.type, temp_dir, temp_dir, algo_params={"use_tta": False})
90
  predictor.img_names = ["input_image.tiff"]
91
  _ = predictor.conduct_prediction()
92
 
 
101
  buf.seek(0)
102
 
103
  return image, Image.open(buf)
104
+
105
+ # Cellpose Inference
106
  def infer_cellpose(image, temp_dir="temp_cellpose"):
107
  os.makedirs(temp_dir, exist_ok=True)
108
  input_path = os.path.join(temp_dir, "input_image.tif")
109
  output_overlay = os.path.join(temp_dir, "overlay.png")
110
 
 
111
  image.save(input_path)
112
  img = cellpose_io.imread(input_path)
113
  masks, flows, styles = model_cellpose.eval(img, batch_size=1)
 
120
 
121
  return image, Image.open(output_overlay)
122
 
123
+ # Main segmentation dispatcher
124
  def segment(model_name, image):
125
+ ext = image.format.lower() if hasattr(image, 'format') and image.format else None
126
+ if model_name == "Cellpose" and ext not in ["tif", "tiff", None]:
127
+ return None, f"❌ Cellpose only supports `.tif` or `.tiff` images."
128
+
 
 
 
 
 
 
129
  if model_name == "SegFormer":
130
  return infer_segformer(image)
131
  elif model_name == "StarDist":
 
137
  else:
138
  return None, f"❌ Unknown model: {model_name}"
139
 
140
+ # === Gradio UI ===
141
  with gr.Blocks(title="Cell Segmentation Explorer") as app:
142
  gr.Markdown("## Cell Segmentation Explorer")
143
  gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.")
 
159
  def handle_submit(model_name, img):
160
  if img is None:
161
  return None
162
+ _, result = segment(model_name, img)
163
  return result
164
 
165
  submit_btn.click(
 
174
  outputs=[image_input, output_image]
175
  )
176
 
 
177
  gr.Markdown("---")
178
  gr.Markdown("### Sample Images (click to use as input)")
179
 
180
+ original_sample_paths = ["Sample Images/img1.png", "Sample Images/img2.png", "Sample Images/img3.png"]
 
 
 
 
 
 
181
  resized_sample_paths = []
182
+
183
  for idx, p in enumerate(original_sample_paths):
184
  img = PILImage.open(p).resize((128, 128))
185
  temp_path = f"/tmp/sample_resized_{idx}.png"
 
189
  sample_image_components = []
190
  with gr.Row():
191
  for i, img_path in enumerate(resized_sample_paths):
192
+ def load_full_image(idx=i):
193
  return PILImage.open(original_sample_paths[idx])
194
 
195
  sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False)
 
200
  )
201
  sample_image_components.append(sample_img)
202
 
 
203
  app.launch()